mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
docs: Update TRL README with GRPO example details and usage instructions (#76)
This commit is contained in:
parent
d61db3772a
commit
a8f9eafd43
3 changed files with 37 additions and 12 deletions
|
|
@ -33,8 +33,8 @@ class ReasoningGymDataset(Dataset):
|
|||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
metadata = self.data[idx]
|
||||
question = metadata["question"]
|
||||
item = self.data[idx]
|
||||
question = item["question"]
|
||||
|
||||
chat = []
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class ReasoningGymDataset(Dataset):
|
|||
chat.append({"role": "user", "content": question})
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
return {"prompt": prompt, "metadata": metadata}
|
||||
return {"prompt": prompt, "metadata": item}
|
||||
|
||||
|
||||
class GRPOTrainerCustom(GRPOTrainer):
|
||||
|
|
@ -54,7 +54,7 @@ class GRPOTrainerCustom(GRPOTrainer):
|
|||
args: GRPOConfig,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
seed1,
|
||||
seed,
|
||||
size,
|
||||
developer_role="system",
|
||||
):
|
||||
|
|
@ -66,7 +66,7 @@ class GRPOTrainerCustom(GRPOTrainer):
|
|||
peft_config=peft_config,
|
||||
)
|
||||
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role)
|
||||
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
|
||||
|
||||
def _format_reward(self, completions, **kwargs):
|
||||
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
|
||||
|
|
@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
|
|||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=peft_config,
|
||||
seed1=training_args.seed,
|
||||
seed=training_args.seed,
|
||||
size=script_args.train_size,
|
||||
)
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
|
|||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(script_args.dataset_name),
|
||||
"dataset_tags": list(script_args.dataset_name),
|
||||
"tags": ["open-r1"],
|
||||
"tags": ["reasoning-gym"],
|
||||
}
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue