mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
feat: implement Elastic Environment Orchestrator (DEO) for environment microservices
This commit is contained in:
parent
c20c85256e
commit
5a1ea7d3cb
8 changed files with 478 additions and 0 deletions
67
atroposlib/orchestration/controller.py
Normal file
67
atroposlib/orchestration/controller.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import Optional, Dict, Any, List
|
||||
from .metrics import WorkloadMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ScalingController:
|
||||
"""
|
||||
Decides the "Desired Actor Count" based on workload metrics.
|
||||
Uses a dampened calculation with hysteresis to avoid flapping.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
min_actors: int = 1,
|
||||
max_actors: int = 20,
|
||||
target_pressure: float = 1.0,
|
||||
scaling_threshold: float = 0.2, # ±20%
|
||||
cooldown_seconds: int = 60,
|
||||
max_step_change: int = 4
|
||||
):
|
||||
self.min_actors = min_actors
|
||||
self.max_actors = max_actors
|
||||
self.target_pressure = target_pressure
|
||||
self.scaling_threshold = scaling_threshold
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
self.max_step_change = max_step_change
|
||||
|
||||
self.last_action_timestamp = 0
|
||||
self.current_desired = min_actors
|
||||
|
||||
def calculate_desired(self, metrics: WorkloadMetrics, current_actors: int) -> int:
|
||||
"""
|
||||
Decides the next target for the number of environment actors.
|
||||
"""
|
||||
now = metrics.timestamp
|
||||
pressure = metrics.rollout_pressure
|
||||
|
||||
# 1. Check cooldown
|
||||
if now - self.last_action_timestamp < self.cooldown_seconds:
|
||||
return self.current_desired
|
||||
|
||||
# 2. Sensitivity check (Hysteresis)
|
||||
# If work is roughly satisfying target, don't change anything.
|
||||
if abs(pressure - self.target_pressure) < self.scaling_threshold:
|
||||
return self.current_desired
|
||||
|
||||
# 3. Calculate target
|
||||
# Target = Current * (Current_Pressure / Ideal_Pressure)
|
||||
# This is a dampened proportional controller.
|
||||
raw_target = math.ceil(current_actors * (pressure / self.target_pressure))
|
||||
|
||||
# 4. Apply step constraints (Rate Limiting)
|
||||
# Don't add/remove more than max_step_change in a single move.
|
||||
diff = raw_target - current_actors
|
||||
if abs(diff) > self.max_step_change:
|
||||
raw_target = current_actors + (self.max_step_change if diff > 0 else -self.max_step_change)
|
||||
|
||||
# 5. Apply world bounds
|
||||
final_target = max(self.min_actors, min(self.max_actors, raw_target))
|
||||
|
||||
if final_target != current_actors:
|
||||
self.last_action_timestamp = now
|
||||
self.current_desired = final_target
|
||||
logger.info(f"Controller DECISION: Scale {current_actors} -> {final_target} (Pressure: {pressure:.2f})")
|
||||
|
||||
return final_target
|
||||
57
atroposlib/orchestration/metrics.py
Normal file
57
atroposlib/orchestration/metrics.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import time
|
||||
import requests
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class WorkloadMetrics:
|
||||
current_step: int
|
||||
queue_size: int
|
||||
total_rollouts: int
|
||||
unallocated_fraction: float
|
||||
num_envs: int
|
||||
batch_size: int
|
||||
timestamp: float
|
||||
|
||||
@property
|
||||
def rollout_pressure(self) -> float:
|
||||
"""
|
||||
Calculates the "Rollout Pressure" (RP).
|
||||
RP = (Queue Size / Batch Size).
|
||||
If RP > 1.0, the trainer is starving.
|
||||
"""
|
||||
if self.batch_size <= 0:
|
||||
return 0.0
|
||||
return self.queue_size / self.batch_size
|
||||
|
||||
class MetricsCollector:
|
||||
def __init__(self, server_url: str):
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.last_metrics: Optional[WorkloadMetrics] = None
|
||||
|
||||
def poll(self) -> Optional[WorkloadMetrics]:
|
||||
"""
|
||||
Polls the Atropos server for global metrics.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{self.server_url}/global-status", timeout=5)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
metrics = WorkloadMetrics(
|
||||
current_step=data["current_step"],
|
||||
queue_size=data["queue_size"],
|
||||
total_rollouts=data["total_rollouts_processed"],
|
||||
unallocated_fraction=data["unallocated_fraction"],
|
||||
num_envs=data["num_connected_envs"],
|
||||
batch_size=data["batch_size"],
|
||||
timestamp=time.time()
|
||||
)
|
||||
self.last_metrics = metrics
|
||||
return metrics
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to poll metrics from {self.server_url}: {e}")
|
||||
return None
|
||||
77
atroposlib/orchestration/strategy.py
Normal file
77
atroposlib/orchestration/strategy.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import subprocess
|
||||
import os
|
||||
import signal
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ScalingStrategy(ABC):
|
||||
"""
|
||||
Abstract interface for scaling environment actors.
|
||||
"""
|
||||
@abstractmethod
|
||||
def set_instance_count(self, target_count: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_current_count(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
class LocalActor(ScalingStrategy):
|
||||
"""
|
||||
Manages local environment server processes via subprocess.
|
||||
"""
|
||||
def __init__(self, command: List[str], cwd: str = "."):
|
||||
self.command = command
|
||||
self.cwd = cwd
|
||||
self.processes: List[subprocess.Popen] = []
|
||||
|
||||
def set_instance_count(self, target_count: int):
|
||||
current_count = len(self.processes)
|
||||
|
||||
if target_count > current_count:
|
||||
# Scale UP
|
||||
to_add = target_count - current_count
|
||||
logger.info(f"LocalActor: Scaling UP by {to_add} (Total: {target_count})")
|
||||
for _ in range(to_add):
|
||||
proc = subprocess.Popen(
|
||||
self.command,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.DEVNULL, # Should probably be configurable
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
self.processes.append(proc)
|
||||
|
||||
elif target_count < current_count:
|
||||
# Scale DOWN
|
||||
to_remove = current_count - target_count
|
||||
logger.info(f"LocalActor: Scaling DOWN by {to_remove} (Total: {target_count})")
|
||||
for _ in range(to_remove):
|
||||
proc = self.processes.pop()
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
|
||||
def get_current_count(self) -> int:
|
||||
# Filter out dead processes
|
||||
self.processes = [p for p in self.processes if p.poll() is None]
|
||||
return len(self.processes)
|
||||
|
||||
def cleanup(self):
|
||||
logger.info("LocalActor: Cleaning up all managed processes...")
|
||||
for proc in self.processes:
|
||||
proc.terminate()
|
||||
for proc in self.processes:
|
||||
try:
|
||||
proc.wait(timeout=2)
|
||||
except:
|
||||
proc.kill()
|
||||
self.processes = []
|
||||
Loading…
Add table
Add a link
Reference in a new issue