reasoning-gym/examples/trl/grpo.py
joesharratt1229 1c98584f28
Feat/unsloth example (#482)
* cleaned up examples

* updated failing hooks

* updated readme

* corrected linting checks
2025-06-28 17:04:38 +01:00

266 lines
8.9 KiB
Python

import logging
import os
import re
import sys
from dataclasses import dataclass, field
from typing import Optional
import torch
import transformers
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
import reasoning_gym
from reasoning_gym.coaching.experiment import Experiment
from reasoning_gym.composite import DatasetSpec
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
@dataclass
class DatasetConfigItem:
weight: Optional[float] = field(default=1.0)
config: Optional[dict] = field(default_factory=dict)
@dataclass
class DatasetConfig:
dataset_size: int = field(default=10000)
developer_prompt: str = field(default="DeepSeekZero")
developer_role: str = field(default="system")
datasets: dict[str, DatasetConfigItem] = field(default=None)
def __post_init__(self):
# Convert dictionary items to DatasetConfigItem instances
if self.datasets:
converted_datasets = {}
for name, config_item in self.datasets.items():
if isinstance(config_item, dict):
converted_datasets[name] = DatasetConfigItem(**config_item)
else:
converted_datasets[name] = config_item
self.datasets = converted_datasets
class ReasoningGymDataset(Dataset):
def __init__(
self,
tokenizer,
procedural_dataset: Optional[ProceduralDataset] = None,
experiment: Optional[Experiment] = None,
developer_prompt: Optional[str] = None,
developer_role: Optional[str] = None,
):
self.tokenizer = tokenizer
self.data = procedural_dataset or experiment.composite
self.experiment = experiment
self.developer_prompt = developer_prompt
self.developer_role = developer_role
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> dict:
item = self.data[idx]
question = item["question"]
chat = []
if self.developer_role is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "item": item}
class CustomGRPOTrainer(GRPOTrainer):
def __init__(
self,
model,
args: GRPOConfig,
tokenizer,
train_dataset: ReasoningGymDataset,
eval_dataset: ReasoningGymDataset,
):
super().__init__(
model=model,
reward_funcs=[
self._accuracy_reward,
self._format_reward,
],
args=args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
def _accuracy_reward(self, completions: list[str], **kwargs) -> list[float]:
assert "item" in kwargs, "The 'item' argument must be provided to compute accuracy reward."
assert len(kwargs["item"]) == len(completions), "Items and completions must have the same length."
assert all(isinstance(item, dict) for item in kwargs["item"]), "Each item must be a dictionary."
answers = [extract_answer(c) for c in completions]
return [self.train_dataset.data.score_answer(answer, item) for answer, item in zip(answers, kwargs["item"])]
def _format_reward(self, completions: list[str], **kwargs) -> list[float]:
def count_tags(text: str) -> float:
count = 0.0
if re.search(r"\s*<think>\s*", text):
count += 0.25
if re.search(r"\s*</think>\s*", text):
count += 0.25
if re.search(r"\s*<answer>\s*", text):
count += 0.25
if re.search(r"\s*</answer>\s*", text):
count += 0.25
return count
return [count_tags(c) for c in completions]
def make_dataset(
tokenizer,
data_source: Experiment | ProceduralDataset,
developer_prompt: str,
developer_role: Optional[str] = None,
) -> ReasoningGymDataset:
"""Create a ReasoningGymDataset from an Experiment or ProceduralDataset."""
if isinstance(data_source, Experiment):
return ReasoningGymDataset(
tokenizer=tokenizer,
experiment=data_source,
developer_prompt=developer_prompt,
developer_role=developer_role,
)
else:
return ReasoningGymDataset(
tokenizer=tokenizer,
procedural_dataset=data_source,
developer_prompt=developer_prompt,
developer_role=developer_role,
)
def prepare_datasets(
config: DatasetConfig,
tokenizer,
) -> tuple[ReasoningGymDataset, ReasoningGymDataset]:
"""Prepare the training and eval datasets."""
developer_prompt = SYSTEM_PROMPTS[config.developer_prompt]
dataset_specs = [
DatasetSpec(
name=name,
weight=ds_config.weight,
config=ds_config.config,
)
for name, ds_config in config.datasets.items()
]
train_data_source = reasoning_gym.create_dataset(
"composite", seed=1, size=config.dataset_size, datasets=dataset_specs
)
val_data_source = reasoning_gym.create_dataset(
"composite", seed=2, size=config.dataset_size, datasets=dataset_specs
)
train_dataset = make_dataset(
tokenizer=tokenizer,
data_source=train_data_source,
developer_prompt=developer_prompt,
developer_role=config.developer_role,
)
eval_dataset = make_dataset(
tokenizer=tokenizer,
data_source=val_data_source,
developer_prompt=developer_prompt,
developer_role=config.developer_role,
)
return train_dataset, eval_dataset
def main():
# -----------
# Parse args
# -----------
parser = TrlParser((DatasetConfig, GRPOConfig, ModelConfig))
reasoning_gym_args, training_args, model_args = parser.parse_args_and_config()
set_seed(training_args.seed)
# ---------------
# Set up logging
# ---------------
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Training parameters {training_args}")
# -----------
# Load model
# -----------
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation=model_args.attn_implementation,
use_cache=False if training_args.gradient_checkpointing else True,
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# --------------------
# Instantiate trainer
# --------------------
training_args.reasoning_gym = reasoning_gym_args
train_dataset, eval_dataset = prepare_datasets(reasoning_gym_args, tokenizer)
trainer = CustomGRPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# ------------------------------
# See if we can resume training
# ------------------------------
logger.info("Starting training...")
# Check for last checkpoint
ckpt = None
if training_args.resume_from_checkpoint is not None:
ckpt = training_args.resume_from_checkpoint
elif os.path.isdir(training_args.output_dir):
ckpt = get_last_checkpoint(training_args.output_dir)
if ckpt:
logger.info(f"\nCheckpoint detected, resuming training at {ckpt=}.")
else:
logger.info("\nNo checkpoint detected, starting training from scratch.")
# ---------------
# Start training
# ---------------
train_result = trainer.train(resume_from_checkpoint=ckpt)
train_metrics = train_result.metrics
trainer.log_metrics("train", train_metrics)
trainer.save_metrics("train", train_metrics)
trainer.save_state()
# ---------
# Clean up
# ---------
del trainer
torch.cuda.empty_cache()
if __name__ == "__main__":
main()