mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Update server.py
This commit is contained in:
parent
4380dc41d2
commit
60addb9a7d
1 changed files with 21 additions and 24 deletions
|
|
@ -13,6 +13,9 @@ from atroposlib.api.utils import (
|
|||
grab_exact_from_heterogeneous_queue,
|
||||
)
|
||||
|
||||
# Constants
|
||||
MIN_ENV_WEIGHT = 0.01 # Minimum weight to prevent environments from being completely starved
|
||||
|
||||
# Message import removed - using Dict[str, Any] for more flexible validation
|
||||
|
||||
app = FastAPI(title="AtroposLib API")
|
||||
|
|
@ -109,9 +112,8 @@ class Info(BaseModel):
|
|||
|
||||
@app.post("/register")
|
||||
async def register(registration: Registration):
|
||||
try:
|
||||
isinstance(app.state.queue, list)
|
||||
except AttributeError:
|
||||
# Initialize app state if not already done
|
||||
if not hasattr(app.state, 'queue'):
|
||||
app.state.queue = []
|
||||
app.state.group = registration.wandb_group
|
||||
app.state.project = registration.wandb_project
|
||||
|
|
@ -125,34 +127,29 @@ async def register(registration: Registration):
|
|||
app.state.started = False
|
||||
app.state.envs = []
|
||||
app.state.buffer = {} # Buffer for mixed-size groups per environment
|
||||
try:
|
||||
app.state.requesters.append(uuid.uuid4().int)
|
||||
except AttributeError:
|
||||
# If requesters doesn't exist, create it
|
||||
app.state.requesters = [uuid.uuid4().int]
|
||||
|
||||
# Initialize requesters list if not already done
|
||||
if not hasattr(app.state, 'requesters'):
|
||||
app.state.requesters = []
|
||||
|
||||
app.state.requesters.append(uuid.uuid4().int)
|
||||
return {"uuid": app.state.requesters[-1]}
|
||||
|
||||
|
||||
@app.post("/register-env")
|
||||
async def register_env_url(register_env: RegisterEnv):
|
||||
try:
|
||||
if not app.state.started:
|
||||
return {
|
||||
"status": "wait for trainer to start",
|
||||
}
|
||||
except AttributeError:
|
||||
# Check if trainer has started
|
||||
if not hasattr(app.state, 'started') or not app.state.started:
|
||||
return {
|
||||
"status": "wait for trainer to start",
|
||||
}
|
||||
try:
|
||||
isinstance(app.state.envs, list)
|
||||
except AttributeError:
|
||||
|
||||
# Initialize envs list if not already done
|
||||
if not hasattr(app.state, 'envs'):
|
||||
app.state.envs = []
|
||||
checkpoint_dir = ""
|
||||
try:
|
||||
checkpoint_dir = app.state.checkpoint_dir
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Get checkpoint directory safely
|
||||
checkpoint_dir = getattr(app.state, 'checkpoint_dir', "")
|
||||
real_name = (
|
||||
f"{register_env.desired_name}_"
|
||||
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
|
||||
|
|
@ -416,8 +413,8 @@ async def get_status_env(env: EnvIdentifier):
|
|||
/ total
|
||||
)
|
||||
env_weight = max(
|
||||
0.01, env_weight
|
||||
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
|
||||
MIN_ENV_WEIGHT, env_weight
|
||||
) # Ensure minimum weight to prevent environment starvation
|
||||
|
||||
# Calculate total minimum allocations
|
||||
total_min_allocation = 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue