diff --git a/examples/veRL/README.md b/examples/veRL/README.md index 5904cc8a..e9e2bd6c 100644 --- a/examples/veRL/README.md +++ b/examples/veRL/README.md @@ -15,7 +15,7 @@ Regarding vllm>0.7 see: [docs](https://verl.readthedocs.io/en/latest/README_vllm ### clone and install veRL -tested with verl HEAD 0dfcb7f99e299940e1792a386df13c7591df351a +tested with verl HEAD c34206925e2a50fd452e474db857b4d488f8602d ``` git clone https://github.com/volcengine/verl.git diff --git a/examples/veRL/chain_sum/config/grpo_trainer.yaml b/examples/veRL/chain_sum/config/grpo_trainer.yaml index b99fa89f..a4277028 100644 --- a/examples/veRL/chain_sum/config/grpo_trainer.yaml +++ b/examples/veRL/chain_sum/config/grpo_trainer.yaml @@ -83,8 +83,11 @@ actor_rollout_ref: enable_chunked_prefill: True # could get higher throughput # for hf rollout do_sample: True + use_fire_sampling: False # number of responses (i.e. num sample times) n: 16 # > 1 for grpo + val_kwargs: + do_sample: True critic: strategy: fsdp @@ -151,6 +154,7 @@ algorithm: kl_coef: 0.001 trainer: + balance_batch: True total_epochs: 30 total_training_steps: null project_name: verl_examples diff --git a/examples/veRL/chain_sum/config/ppo_trainer.yaml b/examples/veRL/chain_sum/config/ppo_trainer.yaml index a3d167ea..5e4e97be 100644 --- a/examples/veRL/chain_sum/config/ppo_trainer.yaml +++ b/examples/veRL/chain_sum/config/ppo_trainer.yaml @@ -83,8 +83,11 @@ actor_rollout_ref: enable_chunked_prefill: True # could get higher throughput # for hf rollout do_sample: True + use_fire_sampling: False # number of responses (i.e. num sample times) n: 1 # > 1 for grpo + val_kwargs: + do_sample: True critic: strategy: fsdp @@ -151,6 +154,7 @@ algorithm: kl_coef: 0.001 trainer: + balance_batch: True total_epochs: 30 total_training_steps: null project_name: verl_examples diff --git a/examples/veRL/chain_sum/main_ppo_custom_reward.py b/examples/veRL/chain_sum/main_ppo_custom_reward.py index 2addb8e9..6c5863f3 100644 --- a/examples/veRL/chain_sum/main_ppo_custom_reward.py +++ b/examples/veRL/chain_sum/main_ppo_custom_reward.py @@ -70,6 +70,7 @@ class ReasoningGymDataset(Dataset): row_dict["input_ids"] = input_ids[0] row_dict["attention_mask"] = attention_mask[0] row_dict["position_ids"] = position_ids[0] + row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False) # encode prompts without chat template if self.return_raw_chat: diff --git a/examples/veRL/chain_sum/main_ppo_custom_reward_server.py b/examples/veRL/chain_sum/main_ppo_custom_reward_server.py index 4a0eccc0..f97b25e9 100644 --- a/examples/veRL/chain_sum/main_ppo_custom_reward_server.py +++ b/examples/veRL/chain_sum/main_ppo_custom_reward_server.py @@ -118,6 +118,7 @@ class ReasoningGymDataset(Dataset): "entry_id": entry.entry_id, "metadata": entry.metadata, "index": index, + "raw_prompt_ids": self.tokenizer.encode(prompt, add_special_tokens=False), } # Add raw chat if requested