mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
tutorial(training): Add a minimal example with trl (#473)
* v0 * 2 gpu setup * improve parsing from yaml * update yaml dataset example * remove restriction on flash attn * more comments * first version of the readme * pin torch * simplify requirements * just flash attn * use set env instead * simpler set env * readme * add wandb project to setup * update template * update model id * post init to capture the config and weight * extract metadata * update config * update dataset config * move env for wandb project * pre-commit * remove qwen-math from training * more instructions * unused import * remove trl old * warmup ratio * warmup ratio * change model id * change model_id * add info about CUDA_VISIBLE_DEVICES
This commit is contained in:
parent
49f3821098
commit
56ce2e79a7
59 changed files with 382 additions and 155340 deletions
|
|
@ -1,32 +1,56 @@
|
|||
# TRL Examples
|
||||
# Training with TRL
|
||||
|
||||
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.
|
||||
Training stack:
|
||||
- TRL for reinforcement learning training
|
||||
- Accelerate (with DeepSpeed) for distributed training
|
||||
- vLLM for rollouts
|
||||
|
||||
## GRPO Example
|
||||
|
||||
The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:
|
||||
|
||||
- Custom reward functions for answer accuracy and format compliance
|
||||
- Integration with reasoning-gym datasets
|
||||
- Configurable training parameters via YAML config
|
||||
- Wandb logging and model checkpointing
|
||||
- Evaluation on held-out test sets
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install the required dependencies:
|
||||
This tutorial uses CUDA 11.8, Python 3.10, and PyTorch 2.5.1
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
Moreover, we assume that you have 2 GPUs on your machine, the last of which is used for vLLM rollouts.
|
||||
|
||||
If you have more than 2 GPUs, adjust the `./config/grpo.yaml` file so that the `vllm_device` is set to the last index of your GPU. For example, if you have 4 GPUs, set it to 3:
|
||||
```yaml
|
||||
vllm_device: 3 # If you have 4 GPUs, set this to 3
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
1. Configure the training parameters in `config/grpo.yaml`
|
||||
2. Run the training script:
|
||||
|
||||
Moreover, you would need to update the `CUDA_VISIBLE_DEVICES` environment variable in the `train.sh` script to include all your available GPUs. For example, if you have 4 GPUs, set it to:
|
||||
```bash
|
||||
python main_grpo_reward.py
|
||||
# ./train.sh
|
||||
|
||||
# ... beginning of the script
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# ... rest of the script
|
||||
```
|
||||
|
||||
The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
|
||||
|
||||
|
||||
1. Install the required packages:
|
||||
```bash
|
||||
# First, give execute permissions to the script
|
||||
# chmod +x ./set_env.sh
|
||||
|
||||
# Then, run the setup script
|
||||
./set_env.sh
|
||||
```
|
||||
|
||||
2. (Optional) Log in to Weights & Biases for experiment tracking:
|
||||
```bash
|
||||
# First, set your WANDB_API_KEY as an environment variable
|
||||
export WANDB_API_KEY=your_wandb_api_key
|
||||
|
||||
# Set the project name
|
||||
export WANDB_PROJECT=your_wandb_project_name
|
||||
```
|
||||
|
||||
3. Run the training script
|
||||
```bash
|
||||
# First, give execute permissions to the script
|
||||
# chmod +x ./train.sh
|
||||
|
||||
# Then, run the training script
|
||||
./train.sh
|
||||
```
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue