Welcome to your interview! This project provides a lightweight, self-contained framework for training language models using a simplified version of GRPO. The goal is to give you a hands-on opportunity to work with modern RLHF techniques, analyse model behaviour, and even design your own reward systems.
The codebase is written in PyTorch and intentionally avoids heavy dependencies like Hugging Face's TRL to ensure the core logic is transparent and accessible. You will be working directly with the training loop, loss implementation, and evaluation pipeline.
The interview is structured into a series of tasks that progress from core implementation to open-ended analysis and creative extension. There's not much left to code to get something minimal working - the main emphasis is on being creative and doing some higher-level analysis.
Good luck!
To get started, follow these steps to set up your environment.
-
Create and activate a Python virtual environment:
python -m venv venv source venv/bin/activate -
Install the required packages:
pip install -r requirements.txt
-
Flash Attention (Optional but Recommended): For significantly better performance (speed and memory), this repository is configured to use Flash Attention 2. Depending on your CUDA version and system setup, you may need to install it separately:
pip install flash-attn --no-build-isolation
If you encounter issues, you can comment out the
attn_implementationline inllms.py, but ideally get it to work because you'll be generating lots of rollouts. -
Cloud Compute Setup (Vast.ai): Charlie will get you set up with an account on Vast.ai for GPU resources.
- For training the smaller
Qwen2-0.5Bmodel, renting a single H100 GPU is sufficient and cost-effective. - For the scaling laws analysis in Task 3 involving larger Qwen models, you will likely need a node with 2-4 x H100 GPUs. Part of the task is managing the computational resources efficiently.
- For training the smaller
The project is organized into several key files:
main.py: The main entry point for training and evaluation. It contains the primary training loop, calls the evaluator, and is where you will implement the loss function.llms.py: A simple utility for loading language models and tokenizers from Hugging Face.evaluator.py: Defines the reward evaluation logic. It includes theRewardEvaluatorabstract base class and a concreteGSM8kEvaluatorfor the math reasoning task.rl_datasets.py: Manages data loading. Currently configured for the GSM8K dataset.plotter.py: A script to generate plots from training and evaluation logs, helping you visualize results.utils.py: Contains miscellaneous helper functions for seeding, logging, and calculating log probabilities.
Please complete the tasks in the order presented.
Your first task is to implement the core GRPO loss function. Navigate to the compute_loss function in main.py. It is currently empty.
Objective: Implement a per-token, KL-regularized policy gradient loss. The objective is to maximize rewards while penalizing KL divergence from the reference policy on a per-token basis.
Key Inputs:
model: The policy model being trained.base_model: The reference model (a frozen copy of the initial model).prompt_completion_ids: The full token sequences.completion_ids: The tokens of the generated completions.attention_mask: Attention mask for the full sequence.completion_mask: A mask to ignore padding tokens in the completions.advantages: The calculated advantages for each completion, which represent the normalized rewards.
Implementation Steps:
-
Calculate Per-Token Log Probabilities:
- Use the
utils.get_per_token_logpsfunction to get the log probabilities of thecompletion_idsunder themodel. - In
torch.inference_mode(), do the same for thebase_modelto get the reference log probabilities (ref_per_token_logps).
- Use the
-
Calculate Per-Token KL Divergence:
- Compute the forward KL divergence between the reference and policy distributions for each token. The formula is:
kl = exp(ref_logps - policy_logps) - (ref_logps - policy_logps) - 1
- Compute the forward KL divergence between the reference and policy distributions for each token. The formula is:
-
Calculate the Per-Token Policy Objective:
- The core of the policy gradient update uses an importance sampling ratio.
- Multiply this ratio by the
advantagesto scale the updates based on reward.
-
Combine into Final Per-Token Loss:
- Combine the policy objective with the KL penalty.
-
Aggregate the Loss:
- Apply the
completion_maskto theper_token_lossto exclude padding tokens from the loss calculation. - Sum the masked losses along the sequence dimension and normalize by the number of completion tokens for each sequence.
- Finally, take the mean across the batch to get the final scalar
loss.
- Apply the
-
Return Metrics:
- Return the
loss. - Also compute and return a dictionary of metrics, including the mean KL divergence (
kl) and the averageresponse_length.
- Return the
Once your loss function is implemented, train a small model on the GSM8K dataset.
Instructions:
-
Run Training: Execute the main script to start training. We recommend starting with the
Qwen/Qwen2-0.5B-Instructmodel, which is fast to train.python main.py \ --model_name "Qwen/Qwen2-0.5B-Instruct" \ --num_train_iters 1000 \ --eval_iterations 100 \ --output_dir "output/Qwen-0.5B"Feel free to adjust the training arguments in
main.pyas you see fit. -
Monitor Output: Training logs, evaluation results, and model checkpoints will be saved in the specified
output_dir. -
Visualize Metrics: Use the
plotter.pyscript to visualize the training dynamics (also results saving and plotting can definitely be cleaned up a lot).python plotter.py --log_dir "output/Qwen-0.5B"This will generate a
training_plots.pdfin the log directory. Review the plots for loss, rewards, KL divergence, and evaluation accuracy. Does the model appear to be learning successfully?
Now, conduct a small-scale analysis of how model size affects performance on the GSM8K task.
Objective: Explore the relationship between model scale, performance, and data efficiency.
Instructions:
- Train Multiple Models: Train at least two different model sizes from the same family (e.g.,
Qwen2-0.5BandQwen2-1.5B). If you have the computational resources, you can add a larger model likeQwen2-7B. Train them for the same number of iterations. - Plot Performance vs. Scale: Create a plot of final evaluation accuracy vs. model size (number of parameters).
- Analyze Data Efficiency: Plot the evaluation accuracy over training steps for each model on the same graph. Does the larger model learn faster? Does it achieve a higher peak performance?
- Summarize Findings: Write a brief summary of your conclusions. What are the trade-offs you observed? You can create a new markdown file or a Jupyter notebook for this analysis.
Compare the effectiveness and data efficiency of our RL approach (GRPO) against standard Supervised Fine-Tuning (SFT).
Objective: Investigate how RL training compares to fine-tuning on a small, high-quality dataset.
Instructions:
-
Generate a "Gold Standard" Dataset:
- Write a script to generate high-quality solutions for ~500-1000 problems from the GSM8K training set.
- Use a powerful LLM for generation (e.g., via an API for GPT-4, or a large local model like Mixtral or Llama-3-70B).
- Ensure the generated solutions strictly follow the required
<reasoning>...</reasoning><answer>...</answer>format. Save these as your SFT dataset.
-
Implement and Run SFT:
- Write a new, simple script to perform SFT on a base model (e.g.,
Qwen2-0.5B-Instruct) using your gold-standard dataset. You can use Hugging Face'sTrainerAPI for this. - Train the model and evaluate its accuracy on the GSM8K test set.
- Write a new, simple script to perform SFT on a base model (e.g.,
-
Compare and Analyze:
- How does the SFT model's performance compare to the GRPO-trained model from Task 2?
- To analyze data efficiency, train several SFT models on subsets of your gold data (e.g., 100, 250, 500 examples). Plot the SFT accuracy vs. the number of training examples.
- Compare this plot to the accuracy-over-time plot from your GRPO run. Roughly how many GRPO training steps appear to be as effective as one high-quality SFT example?
- Write up your analysis and conclusions.
If you have time, extend the framework to a non-verifiable, creative domain.
Objective: Implement an LLM-as-a-judge reward model and use it to train a policy on a creative writing task.
Instructions:
-
Define a Task and Dataset:
- Choose a simple creative task, for example, "Given an opening line, write a short, one-paragraph story."
- Create a small dataset of prompts (e.g., 50-100 opening lines). You will need to create a new
DataLoaderinrl_datasets.py.
-
Implement an LLM-as-a-Judge Evaluator:
- Create a new class,
CreativeWritingEvaluator, inevaluator.pythat inherits fromRewardEvaluator. - In its
compute_rewardsmethod, call an external LLM (this can be the same model you are training, another local model, or an API call) to act as a "judge". - Design a prompt for the judge model that asks it to rate the generated story on a numeric scale (e.g., 1 to 5) based on criteria like creativity, coherence, and engagement. You'll need to be careful here - a lot of judges don't have good discriminatory power, and you'll need some responses to be judged as good and some as poor, otherwise GRPO won't work.
- The reward for each completion will be the score returned by the judge LLM. Be sure to handle parsing the judge's output robustly.
- Create a new class,
-
Train the Model:
- Run
main.pyusing your new creative dataset and LLM-as-a-judge evaluator. - Examine the generated stories. Does the model's writing quality improve over time? What are the challenges of using an LLM as a reward source?
- Run