mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
docs: Update TRL README with GRPO example details and usage instructions (#76)
This commit is contained in:
parent
d61db3772a
commit
a8f9eafd43
3 changed files with 37 additions and 12 deletions
|
|
@ -1,5 +1,32 @@
|
||||||
1. Install the requirements in the txt file
|
# TRL Examples
|
||||||
|
|
||||||
```
|
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.
|
||||||
|
|
||||||
|
## GRPO Example
|
||||||
|
|
||||||
|
The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:
|
||||||
|
|
||||||
|
- Custom reward functions for answer accuracy and format compliance
|
||||||
|
- Integration with reasoning-gym datasets
|
||||||
|
- Configurable training parameters via YAML config
|
||||||
|
- Wandb logging and model checkpointing
|
||||||
|
- Evaluation on held-out test sets
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
1. Install the required dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. Configure the training parameters in `config/grpo.yaml`
|
||||||
|
2. Run the training script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main_grpo_reward.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,8 @@ class ReasoningGymDataset(Dataset):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
metadata = self.data[idx]
|
item = self.data[idx]
|
||||||
question = metadata["question"]
|
question = item["question"]
|
||||||
|
|
||||||
chat = []
|
chat = []
|
||||||
|
|
||||||
|
|
@ -43,7 +43,7 @@ class ReasoningGymDataset(Dataset):
|
||||||
chat.append({"role": "user", "content": question})
|
chat.append({"role": "user", "content": question})
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||||
return {"prompt": prompt, "metadata": metadata}
|
return {"prompt": prompt, "metadata": item}
|
||||||
|
|
||||||
|
|
||||||
class GRPOTrainerCustom(GRPOTrainer):
|
class GRPOTrainerCustom(GRPOTrainer):
|
||||||
|
|
@ -54,7 +54,7 @@ class GRPOTrainerCustom(GRPOTrainer):
|
||||||
args: GRPOConfig,
|
args: GRPOConfig,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
seed1,
|
seed,
|
||||||
size,
|
size,
|
||||||
developer_role="system",
|
developer_role="system",
|
||||||
):
|
):
|
||||||
|
|
@ -66,7 +66,7 @@ class GRPOTrainerCustom(GRPOTrainer):
|
||||||
peft_config=peft_config,
|
peft_config=peft_config,
|
||||||
)
|
)
|
||||||
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
||||||
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role)
|
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
|
||||||
|
|
||||||
def _format_reward(self, completions, **kwargs):
|
def _format_reward(self, completions, **kwargs):
|
||||||
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
|
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
|
||||||
|
|
@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
peft_config=peft_config,
|
peft_config=peft_config,
|
||||||
seed1=training_args.seed,
|
seed=training_args.seed,
|
||||||
size=script_args.train_size,
|
size=script_args.train_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
|
||||||
"finetuned_from": model_args.model_name_or_path,
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
"dataset": list(script_args.dataset_name),
|
"dataset": list(script_args.dataset_name),
|
||||||
"dataset_tags": list(script_args.dataset_name),
|
"dataset_tags": list(script_args.dataset_name),
|
||||||
"tags": ["open-r1"],
|
"tags": ["reasoning-gym"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if trainer.accelerator.is_main_process:
|
if trainer.accelerator.is_main_process:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
torch --index-url https://download.pytorch.org/whl/cu124
|
torch>=2.6.0
|
||||||
torchvision --index-url https://download.pytorch.org/whl/cu124
|
|
||||||
torchaudio --index-url https://download.pytorch.org/whl/cu124
|
|
||||||
datasets
|
datasets
|
||||||
peft
|
peft
|
||||||
transformers
|
transformers
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue