mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
642 lines
21 KiB
Python
642 lines
21 KiB
Python
import gzip
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI, Request, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
from fastapi.responses import PlainTextResponse
|
|
from pydantic import BaseModel, field_validator
|
|
from starlette.datastructures import MutableHeaders
|
|
from starlette.types import Receive, Scope, Send
|
|
|
|
from atroposlib.api.utils import (
|
|
find_groups_summing_to_target,
|
|
grab_batch_with_minimum_allocations,
|
|
grab_exact_from_heterogeneous_queue,
|
|
)
|
|
|
|
# Constants
|
|
MIN_ENV_WEIGHT = (
|
|
0.01 # Minimum weight to prevent environments from being completely starved
|
|
)
|
|
|
|
# Message import removed - using Dict[str, Any] for more flexible validation
|
|
|
|
app = FastAPI(title="AtroposLib API")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
|
|
|
|
class GZipRequestMiddleware:
|
|
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = MutableHeaders(scope=scope)
|
|
content_encoding = headers.get("content-encoding", "")
|
|
if "gzip" not in content_encoding.lower():
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
body_chunks = []
|
|
more_body = True
|
|
while more_body:
|
|
message = await receive()
|
|
body_chunks.append(message.get("body", b""))
|
|
more_body = message.get("more_body", False)
|
|
|
|
body = b"".join(body_chunks)
|
|
if body:
|
|
try:
|
|
decompressed = gzip.decompress(body)
|
|
except OSError:
|
|
response = PlainTextResponse(
|
|
"Invalid gzip payload",
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
else:
|
|
decompressed = b""
|
|
|
|
mutable_headers = MutableHeaders(scope=scope)
|
|
mutable_headers["content-length"] = str(len(decompressed))
|
|
if "content-encoding" in mutable_headers:
|
|
del mutable_headers["content-encoding"]
|
|
|
|
sent = False
|
|
|
|
# needed some odd logic here to handle gzip stream so just returning an empty body
|
|
async def new_receive():
|
|
nonlocal sent
|
|
if sent:
|
|
return {"type": "http.request", "body": b"", "more_body": False}
|
|
sent = True
|
|
return {
|
|
"type": "http.request",
|
|
"body": decompressed,
|
|
"more_body": False,
|
|
}
|
|
|
|
await self.app(scope, new_receive, send)
|
|
|
|
|
|
app.add_middleware(GZipRequestMiddleware)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "AtroposLib API"}
|
|
|
|
|
|
class Registration(BaseModel):
|
|
wandb_group: str
|
|
wandb_project: str
|
|
batch_size: int
|
|
max_token_len: int
|
|
checkpoint_dir: str
|
|
save_checkpoint_interval: int
|
|
starting_step: int
|
|
num_steps: int
|
|
|
|
|
|
class RegisterEnv(BaseModel):
|
|
max_token_length: int
|
|
desired_name: str
|
|
weight: float
|
|
group_size: int
|
|
min_batch_allocation: Optional[float] = (
|
|
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
|
|
)
|
|
|
|
|
|
class EnvIdentifier(BaseModel):
|
|
env_id: int
|
|
|
|
|
|
class ScoredData(BaseModel):
|
|
tokens: List[List[int]]
|
|
masks: List[List[int]]
|
|
scores: List[float]
|
|
advantages: Optional[List[List[float]]] = None
|
|
ref_logprobs: Optional[List[List[float]]] = None
|
|
messages: Optional[List[List[Dict[str, Any]]]] = (
|
|
None # Changed from Message TypedDict to Dict
|
|
)
|
|
generation_params: Optional[Dict[str, Any]] = None
|
|
inference_logprobs: Optional[List[List[float]]] = None
|
|
overrides: Optional[List[dict]] = None
|
|
group_overrides: Optional[dict] = None
|
|
images: Optional[Any] = None
|
|
env_id: Optional[int] = None # ID of the environment that generated this data
|
|
# On-policy distillation (new format): parallel token ids + logprobs.
|
|
# Shape for both: [sequence][position][top_k]
|
|
distill_token_ids: Optional[List[List[List[int]]]] = None
|
|
distill_logprobs: Optional[List[List[List[float]]]] = None
|
|
|
|
@field_validator("messages", mode="before")
|
|
@classmethod
|
|
def validate_messages(cls, v):
|
|
"""Validate messages field to ensure required fields are present.
|
|
|
|
This validator only checks that messages have 'role' and 'content' fields.
|
|
The 'reward' field is completely optional.
|
|
"""
|
|
if v is None:
|
|
return None
|
|
|
|
for message_list in v:
|
|
for msg in message_list:
|
|
# Ensure the message has the required fields
|
|
if "role" not in msg or "content" not in msg:
|
|
raise ValueError("Message must have 'role' and 'content' fields")
|
|
|
|
return v
|
|
|
|
|
|
def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
|
|
"""Convert a `ScoredData` pydantic model into a plain dictionary."""
|
|
|
|
return {
|
|
"tokens": scored_data.tokens,
|
|
"masks": scored_data.masks,
|
|
"scores": scored_data.scores,
|
|
"advantages": scored_data.advantages,
|
|
"ref_logprobs": scored_data.ref_logprobs,
|
|
"messages": scored_data.messages,
|
|
"generation_params": scored_data.generation_params,
|
|
"inference_logprobs": scored_data.inference_logprobs,
|
|
"overrides": scored_data.overrides,
|
|
"group_overrides": scored_data.group_overrides,
|
|
"images": scored_data.images,
|
|
"env_id": scored_data.env_id,
|
|
"distill_token_ids": scored_data.distill_token_ids,
|
|
"distill_logprobs": scored_data.distill_logprobs,
|
|
}
|
|
|
|
|
|
def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
|
|
"""Normalize buffering/queueing logic for scored data submissions."""
|
|
|
|
if not hasattr(app.state, "queue"):
|
|
app.state.queue = []
|
|
if not hasattr(app.state, "buffer"):
|
|
app.state.buffer = {}
|
|
|
|
data_dict = _scored_data_to_dict(scored_data)
|
|
env_id = data_dict.get("env_id")
|
|
envs = getattr(app.state, "envs", [])
|
|
|
|
if env_id is not None and env_id < len(envs):
|
|
expected_group_size = envs[env_id].get("group_size", 1)
|
|
actual_group_size = len(scored_data.tokens)
|
|
|
|
if actual_group_size != expected_group_size:
|
|
buffer = app.state.buffer.setdefault(env_id, [])
|
|
buffer.append(data_dict)
|
|
|
|
indices = find_groups_summing_to_target(buffer, expected_group_size)
|
|
|
|
if indices:
|
|
groups_to_add = []
|
|
for idx in sorted(indices, reverse=True):
|
|
groups_to_add.append(buffer.pop(idx))
|
|
|
|
for group in reversed(groups_to_add):
|
|
app.state.queue.append(group)
|
|
app.state.latest = group
|
|
|
|
return {
|
|
"status": "buffered",
|
|
"buffer_size": sum(
|
|
len(group["tokens"]) for group in app.state.buffer.get(env_id, [])
|
|
),
|
|
}
|
|
|
|
app.state.queue.append(data_dict)
|
|
app.state.latest = data_dict
|
|
return {"status": "received"}
|
|
|
|
|
|
class Status(BaseModel):
|
|
"""
|
|
basemodel for status information of the current server
|
|
"""
|
|
|
|
current_step: int
|
|
queue_size: int
|
|
|
|
|
|
class Info(BaseModel):
|
|
"""
|
|
basemodel for useful information
|
|
"""
|
|
|
|
batch_size: int = -1
|
|
|
|
|
|
@app.post("/register")
|
|
async def register(registration: Registration):
|
|
# Initialize app state if not already done
|
|
if not hasattr(app.state, "queue"):
|
|
app.state.queue = []
|
|
app.state.group = registration.wandb_group
|
|
app.state.project = registration.wandb_project
|
|
app.state.batchsize = int(registration.batch_size)
|
|
app.state.max_token_len = int(registration.max_token_len)
|
|
app.state.status_dict = {"step": registration.starting_step}
|
|
app.state.checkpoint_dir = registration.checkpoint_dir
|
|
app.state.save_checkpoint_interval = registration.save_checkpoint_interval
|
|
app.state.num_steps = registration.num_steps
|
|
app.state.curr_batch = []
|
|
app.state.started = False
|
|
app.state.envs = []
|
|
app.state.buffer = {} # Buffer for mixed-size groups per environment
|
|
|
|
# Initialize requesters list if not already done
|
|
if not hasattr(app.state, "requesters"):
|
|
app.state.requesters = []
|
|
|
|
app.state.requesters.append(uuid.uuid4().int)
|
|
return {"uuid": app.state.requesters[-1]}
|
|
|
|
|
|
@app.post("/register-env")
|
|
async def register_env_url(register_env: RegisterEnv):
|
|
# Check if trainer has started
|
|
if not hasattr(app.state, "started") or not app.state.started:
|
|
return {
|
|
"status": "wait for trainer to start",
|
|
}
|
|
|
|
# Initialize envs list if not already done
|
|
if not hasattr(app.state, "envs"):
|
|
app.state.envs = []
|
|
|
|
# Get checkpoint directory safely
|
|
checkpoint_dir = getattr(app.state, "checkpoint_dir", "")
|
|
real_name = (
|
|
f"{register_env.desired_name}_"
|
|
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
|
|
)
|
|
registered_id = len(app.state.envs)
|
|
app.state.envs.append(
|
|
{
|
|
"max_context_len": register_env.max_token_length,
|
|
"weight": register_env.weight if register_env.weight is not None else 1.0,
|
|
"desired_name": register_env.desired_name,
|
|
"real_name": real_name,
|
|
"registered_id": registered_id,
|
|
"last_update": time.time(),
|
|
"connected": True,
|
|
"min_batch_allocation": register_env.min_batch_allocation,
|
|
"group_size": register_env.group_size,
|
|
}
|
|
)
|
|
return {
|
|
"status": "success",
|
|
"env_id": registered_id,
|
|
"wandb_name": real_name,
|
|
"checkpoint_dir": checkpoint_dir,
|
|
"starting_step": app.state.status_dict["step"],
|
|
"checkpoint_interval": app.state.save_checkpoint_interval,
|
|
"num_steps": app.state.num_steps,
|
|
}
|
|
|
|
|
|
@app.post("/disconnect-env")
|
|
async def disconnect_env(disconnect_env: EnvIdentifier):
|
|
try:
|
|
app.state.envs[disconnect_env.env_id]["connected"] = False
|
|
return {"status": "success"}
|
|
except (AttributeError, IndexError) as e:
|
|
return {"status": "failure", "error": str(e)}
|
|
|
|
|
|
@app.get("/wandb_info")
|
|
async def wandb_info():
|
|
try:
|
|
return {"group": app.state.group, "project": app.state.project}
|
|
except AttributeError:
|
|
return {"group": None, "project": None}
|
|
|
|
|
|
@app.get("/info")
|
|
async def info():
|
|
try:
|
|
return {
|
|
"batch_size": app.state.batchsize,
|
|
"max_token_len": app.state.max_token_len,
|
|
}
|
|
except AttributeError:
|
|
return {"batch_size": -1, "max_token_len": -1}
|
|
|
|
|
|
@app.get("/batch")
|
|
async def get_batch(request: Request):
|
|
# Check if trainer has registered first
|
|
if not hasattr(app.state, "started"):
|
|
return {
|
|
"status": "error",
|
|
"message": "Trainer not registered. Call /register first.",
|
|
"batch": [],
|
|
}
|
|
|
|
if not app.state.started:
|
|
app.state.started = True
|
|
|
|
client = request.client
|
|
client_addr = (
|
|
f"{client.host}:{client.port}" if client is not None else "unknown-client"
|
|
)
|
|
client_tag = request.headers.get("x-atropos-client", "unknown")
|
|
client_pid = request.headers.get("x-atropos-pid", "unknown")
|
|
|
|
if len(app.state.curr_batch) > 0:
|
|
curr_batch = app.state.curr_batch.pop()
|
|
logger.warning(
|
|
"API /batch returning prebuilt batch to client=%s pid=%s addr=%s: "
|
|
"groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s",
|
|
client_tag,
|
|
client_pid,
|
|
client_addr,
|
|
len(curr_batch),
|
|
sum(len(x["tokens"]) for x in curr_batch),
|
|
len(app.state.curr_batch),
|
|
len(app.state.queue),
|
|
)
|
|
return {"batch": curr_batch}
|
|
else:
|
|
new_batches = []
|
|
# Check if any envs have minimum allocations
|
|
has_min_allocations = any(
|
|
env.get("min_batch_allocation") is not None
|
|
for env in getattr(app.state, "envs", [])
|
|
)
|
|
|
|
if has_min_allocations:
|
|
batch, app.state.queue = grab_batch_with_minimum_allocations(
|
|
app.state.queue, app.state.batchsize, app.state.envs
|
|
)
|
|
else:
|
|
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
|
app.state.queue, app.state.batchsize
|
|
)
|
|
|
|
while batch is not None:
|
|
new_batches.append(batch)
|
|
if has_min_allocations:
|
|
batch, app.state.queue = grab_batch_with_minimum_allocations(
|
|
app.state.queue, app.state.batchsize, app.state.envs
|
|
)
|
|
else:
|
|
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
|
app.state.queue, app.state.batchsize
|
|
)
|
|
steps_to_take = len(new_batches)
|
|
if steps_to_take == 0:
|
|
now = time.time()
|
|
last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0)
|
|
if now - last_empty_log > 30:
|
|
logger.warning(
|
|
"API /batch no full batch ready for client=%s pid=%s addr=%s: "
|
|
"queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s",
|
|
client_tag,
|
|
client_pid,
|
|
client_addr,
|
|
len(app.state.queue),
|
|
sum(len(x.get("tokens", [])) for x in app.state.queue),
|
|
len(app.state.curr_batch),
|
|
getattr(app.state, "batchsize", -1),
|
|
)
|
|
app.state._last_empty_batch_log = now
|
|
return {"batch": None}
|
|
app.state.status_dict["step"] += steps_to_take
|
|
# chunk it
|
|
for batch in new_batches:
|
|
app.state.curr_batch.append(batch)
|
|
curr_batch = app.state.curr_batch.pop()
|
|
# check length before sending
|
|
logger.warning(
|
|
"API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s "
|
|
"with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s",
|
|
steps_to_take,
|
|
client_tag,
|
|
client_pid,
|
|
client_addr,
|
|
len(curr_batch),
|
|
sum(len(x["tokens"]) for x in curr_batch),
|
|
len(app.state.curr_batch),
|
|
len(app.state.queue),
|
|
app.state.status_dict["step"],
|
|
)
|
|
return {"batch": curr_batch}
|
|
|
|
|
|
@app.get("/latest_example")
|
|
async def get_latest_example():
|
|
try:
|
|
return app.state.latest
|
|
except AttributeError:
|
|
return {
|
|
"tokens": [],
|
|
"masks": [],
|
|
"scores": [],
|
|
"advantages": [],
|
|
"ref_logprobs": [],
|
|
"generation_params": [],
|
|
"inference_logprobs": [],
|
|
"messages": [],
|
|
"images": [],
|
|
}
|
|
|
|
|
|
@app.post("/scored_data")
|
|
async def scored_data(scored_data: ScoredData):
|
|
return _process_scored_data(scored_data)
|
|
|
|
|
|
@app.post("/scored_data_list")
|
|
async def scored_data_list(scored_data_list: List[ScoredData]):
|
|
"""Handle a list of ScoredData objects for step-based learning"""
|
|
|
|
# Process each scored data item
|
|
buffered_count = 0
|
|
last_buffer_size: Optional[int] = None
|
|
for scored_data in scored_data_list:
|
|
result = _process_scored_data(scored_data)
|
|
if result.get("status") == "buffered":
|
|
buffered_count += 1
|
|
last_buffer_size = result.get("buffer_size", last_buffer_size)
|
|
|
|
response: Dict[str, Any] = {
|
|
"status": "received",
|
|
"groups_processed": len(scored_data_list),
|
|
}
|
|
|
|
if buffered_count:
|
|
response["buffered"] = buffered_count
|
|
if last_buffer_size is not None:
|
|
response["last_buffer_size"] = last_buffer_size
|
|
|
|
return response
|
|
|
|
|
|
@app.get("/status")
|
|
async def get_status():
|
|
try:
|
|
return {
|
|
"current_step": app.state.status_dict["step"],
|
|
"queue_size": len(app.state.queue),
|
|
}
|
|
except AttributeError:
|
|
return {"current_step": 0, "queue_size": 0}
|
|
|
|
|
|
@app.get("/status-env")
|
|
async def get_status_env(env: EnvIdentifier):
|
|
total = sum(
|
|
[
|
|
x["max_context_len"] * max(0.0, x["weight"])
|
|
for x in app.state.envs
|
|
if x["connected"]
|
|
]
|
|
)
|
|
env_group_size = app.state.envs[env.env_id]["group_size"]
|
|
env_weight = (
|
|
app.state.envs[env.env_id]["max_context_len"]
|
|
* app.state.envs[env.env_id]["weight"]
|
|
/ total
|
|
)
|
|
env_weight = max(
|
|
MIN_ENV_WEIGHT, env_weight
|
|
) # Ensure minimum weight to prevent environment starvation
|
|
|
|
# Calculate total minimum allocations
|
|
total_min_allocation = 0.0
|
|
for env_config in app.state.envs:
|
|
if (
|
|
env_config.get("connected", False)
|
|
and env_config.get("min_batch_allocation") is not None
|
|
):
|
|
total_min_allocation += env_config["min_batch_allocation"]
|
|
|
|
# Calculate unallocated fraction
|
|
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
|
|
|
|
# Find the maximum group size across all items in queue
|
|
queue = getattr(app.state, "queue", [])
|
|
max_group_size = 1
|
|
num_self_sequences_in_queue = 0
|
|
for item in queue:
|
|
group_size = len(item.get("tokens", []))
|
|
if group_size > max_group_size:
|
|
max_group_size = group_size
|
|
if item.get("env_id") == env.env_id:
|
|
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
|
|
env_group_size = max(env_group_size, group_size)
|
|
num_self_sequences_in_queue += group_size
|
|
|
|
# update the group size for the requesting env
|
|
app.state.envs[env.env_id]["group_size"] = env_group_size
|
|
|
|
# Calculate minimum sequences allocated to each environment
|
|
batch_size = getattr(app.state, "batchsize", 0)
|
|
min_sequences_by_env = {}
|
|
for env_config in app.state.envs:
|
|
if (
|
|
env_config.get("connected", False)
|
|
and env_config.get("min_batch_allocation") is not None
|
|
):
|
|
env_id = env_config["registered_id"]
|
|
min_sequences = int(batch_size * env_config["min_batch_allocation"])
|
|
min_sequences_by_env[env_id] = min_sequences
|
|
|
|
# Count sequences and calculate packed groups for each environment
|
|
import math
|
|
|
|
sequences_by_env = {}
|
|
packed_groups_by_env = {}
|
|
curr_env_total_sequences = 0
|
|
|
|
for item in queue:
|
|
env_id = item.get("env_id")
|
|
seq_count = len(item.get("tokens", []))
|
|
|
|
# Special handling for the requesting environment
|
|
if env_id == env.env_id:
|
|
curr_env_total_sequences += seq_count
|
|
else:
|
|
if env_id not in sequences_by_env:
|
|
sequences_by_env[env_id] = 0
|
|
sequences_by_env[env_id] += seq_count
|
|
|
|
# Calculate packed groups for each environment (excluding the requesting env)
|
|
if max_group_size > 1:
|
|
for env_id, seq_count in sequences_by_env.items():
|
|
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
|
|
|
|
# Calculate adjusted queue size
|
|
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
|
|
available_from_others = 0
|
|
for env_id in packed_groups_by_env:
|
|
packed_sequences = packed_groups_by_env[env_id] * max_group_size
|
|
min_sequences = min_sequences_by_env.get(env_id, 0)
|
|
available_from_others += max(0, packed_sequences - min_sequences)
|
|
|
|
env_queue_size = curr_env_total_sequences + available_from_others
|
|
|
|
try:
|
|
ret_dict = {
|
|
"current_step": app.state.status_dict["step"],
|
|
"queue_size": env_queue_size // env_group_size,
|
|
"unallocated_fraction": unallocated_fraction,
|
|
"self_queue_size": num_self_sequences_in_queue // env_group_size,
|
|
"max_group_size": max_group_size,
|
|
}
|
|
except AttributeError:
|
|
ret_dict = {
|
|
"current_step": 0,
|
|
"queue_size": 0,
|
|
"unallocated_fraction": 1.0,
|
|
"num_self_sequences_in_queue": 0,
|
|
}
|
|
ret_dict["env_weight"] = env_weight
|
|
return ret_dict
|
|
|
|
|
|
@app.get("/reset_data")
|
|
async def reset_data():
|
|
try:
|
|
del app.state.queue
|
|
app.state.group = None
|
|
app.state.project = None
|
|
app.state.batchsize = -1
|
|
app.state.num_steps = -1
|
|
app.state.status_dict = {"step": 0}
|
|
app.state.curr_batch = []
|
|
app.state.started = False
|
|
app.state.requesters = []
|
|
app.state.envs = []
|
|
app.state.buffer = {}
|
|
except KeyError:
|
|
pass
|
|
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)
|