reasoning-gym/examples/OpenRLHF/custom_reward_ppo.sh
2025-01-28 16:37:19 +00:00

34 lines
920 B
Bash
Executable file

#!/bin/bash
args=(
custom_reward.py
--pretrain meta-llama/Llama-3.2-1B-Instruct # OpenRLHF/Llama-3-8b-sft-mixture
--save_path ./checkpoint/Llama-3.2-1b-lr # ./checkpoint/llama-3-8b-rlhf
--save_steps -1
--logging_steps 1
--eval_steps -1
--micro_train_batch_size 2
--train_batch_size 16
--micro_rollout_batch_size 4
--rollout_batch_size 64 # 1024
--max_epochs 1
--prompt_max_len 1024
--generate_max_len 1024
--zero_stage 2
--bf16
--actor_learning_rate 5e-7
--init_kl_coef 0.01
--prompt_data chain_sum # leg_counting
--input_key question
--apply_chat_template
--normalize_reward
--adam_offload
--flash_attn
--gradient_checkpointing
--max_samples 100000
--critic_learning_rate 9e-6
)
# Add wandb argument only if wandb_token is set
if [[ -n "${wandb_token}" ]]; then
args+=(--use_wandb "${wandb_token}")
fi
deepspeed ${args[@]}