import gzip import logging import time import uuid from typing import Any, Dict, List, Optional from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import PlainTextResponse from pydantic import BaseModel, field_validator from starlette.datastructures import MutableHeaders from starlette.types import Receive, Scope, Send from atroposlib.api.utils import ( find_groups_summing_to_target, grab_batch_with_minimum_allocations, grab_exact_from_heterogeneous_queue, ) # Constants MIN_ENV_WEIGHT = ( 0.01 # Minimum weight to prevent environments from being completely starved ) # Message import removed - using Dict[str, Any] for more flexible validation app = FastAPI(title="AtroposLib API") logger = logging.getLogger(__name__) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000) class GZipRequestMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": await self.app(scope, receive, send) return headers = MutableHeaders(scope=scope) content_encoding = headers.get("content-encoding", "") if "gzip" not in content_encoding.lower(): await self.app(scope, receive, send) return body_chunks = [] more_body = True while more_body: message = await receive() body_chunks.append(message.get("body", b"")) more_body = message.get("more_body", False) body = b"".join(body_chunks) if body: try: decompressed = gzip.decompress(body) except OSError: response = PlainTextResponse( "Invalid gzip payload", status_code=status.HTTP_400_BAD_REQUEST, ) await response(scope, receive, send) return else: decompressed = b"" mutable_headers = MutableHeaders(scope=scope) mutable_headers["content-length"] = str(len(decompressed)) if "content-encoding" in mutable_headers: del mutable_headers["content-encoding"] sent = False # needed some odd logic here to handle gzip stream so just returning an empty body async def new_receive(): nonlocal sent if sent: return {"type": "http.request", "body": b"", "more_body": False} sent = True return { "type": "http.request", "body": decompressed, "more_body": False, } await self.app(scope, new_receive, send) app.add_middleware(GZipRequestMiddleware) @app.get("/") async def root(): return {"message": "AtroposLib API"} class Registration(BaseModel): wandb_group: str wandb_project: str batch_size: int max_token_len: int checkpoint_dir: str save_checkpoint_interval: int starting_step: int num_steps: int class RegisterEnv(BaseModel): max_token_length: int desired_name: str weight: float group_size: int min_batch_allocation: Optional[float] = ( None # Minimum proportion of a batch this env should be allocated (0.0-1.0) ) class EnvIdentifier(BaseModel): env_id: int class ScoredData(BaseModel): tokens: List[List[int]] masks: List[List[int]] scores: List[float] advantages: Optional[List[List[float]]] = None ref_logprobs: Optional[List[List[float]]] = None messages: Optional[List[List[Dict[str, Any]]]] = ( None # Changed from Message TypedDict to Dict ) generation_params: Optional[Dict[str, Any]] = None inference_logprobs: Optional[List[List[float]]] = None overrides: Optional[List[dict]] = None group_overrides: Optional[dict] = None images: Optional[Any] = None env_id: Optional[int] = None # ID of the environment that generated this data # On-policy distillation (new format): parallel token ids + logprobs. # Shape for both: [sequence][position][top_k] distill_token_ids: Optional[List[List[List[int]]]] = None distill_logprobs: Optional[List[List[List[float]]]] = None @field_validator("messages", mode="before") @classmethod def validate_messages(cls, v): """Validate messages field to ensure required fields are present. This validator only checks that messages have 'role' and 'content' fields. The 'reward' field is completely optional. """ if v is None: return None for message_list in v: for msg in message_list: # Ensure the message has the required fields if "role" not in msg or "content" not in msg: raise ValueError("Message must have 'role' and 'content' fields") return v def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]: """Convert a `ScoredData` pydantic model into a plain dictionary.""" return { "tokens": scored_data.tokens, "masks": scored_data.masks, "scores": scored_data.scores, "advantages": scored_data.advantages, "ref_logprobs": scored_data.ref_logprobs, "messages": scored_data.messages, "generation_params": scored_data.generation_params, "inference_logprobs": scored_data.inference_logprobs, "overrides": scored_data.overrides, "group_overrides": scored_data.group_overrides, "images": scored_data.images, "env_id": scored_data.env_id, "distill_token_ids": scored_data.distill_token_ids, "distill_logprobs": scored_data.distill_logprobs, } def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]: """Normalize buffering/queueing logic for scored data submissions.""" if not hasattr(app.state, "queue"): app.state.queue = [] if not hasattr(app.state, "buffer"): app.state.buffer = {} data_dict = _scored_data_to_dict(scored_data) env_id = data_dict.get("env_id") envs = getattr(app.state, "envs", []) if env_id is not None and env_id < len(envs): expected_group_size = envs[env_id].get("group_size", 1) actual_group_size = len(scored_data.tokens) if actual_group_size != expected_group_size: buffer = app.state.buffer.setdefault(env_id, []) buffer.append(data_dict) indices = find_groups_summing_to_target(buffer, expected_group_size) if indices: groups_to_add = [] for idx in sorted(indices, reverse=True): groups_to_add.append(buffer.pop(idx)) for group in reversed(groups_to_add): app.state.queue.append(group) app.state.latest = group return { "status": "buffered", "buffer_size": sum( len(group["tokens"]) for group in app.state.buffer.get(env_id, []) ), } app.state.queue.append(data_dict) app.state.latest = data_dict return {"status": "received"} class Status(BaseModel): """ basemodel for status information of the current server """ current_step: int queue_size: int class Info(BaseModel): """ basemodel for useful information """ batch_size: int = -1 @app.post("/register") async def register(registration: Registration): # Initialize app state if not already done if not hasattr(app.state, "queue"): app.state.queue = [] app.state.group = registration.wandb_group app.state.project = registration.wandb_project app.state.batchsize = int(registration.batch_size) app.state.max_token_len = int(registration.max_token_len) app.state.status_dict = {"step": registration.starting_step} app.state.checkpoint_dir = registration.checkpoint_dir app.state.save_checkpoint_interval = registration.save_checkpoint_interval app.state.num_steps = registration.num_steps app.state.curr_batch = [] app.state.started = False app.state.envs = [] app.state.buffer = {} # Buffer for mixed-size groups per environment # Initialize requesters list if not already done if not hasattr(app.state, "requesters"): app.state.requesters = [] app.state.requesters.append(uuid.uuid4().int) return {"uuid": app.state.requesters[-1]} @app.post("/register-env") async def register_env_url(register_env: RegisterEnv): # Check if trainer has started if not hasattr(app.state, "started") or not app.state.started: return { "status": "wait for trainer to start", } # Initialize envs list if not already done if not hasattr(app.state, "envs"): app.state.envs = [] # Get checkpoint directory safely checkpoint_dir = getattr(app.state, "checkpoint_dir", "") real_name = ( f"{register_env.desired_name}_" f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}" ) registered_id = len(app.state.envs) app.state.envs.append( { "max_context_len": register_env.max_token_length, "weight": register_env.weight if register_env.weight is not None else 1.0, "desired_name": register_env.desired_name, "real_name": real_name, "registered_id": registered_id, "last_update": time.time(), "connected": True, "min_batch_allocation": register_env.min_batch_allocation, "group_size": register_env.group_size, } ) return { "status": "success", "env_id": registered_id, "wandb_name": real_name, "checkpoint_dir": checkpoint_dir, "starting_step": app.state.status_dict["step"], "checkpoint_interval": app.state.save_checkpoint_interval, "num_steps": app.state.num_steps, } @app.post("/disconnect-env") async def disconnect_env(disconnect_env: EnvIdentifier): try: app.state.envs[disconnect_env.env_id]["connected"] = False return {"status": "success"} except (AttributeError, IndexError) as e: return {"status": "failure", "error": str(e)} @app.get("/wandb_info") async def wandb_info(): try: return {"group": app.state.group, "project": app.state.project} except AttributeError: return {"group": None, "project": None} @app.get("/info") async def info(): try: return { "batch_size": app.state.batchsize, "max_token_len": app.state.max_token_len, } except AttributeError: return {"batch_size": -1, "max_token_len": -1} @app.get("/batch") async def get_batch(): # Check if trainer has registered first if not hasattr(app.state, "started"): return { "status": "error", "message": "Trainer not registered. Call /register first.", "batch": [], } if not app.state.started: app.state.started = True if len(app.state.curr_batch) > 0: return {"batch": app.state.curr_batch.pop()} else: new_batches = [] # Check if any envs have minimum allocations has_min_allocations = any( env.get("min_batch_allocation") is not None for env in getattr(app.state, "envs", []) ) if has_min_allocations: batch, app.state.queue = grab_batch_with_minimum_allocations( app.state.queue, app.state.batchsize, app.state.envs ) else: batch, app.state.queue = grab_exact_from_heterogeneous_queue( app.state.queue, app.state.batchsize ) while batch is not None: new_batches.append(batch) if has_min_allocations: batch, app.state.queue = grab_batch_with_minimum_allocations( app.state.queue, app.state.batchsize, app.state.envs ) else: batch, app.state.queue = grab_exact_from_heterogeneous_queue( app.state.queue, app.state.batchsize ) steps_to_take = len(new_batches) if steps_to_take == 0: return {"batch": None} app.state.status_dict["step"] += steps_to_take # chunk it for batch in new_batches: app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending logger.info( "Sending batch of %s sequences", sum(len(x["tokens"]) for x in curr_batch), ) return {"batch": curr_batch} @app.get("/latest_example") async def get_latest_example(): try: return app.state.latest except AttributeError: return { "tokens": [], "masks": [], "scores": [], "advantages": [], "ref_logprobs": [], "generation_params": [], "inference_logprobs": [], "messages": [], "images": [], } @app.post("/scored_data") async def scored_data(scored_data: ScoredData): return _process_scored_data(scored_data) @app.post("/scored_data_list") async def scored_data_list(scored_data_list: List[ScoredData]): """Handle a list of ScoredData objects for step-based learning""" # Process each scored data item buffered_count = 0 last_buffer_size: Optional[int] = None for scored_data in scored_data_list: result = _process_scored_data(scored_data) if result.get("status") == "buffered": buffered_count += 1 last_buffer_size = result.get("buffer_size", last_buffer_size) response: Dict[str, Any] = { "status": "received", "groups_processed": len(scored_data_list), } if buffered_count: response["buffered"] = buffered_count if last_buffer_size is not None: response["last_buffer_size"] = last_buffer_size return response @app.get("/status") async def get_status(): try: return { "current_step": app.state.status_dict["step"], "queue_size": len(app.state.queue), } except AttributeError: return {"current_step": 0, "queue_size": 0} @app.get("/status-env") async def get_status_env(env: EnvIdentifier): total = sum( [ x["max_context_len"] * max(0.0, x["weight"]) for x in app.state.envs if x["connected"] ] ) env_group_size = app.state.envs[env.env_id]["group_size"] env_weight = ( app.state.envs[env.env_id]["max_context_len"] * app.state.envs[env.env_id]["weight"] / total ) env_weight = max( MIN_ENV_WEIGHT, env_weight ) # Ensure minimum weight to prevent environment starvation # Calculate total minimum allocations total_min_allocation = 0.0 for env_config in app.state.envs: if ( env_config.get("connected", False) and env_config.get("min_batch_allocation") is not None ): total_min_allocation += env_config["min_batch_allocation"] # Calculate unallocated fraction unallocated_fraction = 1.0 - min(total_min_allocation, 1.0) # Find the maximum group size across all items in queue queue = getattr(app.state, "queue", []) max_group_size = 1 num_self_sequences_in_queue = 0 for item in queue: group_size = len(item.get("tokens", [])) if group_size > max_group_size: max_group_size = group_size if item.get("env_id") == env.env_id: # update the group size for the requesting env, handle cases where the group size may be dynamic with max env_group_size = max(env_group_size, group_size) num_self_sequences_in_queue += group_size # update the group size for the requesting env app.state.envs[env.env_id]["group_size"] = env_group_size # Calculate minimum sequences allocated to each environment batch_size = getattr(app.state, "batchsize", 0) min_sequences_by_env = {} for env_config in app.state.envs: if ( env_config.get("connected", False) and env_config.get("min_batch_allocation") is not None ): env_id = env_config["registered_id"] min_sequences = int(batch_size * env_config["min_batch_allocation"]) min_sequences_by_env[env_id] = min_sequences # Count sequences and calculate packed groups for each environment import math sequences_by_env = {} packed_groups_by_env = {} curr_env_total_sequences = 0 for item in queue: env_id = item.get("env_id") seq_count = len(item.get("tokens", [])) # Special handling for the requesting environment if env_id == env.env_id: curr_env_total_sequences += seq_count else: if env_id not in sequences_by_env: sequences_by_env[env_id] = 0 sequences_by_env[env_id] += seq_count # Calculate packed groups for each environment (excluding the requesting env) if max_group_size > 1: for env_id, seq_count in sequences_by_env.items(): packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size) # Calculate adjusted queue size # (curr_env_total_sequences + sum of available sequences from other envs after their minimums) available_from_others = 0 for env_id in packed_groups_by_env: packed_sequences = packed_groups_by_env[env_id] * max_group_size min_sequences = min_sequences_by_env.get(env_id, 0) available_from_others += max(0, packed_sequences - min_sequences) env_queue_size = curr_env_total_sequences + available_from_others try: ret_dict = { "current_step": app.state.status_dict["step"], "queue_size": env_queue_size // env_group_size, "unallocated_fraction": unallocated_fraction, "self_queue_size": num_self_sequences_in_queue // env_group_size, "max_group_size": max_group_size, } except AttributeError: ret_dict = { "current_step": 0, "queue_size": 0, "unallocated_fraction": 1.0, "num_self_sequences_in_queue": 0, } ret_dict["env_weight"] = env_weight return ret_dict @app.get("/reset_data") async def reset_data(): try: del app.state.queue app.state.group = None app.state.project = None app.state.batchsize = -1 app.state.num_steps = -1 app.state.status_dict = {"step": 0} app.state.curr_batch = [] app.state.started = False app.state.requesters = [] app.state.envs = [] app.state.buffer = {} except KeyError: pass return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)