diff --git a/training/configs/qwen2.5_3b_grpo_composite.yaml b/training/configs/qwen2.5_3b_grpo_composite.yaml index 3c489ffe..ae7fae4b 100644 --- a/training/configs/qwen2.5_3b_grpo_composite.yaml +++ b/training/configs/qwen2.5_3b_grpo_composite.yaml @@ -39,8 +39,8 @@ data: prompt_key: prompt max_prompt_length: 512 max_response_length: 1024 - train_batch_size: 64 - val_batch_size: 64 + train_batch_size: 32 + val_batch_size: 32 return_raw_chat: True return_raw_input_ids: True @@ -56,7 +56,7 @@ actor_rollout_ref: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 32 ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 160 + ppo_micro_batch_size_per_gpu: 8 use_dynamic_bsz: False ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 @@ -107,7 +107,7 @@ actor_rollout_ref: free_cache_engine: True load_format: dummy_dtensor tensor_model_parallel_size: 4 - max_num_batched_tokens: 8192 + max_num_batched_tokens: 16384 max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 160 @@ -118,7 +118,7 @@ actor_rollout_ref: # for hf rollout do_sample: True use_fire_sampling: False - max_model_len: 4096 + max_model_len: 16384 # number of responses (i.e. num sample times) n: 8 # > 1 for grpo val_kwargs: @@ -178,7 +178,7 @@ critic: fsdp_size: -1 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null + ppo_micro_batch_size_per_gpu: 8 forward_micro_batch_size: ${critic.ppo_micro_batch_size} forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}