docs: Update TRL README with GRPO example details and usage instructions (#76)

This commit is contained in:
Andreas Köpf 2025-02-07 07:56:22 +01:00 committed by GitHub
parent d61db3772a
commit a8f9eafd43
3 changed files with 37 additions and 12 deletions

View file

@ -33,8 +33,8 @@ class ReasoningGymDataset(Dataset):
return len(self.data)
def __getitem__(self, idx):
metadata = self.data[idx]
question = metadata["question"]
item = self.data[idx]
question = item["question"]
chat = []
@ -43,7 +43,7 @@ class ReasoningGymDataset(Dataset):
chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": metadata}
return {"prompt": prompt, "metadata": item}
class GRPOTrainerCustom(GRPOTrainer):
@ -54,7 +54,7 @@ class GRPOTrainerCustom(GRPOTrainer):
args: GRPOConfig,
tokenizer,
peft_config,
seed1,
seed,
size,
developer_role="system",
):
@ -66,7 +66,7 @@ class GRPOTrainerCustom(GRPOTrainer):
peft_config=peft_config,
)
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role)
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
args=training_args,
tokenizer=tokenizer,
peft_config=peft_config,
seed1=training_args.seed,
seed=training_args.seed,
size=script_args.train_size,
)
@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
"finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name),
"tags": ["open-r1"],
"tags": ["reasoning-gym"],
}
if trainer.accelerator.is_main_process: