Update server.py

This commit is contained in:
Ragnar 2025-09-22 00:32:39 +02:00 committed by GitHub
parent 4380dc41d2
commit 60addb9a7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,6 +13,9 @@ from atroposlib.api.utils import (
grab_exact_from_heterogeneous_queue,
)
# Constants
MIN_ENV_WEIGHT = 0.01 # Minimum weight to prevent environments from being completely starved
# Message import removed - using Dict[str, Any] for more flexible validation
app = FastAPI(title="AtroposLib API")
@ -109,9 +112,8 @@ class Info(BaseModel):
@app.post("/register")
async def register(registration: Registration):
try:
isinstance(app.state.queue, list)
except AttributeError:
# Initialize app state if not already done
if not hasattr(app.state, 'queue'):
app.state.queue = []
app.state.group = registration.wandb_group
app.state.project = registration.wandb_project
@ -125,34 +127,29 @@ async def register(registration: Registration):
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
# If requesters doesn't exist, create it
app.state.requesters = [uuid.uuid4().int]
# Initialize requesters list if not already done
if not hasattr(app.state, 'requesters'):
app.state.requesters = []
app.state.requesters.append(uuid.uuid4().int)
return {"uuid": app.state.requesters[-1]}
@app.post("/register-env")
async def register_env_url(register_env: RegisterEnv):
try:
if not app.state.started:
return {
"status": "wait for trainer to start",
}
except AttributeError:
# Check if trainer has started
if not hasattr(app.state, 'started') or not app.state.started:
return {
"status": "wait for trainer to start",
}
try:
isinstance(app.state.envs, list)
except AttributeError:
# Initialize envs list if not already done
if not hasattr(app.state, 'envs'):
app.state.envs = []
checkpoint_dir = ""
try:
checkpoint_dir = app.state.checkpoint_dir
except AttributeError:
pass
# Get checkpoint directory safely
checkpoint_dir = getattr(app.state, 'checkpoint_dir', "")
real_name = (
f"{register_env.desired_name}_"
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
@ -416,8 +413,8 @@ async def get_status_env(env: EnvIdentifier):
/ total
)
env_weight = max(
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
MIN_ENV_WEIGHT, env_weight
) # Ensure minimum weight to prevent environment starvation
# Calculate total minimum allocations
total_min_allocation = 0.0