diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 4a94c6d8..985f8213 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -344,6 +344,14 @@ async def info(): @app.get("/batch") async def get_batch(): + # Check if trainer has registered first + if not hasattr(app.state, "started"): + return { + "status": "error", + "message": "Trainer not registered. Call /register first.", + "batch": [], + } + if not app.state.started: app.state.started = True diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index edbc5af0..bf2140ec 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -178,9 +178,14 @@ def register_trainer(config: TrainingConfig): ) -@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) def get_batch(): data = requests.get("http://localhost:8000/batch", timeout=10).json() + + # Check if there was an error (trainer not registered) + if data.get("status") == "error": + raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") + return data diff --git a/example_trainer/vllm_patching/weight_updater.py b/example_trainer/vllm_patching/weight_updater.py index 5d39569e..40b446a4 100644 --- a/example_trainer/vllm_patching/weight_updater.py +++ b/example_trainer/vllm_patching/weight_updater.py @@ -143,6 +143,11 @@ def weight_updater_process( ) print("[Updater] ✓ NCCL group created", flush=True) + # Barrier synchronization to confirm both sides are ready + print("[Updater] Waiting for trainer to be ready...", flush=True) + dist.barrier(group=gloo_group) + print("[Updater] ✓ Trainer is ready, starting update loop", flush=True) + except Exception as e: print(f"[Updater] Failed to create process groups: {e}", flush=True) import traceback @@ -212,6 +217,17 @@ def weight_updater_process( if debug or (update_count % 50 == 0): print(f"[Updater] Updated {param_name} (#{update_count})", flush=True) + except torch.distributed.DistBackendError as e: + # NCCL communication failure - likely trainer crashed + error_str = str(e) + if "Broken pipe" in error_str or "Connection reset" in error_str: + print("[Updater] Trainer disconnected (broken pipe). Exiting.", flush=True) + break + else: + print(f"[Updater] NCCL error: {e}", flush=True) + import traceback + traceback.print_exc() + time.sleep(1) except Exception as e: print(f"[Updater] Error in update loop: {e}", flush=True) import traceback diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index fd3d66eb..5307bd54 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -336,6 +336,14 @@ class VLLMWeightBridge: group_name="weight_update_group", ) print("[Bridge] ✓ NCCL group created") + + # Barrier synchronization to ensure both sides are ready + print("[Bridge] Waiting for all ranks to be ready...") + try: + dist.barrier(group=self.gloo_group) + print("[Bridge] ✓ All ranks synchronized and ready") + except Exception as e: + print(f"[Bridge] Warning: Barrier sync failed: {e}") def _initialize_http_mode(self) -> None: """Initialize HTTP-based weight synchronization (fallback)."""