From cc0312e4462048e85fae692053a5ebed5108f944 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Tue, 28 Jan 2025 14:40:02 +0000 Subject: [PATCH] add first example with OpenRLHF --- .pre-commit-config.yaml | 2 +- examples/OpenRLHF/.gitignore | 1 + examples/OpenRLHF/custom_reward.py | 734 +++++++++++++++++++++++++ examples/OpenRLHF/custom_reward_ppo.sh | 31 ++ reasoning_gym/__init__.py | 13 +- reasoning_gym/dataset.py | 14 + reasoning_gym/utils.py | 22 + 7 files changed, 815 insertions(+), 2 deletions(-) create mode 100644 examples/OpenRLHF/.gitignore create mode 100644 examples/OpenRLHF/custom_reward.py create mode 100755 examples/OpenRLHF/custom_reward_ppo.sh create mode 100644 reasoning_gym/utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72513dd3..72c22f89 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: rev: 24.1.1 hooks: - id: black - language_version: python3.12 + language_version: python3.11 - repo: https://github.com/pycqa/isort rev: 5.13.2 diff --git a/examples/OpenRLHF/.gitignore b/examples/OpenRLHF/.gitignore new file mode 100644 index 00000000..583bdcb5 --- /dev/null +++ b/examples/OpenRLHF/.gitignore @@ -0,0 +1 @@ +checkpoint/ diff --git a/examples/OpenRLHF/custom_reward.py b/examples/OpenRLHF/custom_reward.py new file mode 100644 index 00000000..9c0c705b --- /dev/null +++ b/examples/OpenRLHF/custom_reward.py @@ -0,0 +1,734 @@ +# tested with OpenRLHF 707b970e992154952a91607ca5491cc49b8665c3 + +import argparse +import itertools +import math +import os +from dataclasses import dataclass +from datetime import datetime +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +from openrlhf.datasets import PromptDataset, SFTDataset +from openrlhf.models import Actor, get_llm_for_sequence_regression +from openrlhf.models.utils import compute_approx_kl, masked_mean +from openrlhf.trainer import PPOTrainer +from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer +from torch.utils.data import Dataset +from transformers.trainer import get_scheduler + +import reasoning_gym +from reasoning_gym.dataset import ProceduralDataset +from reasoning_gym.utils import extract_answer + +DEBUG = False + + +def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + dataset, + tokenizer, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + ) -> None: + super().__init__() + + self.dataset = dataset + self.tokenizer = tokenizer + self.developer_prompt = developer_prompt + self.developer_role = developer_role + + def __len__(self): + length = len(self.dataset) + return length + + def __getitem__(self, idx: int) -> tuple[str, dict]: + x = self.dataset[idx] + + q = x["question"] + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": q}) + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + return prompt, x + + +@dataclass +class SamplesWithMetadata(Samples): + metadata: Optional[dict] + + +class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker): + def __init__( + self, + dataset: ProceduralDataset, + actor: Actor, + critic: nn.Module, + reward_model: nn.Module, + initial_model: Actor, + tokenizer, + prompt_max_len: int, + kl_controller, + strategy=None, + remote_rm_url: str = None, + reward_fn=None, + ) -> None: + super().__init__( + actor=actor, + critic=critic, + reward_model=reward_model, + initial_model=initial_model, + tokenizer=tokenizer, + prompt_max_len=prompt_max_len, + kl_controller=kl_controller, + strategy=strategy, + remote_rm_url=remote_rm_url, + reward_fn=reward_fn, + ) + self.dataset = dataset + + @torch.no_grad() + def generate_samples(self, all_prompts: List[Tuple[str, Any]], **generate_kwargs) -> List[Samples]: + """ + Generate samples and return in batches. + """ + assert not getattr(self, "packing_samples", False) + args = self.strategy.args + self.actor.eval() + + # prepare inputs to sample multiple response + repeated_prompts = [] + repeated_metadata = [] + for prompt, metadata in all_prompts: + for _ in range(args.n_samples_per_prompt): + repeated_prompts.append(prompt) + repeated_metadata.append(metadata) + + samples_list = [] + for i in range(0, len(repeated_prompts), args.micro_rollout_batch_size): + prompts = repeated_prompts[i : i + args.micro_rollout_batch_size] + metadata = repeated_metadata[i : i + args.micro_rollout_batch_size] + inputs = self.tokenize_fn(prompts, self.prompt_max_len, device="cuda") + sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs) + samples = SamplesWithMetadata( + sequences=sequences, + attention_mask=attention_mask, + action_mask=action_mask, + num_actions=action_mask.size(1), + packed_seq_lens=None, + response_length=action_mask.float().sum(dim=-1), + total_length=attention_mask.float().sum(dim=-1), + metadata=metadata, + ) + samples_list.append(samples) + return samples_list + + @torch.no_grad() + def make_experience(self, samples: Samples) -> Experience: + """ + Turn samples into experience by calculating logprobs, values, rewards, and kl divergence. + """ + self.actor.eval() + self.initial_model.eval() + if self.reward_model is not None: + self.reward_model.eval() + if self.critic is not None: + self.critic.eval() + + # extract values from samples + sequences = samples.sequences + attention_mask = samples.attention_mask + action_mask = samples.action_mask + num_actions = samples.num_actions + if isinstance(samples, SamplesWithMetadata): + metadata = samples.metadata + else: + metadata = None + + # log probs + action_log_probs = self.actor(sequences, num_actions, attention_mask) + + # init log probs + base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) + + # values + if self.critic is not None: + value = self.critic(sequences, num_actions, attention_mask) + else: + value = None + + # determine outcome reward + completions = self.tokenizer.batch_decode(sequences.cpu(), skip_special_tokens=True) + returns = [ + self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m) + for c, m in zip(completions, metadata) + ] + r = torch.tensor(returns, device=sequences.device) + + kl = compute_approx_kl( + action_log_probs, + base_action_log_probs, + action_mask=action_mask, + use_kl_estimator_k3=self.strategy.args.use_kl_estimator_k3, + ) + + info = { + "kl": masked_mean(kl, action_mask, dim=-1), + "reward": r, + "response_length": samples.response_length, + "total_length": samples.total_length, + "num_actions": num_actions, + } + # reset model state + self.actor.train() + if self.critic is not None: + self.critic.train() + + return Experience( + sequences, + action_log_probs, + value, + None, + None, + attention_mask, + action_mask, + info, + kl, + ) + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + actor = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + if args.critic_pretrain: + critic = get_llm_for_sequence_regression( + args.critic_pretrain, + "critic", + normalize_reward=args.normalize_reward, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=False), + value_head_prefix=args.value_head_prefix, + init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, + ) + else: + critic = None + + reward_model = None + + strategy.print(actor) + strategy.print(critic) + + # configure tokenizer + tokenizer = get_tokenizer( + args.pretrain, + actor.model, + "left", + strategy, + use_fast=not args.disable_fast_tokenizer, + ) + + # load weights for reference actor + initial_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=False), + ) + + if args.enable_ema: + ema_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=True), + ) + else: + ema_model = None + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + actor_optim = strategy.create_optimizer( + actor, lr=args.actor_learning_rate, betas=args.adam_betas, weight_decay=args.l2 + ) + if args.critic_pretrain: + critic_optim = strategy.create_optimizer( + critic, + lr=args.critic_learning_rate, + betas=args.adam_betas, + weight_decay=args.l2, + ) + else: + critic_optim = None + + # prepare datasets + print("prompt_data", args.prompt_data) + + # DeepSeek Zero system prompt + system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. +The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here + answer here +""" + + prompts_data = reasoning_gym.create_dataset(args.prompt_data, size=args.max_samples) + prompts_dataset = ReasoningGymDataset(prompts_data, tokenizer, developer_prompt=system_prompt) + + if args.pretrain_data: + pretrain_data = blending_datasets( + args.pretrain_data, + args.pretrain_data_probs, + strategy, + args.seed, + return_eval=False, + train_split=args.pretrain_split, + ) + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + pretrain_dataset = SFTDataset( + pretrain_data.select( + range( + min( + len(pretrain_data), + args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt, + ) + ) + ), + tokenizer, + pretrain_max_len, + strategy, + pretrain_mode=True, + ) + + def collate_prompt_and_metadata(xs: list[tuple]) -> list[tuple]: + """dummy collate function to pass on the metadata dict unchanged""" + return xs + + # prepare dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, + args.rollout_batch_size // strategy.world_size, + True, + True, + collate_fn=collate_prompt_and_metadata, + ) + if args.pretrain_data: + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, + args.micro_train_batch_size, + True, + True, + pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # configure scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + actor_scheduler = get_scheduler( + "cosine_with_min_lr", + actor_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1}, + ) + + if args.critic_pretrain: + critic_scheduler = get_scheduler( + "cosine_with_min_lr", + critic_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1}, + ) + else: + critic_scheduler = None + + # prepare models/optimizers... + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_model, + initial_model, + ) = strategy.prepare( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_model, + initial_model, + is_rlhf=True, + ) + + if ema_model: + ema_model._offload = True + ema_model = strategy.prepare(ema_model, is_rlhf=True) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor")) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + ema_model, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # fro GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # remote reward model + remote_rm_url=args.remote_rm_url, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + ) + + # patch experience maker .. + xp = trainer.experience_maker + trainer.experience_maker = AlgorithmicRewardExperienceMaker( + dataset=prompts_dataset.dataset, + actor=xp.actor, + critic=xp.critic, + reward_model=xp.reward_model, + initial_model=xp.initial_model, + tokenizer=xp.tokenizer, + prompt_max_len=xp.prompt_max_len, + kl_controller=xp.kl_ctl, + strategy=xp.strategy, + remote_rm_url=xp.remote_rm_url, + reward_fn=xp.reward_fn, + ) + xp = None + + trainer.fit( + args, + prompts_dataloader, + pretrain_dataloader, + consumed_samples, + num_update_steps_per_episodes, + ) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + ema_model if args.enable_ema else actor, + tokenizer, + args.save_path, + ) + + if args.critic_pretrain and args.save_value_network: + strategy.save_model( + critic, + tokenizer, + args.save_path + "_critic", + ) + + +import os + +import debugpy + +if __name__ == "__main__": + + # Only attach debugger on rank 0 process + if DEBUG and int(os.getenv("LOCAL_RANK", "0")) == 0: + debugpy.listen(("localhost", 5678)) + print("Waiting for debugger attach on localhost:5678") + debugpy.wait_for_client() + + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # PPO + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=512) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument( + "--generate_max_len", + type=int, + default=1024, + help="Max tokens to generate in PPO", + ) + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1000000) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument( + "--normalize_reward", + action="store_true", + default=False, + help="Enable Reward Normazation", + ) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument( + "--freezing_actor_steps", + type=int, + default=-1, + help="Used for critic initialization", + ) + parser.add_argument( + "--n_samples_per_prompt", + type=int, + default=1, + help="number of responses for each prompt in generation", + ) + parser.add_argument( + "--save_value_network", + action="store_true", + default=False, + help="Save critic model", + ) + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--critic_learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--use_kl_estimator_k3", + action="store_true", + default=False, + help=( + "Use the k3 estimator in http://joschu.net/blog/kl-approx.html" + "to ensure the KL divergence calculated is non-negative" + ), + ) + parser.add_argument( + "--adam_betas", + type=float, + nargs=2, + default=(0.9, 0.95), + help="Betas for Adam optimizer", + ) + parser.add_argument( + "--reward_clip_range", + type=float, + nargs=2, + default=(-10, 10), + help="Reward clip range", + ) + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument( + "--adam_offload", + action="store_true", + default=False, + help="Offload Adam Optimizer", + ) + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument( + "--flash_attn", + action="store_true", + default=False, + help="Enable FlashAttention2", + ) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--value_head_prefix", type=str, default="score") + + # Custom dataset + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--pretrain_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--pretrain_split", type=str, default="train") + parser.add_argument("--input_key", type=str, default="question", help="JSON dataset key") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", + action="store_true", + default=False, + help="Use HF tokenizer chat template", + ) + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + # if args.advantage_estimator not in ["gae"]: + # args.critic_pretrain = None + # elif args.critic_pretrain is None: + # if not args.remote_rm_url: + # args.critic_pretrain = args.reward_pretrain + # else: + # args.critic_pretrain = args.pretrain + + args.critic_pretrain = args.pretrain ## temp + + if args.advantage_estimator == "rloo": + assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1" + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + train(args) diff --git a/examples/OpenRLHF/custom_reward_ppo.sh b/examples/OpenRLHF/custom_reward_ppo.sh new file mode 100755 index 00000000..52d14636 --- /dev/null +++ b/examples/OpenRLHF/custom_reward_ppo.sh @@ -0,0 +1,31 @@ +#!/bin/bash +args=( + custom_reward.py + --pretrain meta-llama/Llama-3.2-1B-Instruct # OpenRLHF/Llama-3-8b-sft-mixture + --save_path ./checkpoint/Llama-3.2-1b-lr # ./checkpoint/llama-3-8b-rlhf + --save_steps -1 + --logging_steps 1 + --eval_steps -1 + --micro_train_batch_size 2 + --train_batch_size 16 + --micro_rollout_batch_size 4 + --rollout_batch_size 64 # 1024 + --max_epochs 1 + --prompt_max_len 1024 + --generate_max_len 1024 + --zero_stage 2 + --bf16 + --actor_learning_rate 5e-7 + --init_kl_coef 0.01 + --prompt_data leg_counting + --input_key question + --apply_chat_template + --normalize_reward + --adam_offload + --flash_attn + --gradient_checkpointing + --max_samples 1024 # 100000 + --critic_learning_rate 9e-6 + # --use_wandb {wandb_token} +) +deepspeed ${args[@]} diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 20afc9f8..b25eb134 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -6,4 +6,15 @@ from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, from .factory import create_dataset, register_dataset __version__ = "0.1.1" -__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"] +__all__ = [ + "arithmetic", + "algorithmic", + "algebra", + "cognition", + "data", + "games", + "graphs", + "logic", + "create_dataset", + "register_dataset", +] diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 3527e289..2eba2e6a 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -49,3 +49,17 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): - metadata: dict """ raise NotImplementedError + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["answer"] + reward = 0 + if answer is not None: + if answer == oracle_answer: + reward = 1.0 + elif oracle_answer in answer: + reward = 0.5 + else: + reward = 0.01 + + return reward diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py new file mode 100644 index 00000000..d47eee3c --- /dev/null +++ b/reasoning_gym/utils.py @@ -0,0 +1,22 @@ +import re +from typing import Optional + +# DeepSeek Zero system prompt +SYSTEM_PROMPTS = { + "DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. +The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here + answer here +""" +} + + +def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: + regex = f"<{tag_name}>(.*?)" + answer_match = re.search( + regex, + completion, + flags=re.DOTALL, + ) + if not answer_match: + return None + return answer_match.group(1)