atropos/atroposlib/api/server.py
2026-03-13 11:06:02 -04:00

642 lines
21 KiB
Python

import gzip
import logging
import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from starlette.datastructures import MutableHeaders
from starlette.types import Receive, Scope, Send
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
grab_exact_from_heterogeneous_queue,
)
# Constants
MIN_ENV_WEIGHT = (
0.01 # Minimum weight to prevent environments from being completely starved
)
# Message import removed - using Dict[str, Any] for more flexible validation
app = FastAPI(title="AtroposLib API")
logger = logging.getLogger(__name__)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
class GZipRequestMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
headers = MutableHeaders(scope=scope)
content_encoding = headers.get("content-encoding", "")
if "gzip" not in content_encoding.lower():
await self.app(scope, receive, send)
return
body_chunks = []
more_body = True
while more_body:
message = await receive()
body_chunks.append(message.get("body", b""))
more_body = message.get("more_body", False)
body = b"".join(body_chunks)
if body:
try:
decompressed = gzip.decompress(body)
except OSError:
response = PlainTextResponse(
"Invalid gzip payload",
status_code=status.HTTP_400_BAD_REQUEST,
)
await response(scope, receive, send)
return
else:
decompressed = b""
mutable_headers = MutableHeaders(scope=scope)
mutable_headers["content-length"] = str(len(decompressed))
if "content-encoding" in mutable_headers:
del mutable_headers["content-encoding"]
sent = False
# needed some odd logic here to handle gzip stream so just returning an empty body
async def new_receive():
nonlocal sent
if sent:
return {"type": "http.request", "body": b"", "more_body": False}
sent = True
return {
"type": "http.request",
"body": decompressed,
"more_body": False,
}
await self.app(scope, new_receive, send)
app.add_middleware(GZipRequestMiddleware)
@app.get("/")
async def root():
return {"message": "AtroposLib API"}
class Registration(BaseModel):
wandb_group: str
wandb_project: str
batch_size: int
max_token_len: int
checkpoint_dir: str
save_checkpoint_interval: int
starting_step: int
num_steps: int
class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
group_size: int
min_batch_allocation: Optional[float] = (
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
)
class EnvIdentifier(BaseModel):
env_id: int
class ScoredData(BaseModel):
tokens: List[List[int]]
masks: List[List[int]]
scores: List[float]
advantages: Optional[List[List[float]]] = None
ref_logprobs: Optional[List[List[float]]] = None
messages: Optional[List[List[Dict[str, Any]]]] = (
None # Changed from Message TypedDict to Dict
)
generation_params: Optional[Dict[str, Any]] = None
inference_logprobs: Optional[List[List[float]]] = None
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
env_id: Optional[int] = None # ID of the environment that generated this data
# On-policy distillation (new format): parallel token ids + logprobs.
# Shape for both: [sequence][position][top_k]
distill_token_ids: Optional[List[List[List[int]]]] = None
distill_logprobs: Optional[List[List[List[float]]]] = None
@field_validator("messages", mode="before")
@classmethod
def validate_messages(cls, v):
"""Validate messages field to ensure required fields are present.
This validator only checks that messages have 'role' and 'content' fields.
The 'reward' field is completely optional.
"""
if v is None:
return None
for message_list in v:
for msg in message_list:
# Ensure the message has the required fields
if "role" not in msg or "content" not in msg:
raise ValueError("Message must have 'role' and 'content' fields")
return v
def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
"""Convert a `ScoredData` pydantic model into a plain dictionary."""
return {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
"distill_token_ids": scored_data.distill_token_ids,
"distill_logprobs": scored_data.distill_logprobs,
}
def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
"""Normalize buffering/queueing logic for scored data submissions."""
if not hasattr(app.state, "queue"):
app.state.queue = []
if not hasattr(app.state, "buffer"):
app.state.buffer = {}
data_dict = _scored_data_to_dict(scored_data)
env_id = data_dict.get("env_id")
envs = getattr(app.state, "envs", [])
if env_id is not None and env_id < len(envs):
expected_group_size = envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
buffer = app.state.buffer.setdefault(env_id, [])
buffer.append(data_dict)
indices = find_groups_summing_to_target(buffer, expected_group_size)
if indices:
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(buffer.pop(idx))
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(group["tokens"]) for group in app.state.buffer.get(env_id, [])
),
}
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
class Status(BaseModel):
"""
basemodel for status information of the current server
"""
current_step: int
queue_size: int
class Info(BaseModel):
"""
basemodel for useful information
"""
batch_size: int = -1
@app.post("/register")
async def register(registration: Registration):
# Initialize app state if not already done
if not hasattr(app.state, "queue"):
app.state.queue = []
app.state.group = registration.wandb_group
app.state.project = registration.wandb_project
app.state.batchsize = int(registration.batch_size)
app.state.max_token_len = int(registration.max_token_len)
app.state.status_dict = {"step": registration.starting_step}
app.state.checkpoint_dir = registration.checkpoint_dir
app.state.save_checkpoint_interval = registration.save_checkpoint_interval
app.state.num_steps = registration.num_steps
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
# Initialize requesters list if not already done
if not hasattr(app.state, "requesters"):
app.state.requesters = []
app.state.requesters.append(uuid.uuid4().int)
return {"uuid": app.state.requesters[-1]}
@app.post("/register-env")
async def register_env_url(register_env: RegisterEnv):
# Check if trainer has started
if not hasattr(app.state, "started") or not app.state.started:
return {
"status": "wait for trainer to start",
}
# Initialize envs list if not already done
if not hasattr(app.state, "envs"):
app.state.envs = []
# Get checkpoint directory safely
checkpoint_dir = getattr(app.state, "checkpoint_dir", "")
real_name = (
f"{register_env.desired_name}_"
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
)
registered_id = len(app.state.envs)
app.state.envs.append(
{
"max_context_len": register_env.max_token_length,
"weight": register_env.weight if register_env.weight is not None else 1.0,
"desired_name": register_env.desired_name,
"real_name": real_name,
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
}
)
return {
"status": "success",
"env_id": registered_id,
"wandb_name": real_name,
"checkpoint_dir": checkpoint_dir,
"starting_step": app.state.status_dict["step"],
"checkpoint_interval": app.state.save_checkpoint_interval,
"num_steps": app.state.num_steps,
}
@app.post("/disconnect-env")
async def disconnect_env(disconnect_env: EnvIdentifier):
try:
app.state.envs[disconnect_env.env_id]["connected"] = False
return {"status": "success"}
except (AttributeError, IndexError) as e:
return {"status": "failure", "error": str(e)}
@app.get("/wandb_info")
async def wandb_info():
try:
return {"group": app.state.group, "project": app.state.project}
except AttributeError:
return {"group": None, "project": None}
@app.get("/info")
async def info():
try:
return {
"batch_size": app.state.batchsize,
"max_token_len": app.state.max_token_len,
}
except AttributeError:
return {"batch_size": -1, "max_token_len": -1}
@app.get("/batch")
async def get_batch(request: Request):
# Check if trainer has registered first
if not hasattr(app.state, "started"):
return {
"status": "error",
"message": "Trainer not registered. Call /register first.",
"batch": [],
}
if not app.state.started:
app.state.started = True
client = request.client
client_addr = (
f"{client.host}:{client.port}" if client is not None else "unknown-client"
)
client_tag = request.headers.get("x-atropos-client", "unknown")
client_pid = request.headers.get("x-atropos-pid", "unknown")
if len(app.state.curr_batch) > 0:
curr_batch = app.state.curr_batch.pop()
logger.warning(
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
)
return {"batch": curr_batch}
else:
new_batches = []
# Check if any envs have minimum allocations
has_min_allocations = any(
env.get("min_batch_allocation") is not None
for env in getattr(app.state, "envs", [])
)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
now = time.time()
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
if now - last_empty_log > 30:
logger.warning(
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
client_tag,
client_pid,
client_addr,
len(app.state.queue),
sum(len(x.get("tokens", [])) for x in app.state.queue),
len(app.state.curr_batch),
getattr(app.state, "batchsize", -1),
)
app.state._last_empty_batch_log = now
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
for batch in new_batches:
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
logger.warning(
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
steps_to_take,
client_tag,
client_pid,
client_addr,
len(curr_batch),
sum(len(x["tokens"]) for x in curr_batch),
len(app.state.curr_batch),
len(app.state.queue),
app.state.status_dict["step"],
)
return {"batch": curr_batch}
@app.get("/latest_example")
async def get_latest_example():
try:
return app.state.latest
except AttributeError:
return {
"tokens": [],
"masks": [],
"scores": [],
"advantages": [],
"ref_logprobs": [],
"generation_params": [],
"inference_logprobs": [],
"messages": [],
"images": [],
}
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
return _process_scored_data(scored_data)
@app.post("/scored_data_list")
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
# Process each scored data item
buffered_count = 0
last_buffer_size: Optional[int] = None
for scored_data in scored_data_list:
result = _process_scored_data(scored_data)
if result.get("status") == "buffered":
buffered_count += 1
last_buffer_size = result.get("buffer_size", last_buffer_size)
response: Dict[str, Any] = {
"status": "received",
"groups_processed": len(scored_data_list),
}
if buffered_count:
response["buffered"] = buffered_count
if last_buffer_size is not None:
response["last_buffer_size"] = last_buffer_size
return response
@app.get("/status")
async def get_status():
try:
return {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
}
except AttributeError:
return {"current_step": 0, "queue_size": 0}
@app.get("/status-env")
async def get_status_env(env: EnvIdentifier):
total = sum(
[
x["max_context_len"] * max(0.0, x["weight"])
for x in app.state.envs
if x["connected"]
]
)
env_group_size = app.state.envs[env.env_id]["group_size"]
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
/ total
)
env_weight = max(
MIN_ENV_WEIGHT, env_weight
) # Ensure minimum weight to prevent environment starvation
# Calculate total minimum allocations
total_min_allocation = 0.0
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
total_min_allocation += env_config["min_batch_allocation"]
# Calculate unallocated fraction
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
# Find the maximum group size across all items in queue
queue = getattr(app.state, "queue", [])
max_group_size = 1
num_self_sequences_in_queue = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
if item.get("env_id") == env.env_id:
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
env_group_size = max(env_group_size, group_size)
num_self_sequences_in_queue += group_size
# update the group size for the requesting env
app.state.envs[env.env_id]["group_size"] = env_group_size
# Calculate minimum sequences allocated to each environment
batch_size = getattr(app.state, "batchsize", 0)
min_sequences_by_env = {}
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
env_id = env_config["registered_id"]
min_sequences = int(batch_size * env_config["min_batch_allocation"])
min_sequences_by_env[env_id] = min_sequences
# Count sequences and calculate packed groups for each environment
import math
sequences_by_env = {}
packed_groups_by_env = {}
curr_env_total_sequences = 0
for item in queue:
env_id = item.get("env_id")
seq_count = len(item.get("tokens", []))
# Special handling for the requesting environment
if env_id == env.env_id:
curr_env_total_sequences += seq_count
else:
if env_id not in sequences_by_env:
sequences_by_env[env_id] = 0
sequences_by_env[env_id] += seq_count
# Calculate packed groups for each environment (excluding the requesting env)
if max_group_size > 1:
for env_id, seq_count in sequences_by_env.items():
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
# Calculate adjusted queue size
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
available_from_others = 0
for env_id in packed_groups_by_env:
packed_sequences = packed_groups_by_env[env_id] * max_group_size
min_sequences = min_sequences_by_env.get(env_id, 0)
available_from_others += max(0, packed_sequences - min_sequences)
env_queue_size = curr_env_total_sequences + available_from_others
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": env_queue_size // env_group_size,
"unallocated_fraction": unallocated_fraction,
"self_queue_size": num_self_sequences_in_queue // env_group_size,
"max_group_size": max_group_size,
}
except AttributeError:
ret_dict = {
"current_step": 0,
"queue_size": 0,
"unallocated_fraction": 1.0,
"num_self_sequences_in_queue": 0,
}
ret_dict["env_weight"] = env_weight
return ret_dict
@app.get("/reset_data")
async def reset_data():
try:
del app.state.queue
app.state.group = None
app.state.project = None
app.state.batchsize = -1
app.state.num_steps = -1
app.state.status_dict = {"step": 0}
app.state.curr_batch = []
app.state.started = False
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)