feat: implement Elastic Environment Orchestrator (DEO) for environment microservices

This commit is contained in:
RUFFY-369 2026-04-03 01:58:30 +05:30
parent c20c85256e
commit 5a1ea7d3cb
8 changed files with 478 additions and 0 deletions

View file

@ -0,0 +1,67 @@
import logging
import math
from typing import Optional, Dict, Any, List
from .metrics import WorkloadMetrics
logger = logging.getLogger(__name__)
class ScalingController:
"""
Decides the "Desired Actor Count" based on workload metrics.
Uses a dampened calculation with hysteresis to avoid flapping.
"""
def __init__(
self,
min_actors: int = 1,
max_actors: int = 20,
target_pressure: float = 1.0,
scaling_threshold: float = 0.2, # ±20%
cooldown_seconds: int = 60,
max_step_change: int = 4
):
self.min_actors = min_actors
self.max_actors = max_actors
self.target_pressure = target_pressure
self.scaling_threshold = scaling_threshold
self.cooldown_seconds = cooldown_seconds
self.max_step_change = max_step_change
self.last_action_timestamp = 0
self.current_desired = min_actors
def calculate_desired(self, metrics: WorkloadMetrics, current_actors: int) -> int:
"""
Decides the next target for the number of environment actors.
"""
now = metrics.timestamp
pressure = metrics.rollout_pressure
# 1. Check cooldown
if now - self.last_action_timestamp < self.cooldown_seconds:
return self.current_desired
# 2. Sensitivity check (Hysteresis)
# If work is roughly satisfying target, don't change anything.
if abs(pressure - self.target_pressure) < self.scaling_threshold:
return self.current_desired
# 3. Calculate target
# Target = Current * (Current_Pressure / Ideal_Pressure)
# This is a dampened proportional controller.
raw_target = math.ceil(current_actors * (pressure / self.target_pressure))
# 4. Apply step constraints (Rate Limiting)
# Don't add/remove more than max_step_change in a single move.
diff = raw_target - current_actors
if abs(diff) > self.max_step_change:
raw_target = current_actors + (self.max_step_change if diff > 0 else -self.max_step_change)
# 5. Apply world bounds
final_target = max(self.min_actors, min(self.max_actors, raw_target))
if final_target != current_actors:
self.last_action_timestamp = now
self.current_desired = final_target
logger.info(f"Controller DECISION: Scale {current_actors} -> {final_target} (Pressure: {pressure:.2f})")
return final_target

View file

@ -0,0 +1,57 @@
import time
import requests
from typing import Optional, Dict, Any
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class WorkloadMetrics:
current_step: int
queue_size: int
total_rollouts: int
unallocated_fraction: float
num_envs: int
batch_size: int
timestamp: float
@property
def rollout_pressure(self) -> float:
"""
Calculates the "Rollout Pressure" (RP).
RP = (Queue Size / Batch Size).
If RP > 1.0, the trainer is starving.
"""
if self.batch_size <= 0:
return 0.0
return self.queue_size / self.batch_size
class MetricsCollector:
def __init__(self, server_url: str):
self.server_url = server_url.rstrip("/")
self.last_metrics: Optional[WorkloadMetrics] = None
def poll(self) -> Optional[WorkloadMetrics]:
"""
Polls the Atropos server for global metrics.
"""
try:
response = requests.get(f"{self.server_url}/global-status", timeout=5)
response.raise_for_status()
data = response.json()
metrics = WorkloadMetrics(
current_step=data["current_step"],
queue_size=data["queue_size"],
total_rollouts=data["total_rollouts_processed"],
unallocated_fraction=data["unallocated_fraction"],
num_envs=data["num_connected_envs"],
batch_size=data["batch_size"],
timestamp=time.time()
)
self.last_metrics = metrics
return metrics
except Exception as e:
logger.error(f"Failed to poll metrics from {self.server_url}: {e}")
return None

View file

@ -0,0 +1,77 @@
import subprocess
import os
import signal
import logging
from abc import ABC, abstractmethod
from typing import List, Dict
logger = logging.getLogger(__name__)
class ScalingStrategy(ABC):
"""
Abstract interface for scaling environment actors.
"""
@abstractmethod
def set_instance_count(self, target_count: int):
pass
@abstractmethod
def get_current_count(self) -> int:
pass
@abstractmethod
def cleanup(self):
pass
class LocalActor(ScalingStrategy):
"""
Manages local environment server processes via subprocess.
"""
def __init__(self, command: List[str], cwd: str = "."):
self.command = command
self.cwd = cwd
self.processes: List[subprocess.Popen] = []
def set_instance_count(self, target_count: int):
current_count = len(self.processes)
if target_count > current_count:
# Scale UP
to_add = target_count - current_count
logger.info(f"LocalActor: Scaling UP by {to_add} (Total: {target_count})")
for _ in range(to_add):
proc = subprocess.Popen(
self.command,
cwd=self.cwd,
stdout=subprocess.DEVNULL, # Should probably be configurable
stderr=subprocess.STDOUT
)
self.processes.append(proc)
elif target_count < current_count:
# Scale DOWN
to_remove = current_count - target_count
logger.info(f"LocalActor: Scaling DOWN by {to_remove} (Total: {target_count})")
for _ in range(to_remove):
proc = self.processes.pop()
proc.terminate()
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
def get_current_count(self) -> int:
# Filter out dead processes
self.processes = [p for p in self.processes if p.poll() is None]
return len(self.processes)
def cleanup(self):
logger.info("LocalActor: Cleaning up all managed processes...")
for proc in self.processes:
proc.terminate()
for proc in self.processes:
try:
proc.wait(timeout=2)
except:
proc.kill()
self.processes = []