readme fix

This commit is contained in:
Jai Suphavadeeprasit 2026-02-19 15:49:29 -05:00
parent 657945fa1d
commit ef9f29dbde
2 changed files with 57 additions and 15 deletions

View file

@ -60,6 +60,10 @@ LORA_ONLY_VLLM_GPU="${LORA_ONLY_VLLM_GPU:-2}"
LORA_RESTART_TRAINER_GPU="${LORA_RESTART_TRAINER_GPU:-3}"
LORA_RESTART_VLLM_GPU="${LORA_RESTART_VLLM_GPU:-4}"
DRY_RUN="${DRY_RUN:-0}"
ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}"
ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}"
ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}"
ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}"
SHARED_API_PORT="$START_API_PORT"
SHARED_VLLM_PORT="$START_VLLM_PORT"
@ -77,6 +81,10 @@ log() {
kill_port() {
local port="$1"
if [[ "$DRY_RUN" == "1" ]]; then
log "[DRY RUN] skip port cleanup for :${port}"
return 0
fi
if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then
lsof -ti ":${port}" | xargs -r kill -9 || true
fi
@ -156,20 +164,34 @@ common_trainer_flags() {
}
start_gsm8k_env() {
local vllm_port="$1"
local logfile="$2"
local api_port="$1"
local vllm_port="$2"
local env_wandb_name="$3"
local logfile="$4"
start_process "gsm8k_env" "$logfile" \
"$PYTHON_BIN" environments/gsm8k_server.py serve \
--env.group_size 4 \
--env.max_num 200 \
--slurm.num_requests_per_time_interval 16 \
--slurm.time_interval 10 \
--env.batch_size "$ENV_BATCH_SIZE" \
--env.total_steps "$ENV_TOTAL_STEPS" \
--env.steps_per_eval "$ENV_STEPS_PER_EVAL" \
--env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \
--env.max_token_length "$MAX_MODEL_LEN" \
--env.rollout_server_url "http://localhost:${api_port}" \
--env.use_wandb true \
--env.wandb_name "$env_wandb_name" \
--openai.api_key "dummy" \
--openai.base_url "http://localhost:${vllm_port}/v1" \
--openai.model_name "$MODEL_NAME" \
--openai.server_type vllm
}
start_gsm8k_env_shared() {
local vllm_port="$1"
local logfile="$2"
local api_port="$SHARED_API_PORT"
start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-shared-vllm-env" "$logfile"
}
run_shared_vllm() {
log "========== RUN: shared_vllm =========="
local api_port="$SHARED_API_PORT"
@ -207,7 +229,7 @@ run_shared_vllm() {
wait_for_http "http://localhost:${vllm_port}/health" 300 "shared vLLM"
fi
start_gsm8k_env "$vllm_port" "$mode_dir/env.log"
start_gsm8k_env_shared "$vllm_port" "$mode_dir/env.log"
log "Starting trainer: shared_vllm"
if [[ "$DRY_RUN" == "1" ]]; then
@ -273,7 +295,7 @@ run_lora_only() {
wait_for_http "http://localhost:${vllm_port}/health" 300 "lora_only vLLM"
fi
start_gsm8k_env "$vllm_port" "$mode_dir/env.log"
start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-only-env" "$mode_dir/env.log"
log "Starting trainer: lora_only"
if [[ "$DRY_RUN" == "1" ]]; then
@ -351,7 +373,7 @@ run_lora_restart() {
$(add_lora_layer_flag)
printf '\n'
log "[DRY RUN] then wait for http://localhost:${vllm_port}/health"
log "[DRY RUN] then start GSM8K env pointed at http://localhost:${vllm_port}/v1"
log "[DRY RUN] then start GSM8K env pointed at http://localhost:${vllm_port}/v1 and rollout server http://localhost:${api_port}"
log "[DRY RUN] trainer log path: $mode_dir/trainer.log"
else
env CUDA_VISIBLE_DEVICES="$LORA_RESTART_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \
@ -372,7 +394,7 @@ run_lora_restart() {
run_pids+=("$trainer_pid")
wait_for_http "http://localhost:${vllm_port}/health" 420 "lora_restart vLLM"
start_gsm8k_env "$vllm_port" "$mode_dir/env.log"
start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-restart-env" "$mode_dir/env.log"
wait "$trainer_pid"
cat "$mode_dir/trainer.log"