diff --git a/examples/OpenRLHF/.gitignore b/examples/OpenRLHF/.gitignore index 583bdcb5..740d90b5 100644 --- a/examples/OpenRLHF/.gitignore +++ b/examples/OpenRLHF/.gitignore @@ -1 +1,2 @@ checkpoint/ +wandb/ diff --git a/examples/OpenRLHF/custom_reward.py b/examples/OpenRLHF/custom_reward.py index 9c0c705b..15b7ee1e 100644 --- a/examples/OpenRLHF/custom_reward.py +++ b/examples/OpenRLHF/custom_reward.py @@ -16,6 +16,8 @@ from openrlhf.models.utils import compute_approx_kl, masked_mean from openrlhf.trainer import PPOTrainer from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer +from openrlhf.utils.logging_utils import init_logger + from torch.utils.data import Dataset from transformers.trainer import get_scheduler @@ -23,6 +25,8 @@ import reasoning_gym from reasoning_gym.dataset import ProceduralDataset from reasoning_gym.utils import extract_answer +logger = init_logger(__name__) + DEBUG = False @@ -176,12 +180,13 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker): value = None # determine outcome reward - completions = self.tokenizer.batch_decode(sequences.cpu(), skip_special_tokens=True) + completions = sequences[:, -action_mask.size(1):].cpu() + completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True) returns = [ self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m) for c, m in zip(completions, metadata) ] - r = torch.tensor(returns, device=sequences.device) + r = torch.tensor(returns, dtype=torch.float32, device=sequences.device) kl = compute_approx_kl( action_log_probs, @@ -197,6 +202,9 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker): "total_length": samples.total_length, "num_actions": num_actions, } + + logger.info(f"info={info}") + # reset model state self.actor.train() if self.critic is not None: @@ -667,7 +675,7 @@ if __name__ == "__main__": parser.add_argument("--value_head_prefix", type=str, default="score") # Custom dataset - parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument("--prompt_data", type=str, default="chain_sum", help="HF dataset name or path") parser.add_argument( "--prompt_data_probs", type=str, @@ -708,15 +716,10 @@ if __name__ == "__main__": args = parser.parse_args() - # if args.advantage_estimator not in ["gae"]: - # args.critic_pretrain = None - # elif args.critic_pretrain is None: - # if not args.remote_rm_url: - # args.critic_pretrain = args.reward_pretrain - # else: - # args.critic_pretrain = args.pretrain - - args.critic_pretrain = args.pretrain ## temp + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + elif args.critic_pretrain is None: + args.critic_pretrain = args.pretrain ## temp if args.advantage_estimator == "rloo": assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1" diff --git a/examples/OpenRLHF/custom_reward_ppo.sh b/examples/OpenRLHF/custom_reward_ppo.sh index 52d14636..f4af1833 100755 --- a/examples/OpenRLHF/custom_reward_ppo.sh +++ b/examples/OpenRLHF/custom_reward_ppo.sh @@ -17,15 +17,18 @@ args=( --bf16 --actor_learning_rate 5e-7 --init_kl_coef 0.01 - --prompt_data leg_counting + --prompt_data chain_sum # leg_counting --input_key question --apply_chat_template --normalize_reward --adam_offload --flash_attn --gradient_checkpointing - --max_samples 1024 # 100000 + --max_samples 100000 --critic_learning_rate 9e-6 - # --use_wandb {wandb_token} ) +# Add wandb argument only if wandb_token is set +if [[ -n "${wandb_token}" ]]; then + args+=(--use_wandb "${wandb_token}") +fi deepspeed ${args[@]} diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 2eba2e6a..07649a41 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -53,7 +53,7 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"] - reward = 0 + reward = 0.0 if answer is not None: if answer == oracle_answer: reward = 1.0 diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index d47eee3c..aaa54f72 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -12,11 +12,13 @@ The assistant first thinks about the reasoning process in the mind and then prov def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: regex = f"<{tag_name}>(.*?)" - answer_match = re.search( - regex, - completion, - flags=re.DOTALL, + matches = list( + re.finditer( + regex, + completion, + flags=re.DOTALL, + ) ) - if not answer_match: + if not matches: return None - return answer_match.group(1) + return matches[-1].group(1)