mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
wandb integration
This commit is contained in:
parent
dc9df00570
commit
344d87562b
1 changed files with 43 additions and 7 deletions
|
|
@ -20,6 +20,8 @@ set -e
|
|||
MODEL="${1:-Qwen/Qwen3-4B-Instruct-2507}"
|
||||
TRAINING_STEPS="${2:-20}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-2}"
|
||||
USE_WANDB="${USE_WANDB:-true}" # Set USE_WANDB=false to disable
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-lora-mode-comparison}"
|
||||
|
||||
# Port allocation (separate ports for each mode)
|
||||
LORA_ONLY_VLLM_PORT=9001
|
||||
|
|
@ -45,6 +47,7 @@ echo "============================================================"
|
|||
echo "Model: $MODEL"
|
||||
echo "Steps: $TRAINING_STEPS"
|
||||
echo "Batch: $BATCH_SIZE"
|
||||
echo "Wandb: $USE_WANDB (project: $WANDB_PROJECT)"
|
||||
echo ""
|
||||
echo "GPU Allocation:"
|
||||
echo " GPU $LORA_ONLY_GPU: lora_only (ports $LORA_ONLY_API_PORT, $LORA_ONLY_VLLM_PORT)"
|
||||
|
|
@ -183,16 +186,24 @@ echo ""
|
|||
echo "[LORA_ONLY] Starting GSM8k environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.use_wandb=False \
|
||||
--env.use_wandb=$USE_WANDB \
|
||||
--env.wandb_name "lora-only-env" \
|
||||
--env.rollout_server_url "http://localhost:${LORA_ONLY_API_PORT}" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:${LORA_ONLY_VLLM_PORT}/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOG_DIR/env_lora_only.log" 2>&1 &
|
||||
2>&1 | tee "$LOG_DIR/env_lora_only.log" &
|
||||
LORA_ONLY_ENV_PID=$!
|
||||
|
||||
echo "[LORA_ONLY] Starting trainer..."
|
||||
|
||||
# Build wandb args
|
||||
WANDB_ARGS=""
|
||||
if [ "$USE_WANDB" = "true" ]; then
|
||||
WANDB_ARGS="--use-wandb --wandb-project $WANDB_PROJECT --wandb-group lora-only"
|
||||
fi
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$LORA_ONLY_GPU python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_only \
|
||||
|
|
@ -204,8 +215,9 @@ CUDA_VISIBLE_DEVICES=$LORA_ONLY_GPU python -m example_trainer.grpo \
|
|||
--lora-alpha 32 \
|
||||
--vllm-restart-interval 5 \
|
||||
--save-path "$LOG_DIR/checkpoints_lora_only" \
|
||||
$WANDB_ARGS \
|
||||
--benchmark \
|
||||
> "$LOG_DIR/trainer_lora_only.log" 2>&1 &
|
||||
2>&1 | tee "$LOG_DIR/trainer_lora_only.log" &
|
||||
LORA_ONLY_TRAINER_PID=$!
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -213,6 +225,13 @@ LORA_ONLY_TRAINER_PID=$!
|
|||
# -----------------------------------------------------------------------------
|
||||
echo ""
|
||||
echo "[LORA_RESTART] Starting trainer (manages vLLM internally)..."
|
||||
|
||||
# Build wandb args for lora_restart
|
||||
WANDB_ARGS_RESTART=""
|
||||
if [ "$USE_WANDB" = "true" ]; then
|
||||
WANDB_ARGS_RESTART="--use-wandb --wandb-project $WANDB_PROJECT --wandb-group lora-restart"
|
||||
fi
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$LORA_RESTART_GPU python -m example_trainer.grpo \
|
||||
--model-name "$MODEL" \
|
||||
--weight-bridge-mode lora_restart \
|
||||
|
|
@ -225,8 +244,9 @@ CUDA_VISIBLE_DEVICES=$LORA_RESTART_GPU python -m example_trainer.grpo \
|
|||
--lora-alpha 32 \
|
||||
--vllm-restart-interval 5 \
|
||||
--save-path "$LOG_DIR/checkpoints_lora_restart" \
|
||||
$WANDB_ARGS_RESTART \
|
||||
--benchmark \
|
||||
> "$LOG_DIR/trainer_lora_restart.log" 2>&1 &
|
||||
2>&1 | tee "$LOG_DIR/trainer_lora_restart.log" &
|
||||
LORA_RESTART_TRAINER_PID=$!
|
||||
|
||||
# Wait for lora_restart's internal vLLM to start
|
||||
|
|
@ -249,13 +269,14 @@ wait_for_health $LORA_RESTART_VLLM_PORT "lora_restart internal vLLM" 180 || {
|
|||
echo "[LORA_RESTART] Starting GSM8k environment..."
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name "$MODEL" \
|
||||
--env.use_wandb=False \
|
||||
--env.use_wandb=$USE_WANDB \
|
||||
--env.wandb_name "lora-restart-env" \
|
||||
--env.rollout_server_url "http://localhost:${LORA_RESTART_API_PORT}" \
|
||||
--openai.model_name "$MODEL" \
|
||||
--openai.base_url "http://localhost:${LORA_RESTART_VLLM_PORT}/v1" \
|
||||
--openai.server_type vllm \
|
||||
--slurm false \
|
||||
> "$LOG_DIR/env_lora_restart.log" 2>&1 &
|
||||
2>&1 | tee "$LOG_DIR/env_lora_restart.log" &
|
||||
LORA_RESTART_ENV_PID=$!
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -266,10 +287,25 @@ echo "━━━━━━━━━━━━━━━━━━━━━━━━
|
|||
echo "Both trainers running in parallel. Waiting for completion..."
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
echo "Monitor progress:"
|
||||
echo "📊 WANDB: https://wandb.ai (project: $WANDB_PROJECT)"
|
||||
echo ""
|
||||
echo "📋 MONITOR LOGS (in another terminal):"
|
||||
echo ""
|
||||
echo " # Trainer logs (main output):"
|
||||
echo " tail -f $LOG_DIR/trainer_lora_only.log"
|
||||
echo " tail -f $LOG_DIR/trainer_lora_restart.log"
|
||||
echo ""
|
||||
echo " # Environment logs (rollouts, scores):"
|
||||
echo " tail -f $LOG_DIR/env_lora_only.log"
|
||||
echo " tail -f $LOG_DIR/env_lora_restart.log"
|
||||
echo ""
|
||||
echo " # vLLM logs:"
|
||||
echo " tail -f $LOG_DIR/vllm_lora_only.log"
|
||||
echo " tail -f $LOG_DIR/checkpoints_lora_restart/vllm_internal.log"
|
||||
echo ""
|
||||
echo " # All logs at once:"
|
||||
echo " tail -f $LOG_DIR/*.log"
|
||||
echo ""
|
||||
|
||||
# Wait for trainers
|
||||
LORA_ONLY_EXIT=0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue