docs: Update TRL README with GRPO example details and usage instructions (#76)

This commit is contained in:
Andreas Köpf 2025-02-07 07:56:22 +01:00 committed by GitHub
parent d61db3772a
commit a8f9eafd43
3 changed files with 37 additions and 12 deletions

View file

@ -1,5 +1,32 @@
1. Install the requirements in the txt file # TRL Examples
``` This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.
## GRPO Example
The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:
- Custom reward functions for answer accuracy and format compliance
- Integration with reasoning-gym datasets
- Configurable training parameters via YAML config
- Wandb logging and model checkpointing
- Evaluation on held-out test sets
## Setup
1. Install the required dependencies:
```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
## Usage
1. Configure the training parameters in `config/grpo.yaml`
2. Run the training script:
```bash
python main_grpo_reward.py
```
The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.

View file

@ -33,8 +33,8 @@ class ReasoningGymDataset(Dataset):
return len(self.data) return len(self.data)
def __getitem__(self, idx): def __getitem__(self, idx):
metadata = self.data[idx] item = self.data[idx]
question = metadata["question"] question = item["question"]
chat = [] chat = []
@ -43,7 +43,7 @@ class ReasoningGymDataset(Dataset):
chat.append({"role": "user", "content": question}) chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": metadata} return {"prompt": prompt, "metadata": item}
class GRPOTrainerCustom(GRPOTrainer): class GRPOTrainerCustom(GRPOTrainer):
@ -54,7 +54,7 @@ class GRPOTrainerCustom(GRPOTrainer):
args: GRPOConfig, args: GRPOConfig,
tokenizer, tokenizer,
peft_config, peft_config,
seed1, seed,
size, size,
developer_role="system", developer_role="system",
): ):
@ -66,7 +66,7 @@ class GRPOTrainerCustom(GRPOTrainer):
peft_config=peft_config, peft_config=peft_config,
) )
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role) self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
def _format_reward(self, completions, **kwargs): def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$" regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
peft_config=peft_config, peft_config=peft_config,
seed1=training_args.seed, seed=training_args.seed,
size=script_args.train_size, size=script_args.train_size,
) )
@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name), "dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name), "dataset_tags": list(script_args.dataset_name),
"tags": ["open-r1"], "tags": ["reasoning-gym"],
} }
if trainer.accelerator.is_main_process: if trainer.accelerator.is_main_process:

View file

@ -1,6 +1,4 @@
torch --index-url https://download.pytorch.org/whl/cu124 torch>=2.6.0
torchvision --index-url https://download.pytorch.org/whl/cu124
torchaudio --index-url https://download.pytorch.org/whl/cu124
datasets datasets
peft peft
transformers transformers