diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 3a0fb999..02f734ba 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -233,6 +233,8 @@ def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]: app.state.queue.append(data_dict) app.state.latest = data_dict + if hasattr(app.state, "total_rollouts_processed"): + app.state.total_rollouts_processed += 1 return {"status": "received"} @@ -253,6 +255,19 @@ class Info(BaseModel): batch_size: int = -1 +class GlobalStatus(BaseModel): + """ + Basemodel for global orchestration metrics + """ + + current_step: int + queue_size: int + total_rollouts_processed: int + unallocated_fraction: float + num_connected_envs: int + batch_size: int + + @app.post("/register") async def register(registration: Registration): # Initialize app state if not already done @@ -270,6 +285,7 @@ async def register(registration: Registration): app.state.started = False app.state.envs = [] app.state.buffer = {} # Buffer for mixed-size groups per environment + app.state.total_rollouts_processed = 0 # Initialize requesters list if not already done if not hasattr(app.state, "requesters"): @@ -468,6 +484,42 @@ async def get_status(): return {"current_step": 0, "queue_size": 0} +@app.get("/global-status", response_model=GlobalStatus) +async def get_global_status(): + """ + Returns global metrics for the Elastic Orchestrator to monitor workload pressure. + """ + try: + # Calculate total unallocated fraction + total_min_allocation = 0.0 + connected_envs = 0 + for env_config in getattr(app.state, "envs", []): + if env_config.get("connected", False): + connected_envs += 1 + if env_config.get("min_batch_allocation") is not None: + total_min_allocation += env_config["min_batch_allocation"] + + unallocated_fraction = 1.0 - min(total_min_allocation, 1.0) + + return { + "current_step": getattr(app.state, "status_dict", {}).get("step", 0), + "queue_size": len(getattr(app.state, "queue", [])), + "total_rollouts_processed": getattr(app.state, "total_rollouts_processed", 0), + "unallocated_fraction": unallocated_fraction, + "num_connected_envs": connected_envs, + "batch_size": getattr(app.state, "batchsize", -1), + } + except AttributeError: + return { + "current_step": 0, + "queue_size": 0, + "total_rollouts_processed": 0, + "unallocated_fraction": 1.0, + "num_connected_envs": 0, + "batch_size": -1, + } + + @app.get("/status-env") async def get_status_env(env: EnvIdentifier): total = sum( diff --git a/atroposlib/cli/orchestrate.py b/atroposlib/cli/orchestrate.py new file mode 100644 index 00000000..b399df5e --- /dev/null +++ b/atroposlib/cli/orchestrate.py @@ -0,0 +1,79 @@ +import argparse +import time +import logging +import signal +import sys +from atroposlib.orchestration.metrics import MetricsCollector +from atroposlib.orchestration.controller import ScalingController +from atroposlib.orchestration.strategy import LocalActor + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s' +) +logger = logging.getLogger("DEO") + +def main(): + parser = argparse.ArgumentParser(description="Atropos Elastic Orchestrator (DEO)") + parser.add_argument("--server-url", type=str, default="http://localhost:8000", help="Atropos server URL") + parser.add_argument("--env-command", type=str, required=True, help="Command to launch environment server") + parser.add_argument("--min-actors", type=int, default=1, help="Min environment actors") + parser.add_argument("--max-actors", type=int, default=20, help="Max environment actors") + parser.add_argument("--target-pressure", type=float, default=1.0, help="Target Rollout Pressure (Queue/BatchSize)") + parser.add_argument("--poll-interval", type=int, default=10, help="Poll interval in seconds") + parser.add_argument("--cooldown", type=int, default=60, help="Scaling cooldown in seconds") + parser.add_argument("--max-step", type=int, default=4, help="Max actors to add/remove at once") + + args = parser.parse_args() + + # 1. Initialize metrics collector + collector = MetricsCollector(args.server_url) + + # 2. Initialize Scaling Controller + controller = ScalingController( + min_actors=args.min_actors, + max_actors=args.max_actors, + target_pressure=args.target_pressure, + cooldown_seconds=args.cooldown, + max_step_change=args.max_step + ) + + # 3. Initialize Strategy (LocalActor) + # Convert command string to list + env_command_list = args.env_command.split() + actor = LocalActor(env_command_list) + + logger.info(f"Starting DEO against {args.server_url}...") + logger.info(f"Command: {args.env_command}") + + # Graceful shutdown handler + def handle_shutdown(sig, frame): + logger.info("Shutdown signal received. Cleaning up...") + actor.cleanup() + sys.exit(0) + + signal.signal(signal.SIGINT, handle_shutdown) + signal.signal(signal.SIGTERM, handle_shutdown) + + # 4. Main Control Loop + try: + while True: + metrics = collector.poll() + if metrics: + current_actors = actor.get_current_count() + target_actors = controller.calculate_desired(metrics, current_actors) + + if target_actors != current_actors: + actor.set_instance_count(target_actors) + else: + logger.warning("Could not fetch metrics. Check if Atropos server is running.") + + time.sleep(args.poll_interval) + + except Exception as e: + logger.error(f"DEO loop crashed: {e}") + actor.cleanup() + +if __name__ == "__main__": + main() diff --git a/atroposlib/orchestration/controller.py b/atroposlib/orchestration/controller.py new file mode 100644 index 00000000..bce2a9c2 --- /dev/null +++ b/atroposlib/orchestration/controller.py @@ -0,0 +1,67 @@ +import logging +import math +from typing import Optional, Dict, Any, List +from .metrics import WorkloadMetrics + +logger = logging.getLogger(__name__) + +class ScalingController: + """ + Decides the "Desired Actor Count" based on workload metrics. + Uses a dampened calculation with hysteresis to avoid flapping. + """ + def __init__( + self, + min_actors: int = 1, + max_actors: int = 20, + target_pressure: float = 1.0, + scaling_threshold: float = 0.2, # ±20% + cooldown_seconds: int = 60, + max_step_change: int = 4 + ): + self.min_actors = min_actors + self.max_actors = max_actors + self.target_pressure = target_pressure + self.scaling_threshold = scaling_threshold + self.cooldown_seconds = cooldown_seconds + self.max_step_change = max_step_change + + self.last_action_timestamp = 0 + self.current_desired = min_actors + + def calculate_desired(self, metrics: WorkloadMetrics, current_actors: int) -> int: + """ + Decides the next target for the number of environment actors. + """ + now = metrics.timestamp + pressure = metrics.rollout_pressure + + # 1. Check cooldown + if now - self.last_action_timestamp < self.cooldown_seconds: + return self.current_desired + + # 2. Sensitivity check (Hysteresis) + # If work is roughly satisfying target, don't change anything. + if abs(pressure - self.target_pressure) < self.scaling_threshold: + return self.current_desired + + # 3. Calculate target + # Target = Current * (Current_Pressure / Ideal_Pressure) + # This is a dampened proportional controller. + raw_target = math.ceil(current_actors * (pressure / self.target_pressure)) + + # 4. Apply step constraints (Rate Limiting) + # Don't add/remove more than max_step_change in a single move. + diff = raw_target - current_actors + if abs(diff) > self.max_step_change: + raw_target = current_actors + (self.max_step_change if diff > 0 else -self.max_step_change) + + # 5. Apply world bounds + final_target = max(self.min_actors, min(self.max_actors, raw_target)) + + if final_target != current_actors: + self.last_action_timestamp = now + self.current_desired = final_target + logger.info(f"Controller DECISION: Scale {current_actors} -> {final_target} (Pressure: {pressure:.2f})") + + return final_target diff --git a/atroposlib/orchestration/metrics.py b/atroposlib/orchestration/metrics.py new file mode 100644 index 00000000..88173f92 --- /dev/null +++ b/atroposlib/orchestration/metrics.py @@ -0,0 +1,57 @@ +import time +import requests +from typing import Optional, Dict, Any +from dataclasses import dataclass +import logging + +logger = logging.getLogger(__name__) + +@dataclass +class WorkloadMetrics: + current_step: int + queue_size: int + total_rollouts: int + unallocated_fraction: float + num_envs: int + batch_size: int + timestamp: float + + @property + def rollout_pressure(self) -> float: + """ + Calculates the "Rollout Pressure" (RP). + RP = (Queue Size / Batch Size). + If RP > 1.0, the trainer is starving. + """ + if self.batch_size <= 0: + return 0.0 + return self.queue_size / self.batch_size + +class MetricsCollector: + def __init__(self, server_url: str): + self.server_url = server_url.rstrip("/") + self.last_metrics: Optional[WorkloadMetrics] = None + + def poll(self) -> Optional[WorkloadMetrics]: + """ + Polls the Atropos server for global metrics. + """ + try: + response = requests.get(f"{self.server_url}/global-status", timeout=5) + response.raise_for_status() + data = response.json() + + metrics = WorkloadMetrics( + current_step=data["current_step"], + queue_size=data["queue_size"], + total_rollouts=data["total_rollouts_processed"], + unallocated_fraction=data["unallocated_fraction"], + num_envs=data["num_connected_envs"], + batch_size=data["batch_size"], + timestamp=time.time() + ) + self.last_metrics = metrics + return metrics + except Exception as e: + logger.error(f"Failed to poll metrics from {self.server_url}: {e}") + return None diff --git a/atroposlib/orchestration/strategy.py b/atroposlib/orchestration/strategy.py new file mode 100644 index 00000000..39ca0e08 --- /dev/null +++ b/atroposlib/orchestration/strategy.py @@ -0,0 +1,77 @@ +import subprocess +import os +import signal +import logging +from abc import ABC, abstractmethod +from typing import List, Dict + +logger = logging.getLogger(__name__) + +class ScalingStrategy(ABC): + """ + Abstract interface for scaling environment actors. + """ + @abstractmethod + def set_instance_count(self, target_count: int): + pass + + @abstractmethod + def get_current_count(self) -> int: + pass + + @abstractmethod + def cleanup(self): + pass + +class LocalActor(ScalingStrategy): + """ + Manages local environment server processes via subprocess. + """ + def __init__(self, command: List[str], cwd: str = "."): + self.command = command + self.cwd = cwd + self.processes: List[subprocess.Popen] = [] + + def set_instance_count(self, target_count: int): + current_count = len(self.processes) + + if target_count > current_count: + # Scale UP + to_add = target_count - current_count + logger.info(f"LocalActor: Scaling UP by {to_add} (Total: {target_count})") + for _ in range(to_add): + proc = subprocess.Popen( + self.command, + cwd=self.cwd, + stdout=subprocess.DEVNULL, # Should probably be configurable + stderr=subprocess.STDOUT + ) + self.processes.append(proc) + + elif target_count < current_count: + # Scale DOWN + to_remove = current_count - target_count + logger.info(f"LocalActor: Scaling DOWN by {to_remove} (Total: {target_count})") + for _ in range(to_remove): + proc = self.processes.pop() + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + def get_current_count(self) -> int: + # Filter out dead processes + self.processes = [p for p in self.processes if p.poll() is None] + return len(self.processes) + + def cleanup(self): + logger.info("LocalActor: Cleaning up all managed processes...") + for proc in self.processes: + proc.terminate() + for proc in self.processes: + try: + proc.wait(timeout=2) + except: + proc.kill() + self.processes = [] diff --git a/integration_tests/dummy_actor.py b/integration_tests/dummy_actor.py new file mode 100644 index 00000000..734b9b6a --- /dev/null +++ b/integration_tests/dummy_actor.py @@ -0,0 +1,34 @@ +import requests +import time +import argparse +import sys + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--server", type=str, default="http://localhost:8000") + args = parser.parse_args() + + print(f"Dummy Actor starting against {args.server}") + + # 1. Register with server + try: + resp = requests.post(f"{args.server}/register-env", json={ + "max_token_length": 1024, + "desired_name": "dummy_actor", + "weight": 1.0, + "group_size": 1 + }) + resp.raise_for_status() + data = resp.json() + print(f"Registered (ID: {data['env_id']})") + + # 2. Stay alive + while True: + time.sleep(10) + + except Exception as e: + print(f"Dummy Actor failed: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/integration_tests/simulate_scaling.py b/integration_tests/simulate_scaling.py new file mode 100644 index 00000000..206539af --- /dev/null +++ b/integration_tests/simulate_scaling.py @@ -0,0 +1,111 @@ +import subprocess +import time +import requests +import os +import signal +import sys + +def main(): + print("🚀 Starting Atropos Elasticity Simulation...") + + # 1. Start Atropos Server + server_process = subprocess.Popen( + ["run-api", "--port", "8008"], + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT + ) + time.sleep(3) + server_url = "http://localhost:8008" + + # 2. Register Trainer + # This must happen before /register-env will work on server + print("Registering Trainer (BatchSize=32)...") + requests.post(f"{server_url}/register", json={ + "wandb_group": "sim", + "wandb_project": "sim", + "batch_size": 32, + "max_token_len": 512, + "checkpoint_dir": "/tmp", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 100 + }) + + # 3. Start Orchestrator + print("Starting Orchestrator (Monitoring RP=1.0)...") + # Command for dummy actor: "python integration_tests/dummy_actor.py --server http://localhost:8008" + orchestrator_process = subprocess.Popen( + [ + "python", "-m", "atroposlib.cli.orchestrate", + "--server-url", server_url, + "--env-command", f"python integration_tests/dummy_actor.py --server {server_url}", + "--min-actors", "1", + "--max-actors", "10", + "--poll-interval", "5", + "--cooldown", "10" # Short cooldown for simulation + ], + cwd=".", + # Use shell if needed for complex commands, but list is safer + ) + + time.sleep(10) + + def get_status(): + try: + return requests.get(f"{server_url}/global-status").json() + except: + return None + + try: + # Pass 1: "Quiet Mode" + print("\n--- TEST 1: Baseline Scaling (Goal: 1 Actor) ---") + status = get_status() + print(f"Queue Size: {status['queue_size']}, Actors: {status['num_connected_envs']}") + + # Pass 2: "Burst Mode" + print("\n--- TEST 2: Burst Scaling (Goal: Target RP > 1) ---") + # RP = (Queue / BatchSize). To get RP=5 with batch_size=32, we need 160 rollouts. + print("Pushing 160 dummy rollouts to increase Rollout Pressure...") + fake_rollout = { + "tokens": [[1]*10], + "masks": [[1]*10], + "scores": [1.0], + "env_id": 0 # This might cause some warnings if not registered, but we just need it in queue + } + # scored_data_list is faster + requests.post(f"{server_url}/scored_data_list", json=[fake_rollout]*160) + + print("Waiting for Orchestrator to react (approx 15s)...") + for _ in range(4): + time.sleep(5) + status = get_status() + print(f"Queue Size: {status['queue_size']}, Actors: {status['num_connected_envs']} (RP: {status['queue_size']/status['batch_size']:.2f})") + if status['num_connected_envs'] > 1: + print("✅ SCALING UP DETECTED.") + + # Pass 3: "Emptying Mode" + print("\n--- TEST 3: Drain Scaling (Goal: Scale Back Down) ---") + print("Cleaning server queue...") + requests.get(f"{server_url}/reset_data") + # Re-register trainer + requests.post(f"{server_url}/register", json={ + "wandb_group": "sim", "wandb_project": "sim", "batch_size": 32, "max_token_len": 512, "checkpoint_dir": "/tmp", "save_checkpoint_interval": 10, "starting_step": 0, "num_steps": 100 + }) + + print("Waiting for Orchestrator to react (approx 20s)...") + for _ in range(5): + time.sleep(5) + status = get_status() + print(f"Queue Size: {status['queue_size']}, Actors: {status['num_connected_envs']}") + if status['num_connected_envs'] < 4: + print("✅ SCALING DOWN DETECTED.") + + finally: + print("\nCleaning up simulation...") + orchestrator_process.terminate() + server_process.terminate() + os.system("pkill -f dummy_actor") # Kill any orphans + print("Done.") + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 6f23666c..c1b62cae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ view-run = "atroposlib.cli.view_run:main" view-run-multimodal = "atroposlib.cli.view_run_multimodal:main" atropos-sft-gen = "atroposlib.cli.sft:main" atropos-dpo-gen = "atroposlib.cli.dpo:main" +atropos-orchestrate = "atroposlib.cli.orchestrate:main" atropos-grpo = "example_trainer.grpo:main" atropos-grpo-run = "example_trainer.run:main"