diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 1d70832a..b9f13fc0 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -82,6 +82,12 @@ def parse_args() -> argparse.Namespace: default="trained_model_checkpoints", help="Directory to save model checkpoints", ) + parser.add_argument( + "--checkpoint-interval", + type=int, + default=3, + help="Save checkpoint every N training steps (0 = only save final)", + ) # === vLLM Arguments === parser.add_argument( @@ -258,6 +264,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: optimizer=args.optimizer, device=args.device, save_path=args.save_path, + checkpoint_interval=getattr(args, "checkpoint_interval", 3), vllm_restart_interval=args.vllm_restart_interval, vllm_port=args.vllm_port, vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, diff --git a/example_trainer/config.py b/example_trainer/config.py index c7673855..858472e2 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -49,6 +49,13 @@ class TrainingConfig(BaseModel): "trained_model_checkpoints", description="Base path to save model checkpoints" ) + checkpoint_interval: int = Field( + 3, + description=( + "Save checkpoint every N training steps. " + "Set to 0 to only save final checkpoint." + ), + ) # === vLLM Server Configuration === vllm_restart_interval: int = Field( diff --git a/example_trainer/scripts/run_concurrent_tests.sh b/example_trainer/scripts/run_concurrent_tests.sh index c9779b02..5178c13f 100644 --- a/example_trainer/scripts/run_concurrent_tests.sh +++ b/example_trainer/scripts/run_concurrent_tests.sh @@ -97,6 +97,8 @@ echo " PID: $LORA_VLLM_PID" echo "" echo "[2/6] Starting Single-Copy vLLM server (GPU 4)..." +# NOTE: --enforce-eager is REQUIRED for single-copy mode! +# Without it, CUDA graphs freeze weights and updates won't be visible to inference. CUDA_VISIBLE_DEVICES=4 \ VLLM_ENABLE_SHARED_WEIGHTS=1 \ LOGDIR="$SINGLE_COPY_CHECKPOINT_DIR" \ @@ -106,6 +108,7 @@ python -u example_trainer/vllm_api_server.py \ --port $SINGLE_COPY_VLLM_PORT \ --dtype bfloat16 \ --gpu-memory-utilization 0.5 \ + --enforce-eager \ > "${LOG_DIR}/single_copy_vllm.log" 2>&1 & SINGLE_COPY_VLLM_PID=$! echo " PID: $SINGLE_COPY_VLLM_PID" diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index d0cc0dac..c005a924 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -569,7 +569,7 @@ def train_shared_vllm(config: TrainingConfig): log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) # Periodic checkpoint (for recovery, not for vLLM sync) - if (step + 1) % config.vllm_restart_interval == 0: + if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0: save_checkpoint(model, tokenizer, config.save_path, step + 1) # === Cleanup ===