From 695aad4dbc83649a0aa5d0db9a3e2ee3645fcb2d Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Fri, 28 Mar 2025 09:45:17 +0100 Subject: [PATCH] fix(training): Prepend `` token in format reward (#396) * prepend think token in format reward * pre commit + fix some default vals * add checkpoint config --- .gitignore | 5 +++++ training/configs/llama3.1_1b_grpo.yaml | 3 +++ training/configs/qwen2.5_1.5b_grpo.yaml | 6 +++++- training/trainers/ray_grpo_trainer.py | 3 +++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d1e0d496..4846b2a0 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,8 @@ htmlcov/ # Jupyter Notebook .ipynb_checkpoints/ .virtual_documents/ + +# logs +wandb/ +outputs/ +*.log diff --git a/training/configs/llama3.1_1b_grpo.yaml b/training/configs/llama3.1_1b_grpo.yaml index 74200cad..34a5f8d3 100644 --- a/training/configs/llama3.1_1b_grpo.yaml +++ b/training/configs/llama3.1_1b_grpo.yaml @@ -35,6 +35,7 @@ reward: format_reward: enable: True scaling_factor: 0.2 + prepend_think_token: False # Set to True only when the tokenizer's prompt template pre-fills the generation with , such as in the case of (distilled) r1 models length_reward: enable: True scaling_factor: 0.2 @@ -75,6 +76,8 @@ actor_rollout_ref: ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'hf_model', 'optimizer', 'extra'] optim: lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime diff --git a/training/configs/qwen2.5_1.5b_grpo.yaml b/training/configs/qwen2.5_1.5b_grpo.yaml index 3ad49d60..1bee4782 100644 --- a/training/configs/qwen2.5_1.5b_grpo.yaml +++ b/training/configs/qwen2.5_1.5b_grpo.yaml @@ -35,6 +35,7 @@ reward: format_reward: enable: True scaling_factor: 0.2 + prepend_think_token: False # Set to True only when the tokenizer's prompt template pre-fills the generation with , such as in the case of (distilled) r1 models length_reward: enable: True scaling_factor: 0.2 @@ -75,6 +76,8 @@ actor_rollout_ref: ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'hf_model', 'optimizer', 'extra'] optim: lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime @@ -116,6 +119,7 @@ actor_rollout_ref: tensor_model_parallel_size: 2 max_num_batched_tokens: 8192 max_num_seqs: 1024 + max_model_len: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 160 log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} @@ -144,7 +148,7 @@ trainer: total_epochs: 10 total_training_steps: null project_name: rg-test - experiment_name: verl_grpo_llama3.1_1b + experiment_name: verl_grpo_qwen2.5_1.5b logger: [ 'console', 'wandb' ] val_generations_to_log_to_wandb: 0 nnodes: 1 diff --git a/training/trainers/ray_grpo_trainer.py b/training/trainers/ray_grpo_trainer.py index a1d17824..869d2dd8 100644 --- a/training/trainers/ray_grpo_trainer.py +++ b/training/trainers/ray_grpo_trainer.py @@ -31,6 +31,7 @@ class RayGRPOTrainer(RayPPOTrainer): self.max_output_length = max_output_length self.format_reward_scaling_factor = config.reward.format_reward.scaling_factor + self.format_reward_prepend_think_token = config.reward.format_reward.prepend_think_token self.length_reward_scaling_factor = config.reward.length_reward.scaling_factor train_reward_fn = lambda data: self._score_output(data, num_examine=0) @@ -99,6 +100,8 @@ class RayGRPOTrainer(RayPPOTrainer): def _compute_format_reward(self, solution_str: str) -> float: """Reward use of exactly one correctly structured and block.""" + if self.format_reward_prepend_think_token: + solution_str = "" + solution_str scaling_factor = self.format_reward_scaling_factor # check and blocks are present pattern = r"\s*.*?\s*.*?"