diff --git a/example_trainer/README.md b/example_trainer/README.md index 880f4d21..9be1cf66 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -129,9 +129,13 @@ python -m example_trainer.vllm_api_server \ # Important: Use server_type=vllm to get logprobs (required for GRPO) python environments/gsm8k_server.py serve \ --env.group_size 4 \ - --env.max_num 200 \ - --slurm.num_requests_per_time_interval 16 \ - --slurm.time_interval 10 \ + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-lora-only-env" \ --openai.api_key "dummy" \ --openai.base_url "http://localhost:9001/v1" \ --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ @@ -171,7 +175,18 @@ python -m example_trainer.vllm_api_server --model ... --enable-lora --enforce-ea while ! curl -s http://localhost:9001/health > /dev/null; do sleep 1; done # 4. Start environment (MUST use --openai.server_type vllm for logprobs) -python environments/gsm8k_server.py serve ... +python environments/gsm8k_server.py serve \ + --env.group_size 4 \ + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-train-env" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm # 5. Start trainer (will register with API and begin training) python -m example_trainer.grpo --weight-bridge-mode lora_only ... @@ -226,8 +241,13 @@ python environments/gsm8k_server.py serve \ --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ --openai.server_type vllm \ --env.group_size 4 \ - --slurm.num_requests_per_time_interval 16 \ - --slurm.time_interval 10 + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-shared-vllm-env" ``` **Terminal 4: Trainer** diff --git a/example_trainer/run_gsm8k_lora_matrix.sh b/example_trainer/run_gsm8k_lora_matrix.sh index 2b77669c..abe94205 100755 --- a/example_trainer/run_gsm8k_lora_matrix.sh +++ b/example_trainer/run_gsm8k_lora_matrix.sh @@ -60,6 +60,10 @@ LORA_ONLY_VLLM_GPU="${LORA_ONLY_VLLM_GPU:-2}" LORA_RESTART_TRAINER_GPU="${LORA_RESTART_TRAINER_GPU:-3}" LORA_RESTART_VLLM_GPU="${LORA_RESTART_VLLM_GPU:-4}" DRY_RUN="${DRY_RUN:-0}" +ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" +ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" +ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}" +ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" SHARED_API_PORT="$START_API_PORT" SHARED_VLLM_PORT="$START_VLLM_PORT" @@ -77,6 +81,10 @@ log() { kill_port() { local port="$1" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] skip port cleanup for :${port}" + return 0 + fi if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then lsof -ti ":${port}" | xargs -r kill -9 || true fi @@ -156,20 +164,34 @@ common_trainer_flags() { } start_gsm8k_env() { - local vllm_port="$1" - local logfile="$2" + local api_port="$1" + local vllm_port="$2" + local env_wandb_name="$3" + local logfile="$4" start_process "gsm8k_env" "$logfile" \ "$PYTHON_BIN" environments/gsm8k_server.py serve \ --env.group_size 4 \ - --env.max_num 200 \ - --slurm.num_requests_per_time_interval 16 \ - --slurm.time_interval 10 \ + --env.batch_size "$ENV_BATCH_SIZE" \ + --env.total_steps "$ENV_TOTAL_STEPS" \ + --env.steps_per_eval "$ENV_STEPS_PER_EVAL" \ + --env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \ + --env.max_token_length "$MAX_MODEL_LEN" \ + --env.rollout_server_url "http://localhost:${api_port}" \ + --env.use_wandb true \ + --env.wandb_name "$env_wandb_name" \ --openai.api_key "dummy" \ --openai.base_url "http://localhost:${vllm_port}/v1" \ --openai.model_name "$MODEL_NAME" \ --openai.server_type vllm } +start_gsm8k_env_shared() { + local vllm_port="$1" + local logfile="$2" + local api_port="$SHARED_API_PORT" + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-shared-vllm-env" "$logfile" +} + run_shared_vllm() { log "========== RUN: shared_vllm ==========" local api_port="$SHARED_API_PORT" @@ -207,7 +229,7 @@ run_shared_vllm() { wait_for_http "http://localhost:${vllm_port}/health" 300 "shared vLLM" fi - start_gsm8k_env "$vllm_port" "$mode_dir/env.log" + start_gsm8k_env_shared "$vllm_port" "$mode_dir/env.log" log "Starting trainer: shared_vllm" if [[ "$DRY_RUN" == "1" ]]; then @@ -273,7 +295,7 @@ run_lora_only() { wait_for_http "http://localhost:${vllm_port}/health" 300 "lora_only vLLM" fi - start_gsm8k_env "$vllm_port" "$mode_dir/env.log" + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-only-env" "$mode_dir/env.log" log "Starting trainer: lora_only" if [[ "$DRY_RUN" == "1" ]]; then @@ -351,7 +373,7 @@ run_lora_restart() { $(add_lora_layer_flag) printf '\n' log "[DRY RUN] then wait for http://localhost:${vllm_port}/health" - log "[DRY RUN] then start GSM8K env pointed at http://localhost:${vllm_port}/v1" + log "[DRY RUN] then start GSM8K env pointed at http://localhost:${vllm_port}/v1 and rollout server http://localhost:${api_port}" log "[DRY RUN] trainer log path: $mode_dir/trainer.log" else env CUDA_VISIBLE_DEVICES="$LORA_RESTART_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ @@ -372,7 +394,7 @@ run_lora_restart() { run_pids+=("$trainer_pid") wait_for_http "http://localhost:${vllm_port}/health" 420 "lora_restart vLLM" - start_gsm8k_env "$vllm_port" "$mode_dir/env.log" + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-restart-env" "$mode_dir/env.log" wait "$trainer_pid" cat "$mode_dir/trainer.log"