mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
217 lines
7.6 KiB
Python
217 lines
7.6 KiB
Python
# This example is an adapted version of HuggingFace trl GRPO code:
|
|
# link : https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import datasets
|
|
import torch
|
|
import transformers
|
|
from grpo_config import ScriptArguments
|
|
from peft import LoraConfig
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
|
|
|
|
import reasoning_gym
|
|
from reasoning_gym.utils import extract_answer
|
|
|
|
|
|
class ReasoningGymDataset(Dataset):
|
|
def __init__(self, dataset_name, seed, size, tokenizer, developer_prompt, developer_role="system") -> None:
|
|
super().__init__()
|
|
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
|
|
self.tokenizer = tokenizer
|
|
self.developer_role = developer_role
|
|
self.developer_prompt = developer_prompt
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
question = item["question"]
|
|
|
|
chat = []
|
|
|
|
if self.developer_role is not None:
|
|
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
|
chat.append({"role": "user", "content": question})
|
|
|
|
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
|
return {"prompt": prompt, "metadata": item}
|
|
|
|
|
|
class GRPOTrainerCustom(GRPOTrainer):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
dataset_name,
|
|
args: GRPOConfig,
|
|
tokenizer,
|
|
peft_config,
|
|
seed,
|
|
size,
|
|
developer_role="system",
|
|
):
|
|
super().__init__(
|
|
model,
|
|
reward_funcs=[self._accuracy_reward, self._format_reward],
|
|
args=args,
|
|
processing_class=tokenizer,
|
|
peft_config=peft_config,
|
|
)
|
|
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
|
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
|
|
|
|
def _format_reward(self, completions, **kwargs):
|
|
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
|
|
matches = [re.match(regex, completion, flags=re.DOTALL) for completion in completions]
|
|
return [1.0 if match else 0.0 for match in matches]
|
|
|
|
def _accuracy_reward(self, completions, metadata, **kwargs):
|
|
answers = [extract_answer(completion) for completion in completions]
|
|
return [self.train_dataset.data.score_answer(answer, entry=obj) for (answer, obj) in zip(answers, metadata)]
|
|
|
|
|
|
def main(script_args, training_args, model_args):
|
|
set_seed(training_args.seed)
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
handlers=[logging.StreamHandler(sys.stdout)],
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
log_level = training_args.get_process_log_level()
|
|
logger.setLevel(log_level) # Set for module-level logger
|
|
|
|
# Configure third-party library log levels
|
|
datasets.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
|
|
logger.info(f"Training arguments: {training_args}")
|
|
logger.info(f"Model arguments: {model_args}")
|
|
logger.info(f"Script arguments: {script_args}")
|
|
|
|
last_checkpoint = None
|
|
if os.path.isdir(training_args.output_dir):
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
|
).to("cuda")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
peft_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=64,
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
|
|
task_type="CAUSAL_LM",
|
|
lora_dropout=0.05,
|
|
)
|
|
|
|
trainer = GRPOTrainerCustom(
|
|
model,
|
|
dataset_name=script_args.dataset_name,
|
|
args=training_args,
|
|
tokenizer=tokenizer,
|
|
peft_config=peft_config,
|
|
seed=training_args.seed,
|
|
size=script_args.train_size,
|
|
)
|
|
|
|
# Training loop
|
|
logger.info("Training model...")
|
|
checkpoint = None
|
|
if training_args.resume_from_checkpoint is not None:
|
|
checkpoint = training_args.resume_from_checkpoint
|
|
elif last_checkpoint is None:
|
|
checkpoint = model.save_pretrained(training_args.output_dir)
|
|
|
|
train_results = trainer.train(resume_from_checkpoint=checkpoint)
|
|
metrics = train_results.metrics
|
|
trainer.log_metrics("train", metrics)
|
|
trainer.save_metrics("train", metrics)
|
|
trainer.save_state()
|
|
|
|
logger.info("*** Save model ***")
|
|
trainer.save_model(training_args.output_dir)
|
|
logger.info(f"Model saved to {training_args.output_dir}")
|
|
|
|
kwargs = {
|
|
"finetuned_from": model_args.model_name_or_path,
|
|
"dataset": list(script_args.dataset_name),
|
|
"dataset_tags": list(script_args.dataset_name),
|
|
"tags": ["reasoning-gym"],
|
|
}
|
|
|
|
if trainer.accelerator.is_main_process:
|
|
trainer.create_model_card(**kwargs)
|
|
# Restore k,v cache for fast inference
|
|
trainer.model.config.use_cache = True
|
|
trainer.model.config.save_pretrained(training_args.output_dir)
|
|
|
|
def evaluate_model(model, tokenizer, dataset, *args, **kwargs):
|
|
model.eval()
|
|
correct_preds = 0
|
|
total_preds = 0
|
|
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
prompt = item["prompt"]
|
|
metadata = item["metadata"]
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
inputs,
|
|
max_new_tokens=training_args.max_completion_length,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
answer = reasoning_gym.utils.extract_answer(generated_text)
|
|
score = dataset.data.score_answer(answer, entry=metadata)
|
|
correct_preds += score
|
|
total_preds += 1
|
|
|
|
return correct_preds / total_preds
|
|
|
|
## Evaluate model
|
|
logger.info("Evaluating model...")
|
|
eval_dataset = ReasoningGymDataset(
|
|
script_args.dataset_name,
|
|
training_args.eval_seed,
|
|
script_args.eval_size,
|
|
tokenizer,
|
|
reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"],
|
|
)
|
|
|
|
eval_results = evaluate_model(model, tokenizer, eval_dataset)
|
|
trainer.log_metrics("eval", {"accuracy": eval_results})
|
|
trainer.save_metrics("eval", {"accuracy": eval_results})
|
|
logger.info(f"Evaluation results: {eval_results}")
|
|
|
|
if training_args.push_to_hub:
|
|
logging.info("Pushing model to hub...")
|
|
trainer.push_to_hub()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
|
main(script_args, training_args, model_args)
|