atropos/atroposlib/api/server.py

538 lines
18 KiB
Python

import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
grab_exact_from_heterogeneous_queue,
)
# Message import removed - using Dict[str, Any] for more flexible validation
app = FastAPI(title="AtroposLib API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "AtroposLib API"}
class Registration(BaseModel):
wandb_group: str
wandb_project: str
batch_size: int
max_token_len: int
checkpoint_dir: str
save_checkpoint_interval: int
starting_step: int
num_steps: int
class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
group_size: int
min_batch_allocation: Optional[float] = (
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
)
class EnvIdentifier(BaseModel):
env_id: int
class ScoredData(BaseModel):
tokens: List[List[int]]
masks: List[List[int]]
scores: List[float]
advantages: Optional[List[List[float]]] = None
ref_logprobs: Optional[List[List[float]]] = None
messages: Optional[List[List[Dict[str, Any]]]] = (
None # Changed from Message TypedDict to Dict
)
generation_params: Optional[Dict[str, Any]] = None
inference_logprobs: Optional[List[List[float]]] = None
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
env_id: Optional[int] = None # ID of the environment that generated this data
@field_validator("messages", mode="before")
@classmethod
def validate_messages(cls, v):
"""Validate messages field to ensure required fields are present.
This validator only checks that messages have 'role' and 'content' fields.
The 'reward' field is completely optional.
"""
if v is None:
return None
for message_list in v:
for msg in message_list:
# Ensure the message has the required fields
if "role" not in msg or "content" not in msg:
raise ValueError("Message must have 'role' and 'content' fields")
return v
class Status(BaseModel):
"""
basemodel for status information of the current server
"""
current_step: int
queue_size: int
class Info(BaseModel):
"""
basemodel for useful information
"""
batch_size: int = -1
@app.post("/register")
async def register(registration: Registration):
try:
isinstance(app.state.queue, list)
except AttributeError:
app.state.queue = []
app.state.group = registration.wandb_group
app.state.project = registration.wandb_project
app.state.batchsize = int(registration.batch_size)
app.state.max_token_len = int(registration.max_token_len)
app.state.status_dict = {"step": registration.starting_step}
app.state.checkpoint_dir = registration.checkpoint_dir
app.state.save_checkpoint_interval = registration.save_checkpoint_interval
app.state.num_steps = registration.num_steps
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
# If requesters doesn't exist, create it
app.state.requesters = [uuid.uuid4().int]
return {"uuid": app.state.requesters[-1]}
@app.post("/register-env")
async def register_env_url(register_env: RegisterEnv):
try:
if not app.state.started:
return {
"status": "wait for trainer to start",
}
except AttributeError:
return {
"status": "wait for trainer to start",
}
try:
isinstance(app.state.envs, list)
except AttributeError:
app.state.envs = []
checkpoint_dir = ""
try:
checkpoint_dir = app.state.checkpoint_dir
except AttributeError:
pass
real_name = (
f"{register_env.desired_name}_"
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
)
registered_id = len(app.state.envs)
app.state.envs.append(
{
"max_context_len": register_env.max_token_length,
"weight": register_env.weight if register_env.weight is not None else 1.0,
"desired_name": register_env.desired_name,
"real_name": real_name,
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
}
)
return {
"status": "success",
"env_id": registered_id,
"wandb_name": real_name,
"checkpoint_dir": checkpoint_dir,
"starting_step": app.state.status_dict["step"],
"checkpoint_interval": app.state.save_checkpoint_interval,
"num_steps": app.state.num_steps,
}
@app.post("/disconnect-env")
async def disconnect_env(disconnect_env: EnvIdentifier):
try:
app.state.envs[disconnect_env.env_id]["connected"] = False
return {"status": "success"}
except (AttributeError, IndexError) as e:
return {"status": "failure", "error": str(e)}
@app.get("/wandb_info")
async def wandb_info():
try:
return {"group": app.state.group, "project": app.state.project}
except AttributeError:
return {"group": None, "project": None}
@app.get("/info")
async def info():
try:
return {
"batch_size": app.state.batchsize,
"max_token_len": app.state.max_token_len,
}
except AttributeError:
return {"batch_size": -1, "max_token_len": -1}
@app.get("/batch")
async def get_batch():
if not app.state.started:
app.state.started = True
if len(app.state.curr_batch) > 0:
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
# Check if any envs have minimum allocations
has_min_allocations = any(
env.get("min_batch_allocation") is not None
for env in getattr(app.state, "envs", [])
)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
for batch in new_batches:
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
return {"batch": curr_batch}
@app.get("/latest_example")
async def get_latest_example():
try:
return app.state.latest
except AttributeError:
return {
"tokens": [],
"masks": [],
"scores": [],
"advantages": [],
"ref_logprobs": [],
"inference_logprobs": [],
"messages": [],
"images": [],
}
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
@app.post("/scored_data_list")
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
# Process each scored data item
for scored_data in scored_data_list:
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
else:
# Normal size - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
else:
# No env info or normal path - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received", "groups_processed": len(scored_data_list)}
@app.get("/status")
async def get_status():
try:
return {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
}
except AttributeError:
return {"current_step": 0, "queue_size": 0}
@app.get("/status-env")
async def get_status_env(env: EnvIdentifier):
total = sum(
[
x["max_context_len"] * max(0.0, x["weight"])
for x in app.state.envs
if x["connected"]
]
)
env_group_size = app.state.envs[env.env_id]["group_size"]
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
/ total
)
env_weight = max(
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
# Calculate total minimum allocations
total_min_allocation = 0.0
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
total_min_allocation += env_config["min_batch_allocation"]
# Calculate unallocated fraction
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
# Find the maximum group size across all items in queue
queue = getattr(app.state, "queue", [])
max_group_size = 1
num_self_sequences_in_queue = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
if item.get("env_id") == env.env_id:
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
env_group_size = max(env_group_size, group_size)
num_self_sequences_in_queue += group_size
# update the group size for the requesting env
app.state.envs[env.env_id]["group_size"] = env_group_size
# Calculate minimum sequences allocated to each environment
batch_size = getattr(app.state, "batchsize", 0)
min_sequences_by_env = {}
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
env_id = env_config["registered_id"]
min_sequences = int(batch_size * env_config["min_batch_allocation"])
min_sequences_by_env[env_id] = min_sequences
# Count sequences and calculate packed groups for each environment
import math
sequences_by_env = {}
packed_groups_by_env = {}
curr_env_total_sequences = 0
for item in queue:
env_id = item.get("env_id")
seq_count = len(item.get("tokens", []))
# Special handling for the requesting environment
if env_id == env.env_id:
curr_env_total_sequences += seq_count
else:
if env_id not in sequences_by_env:
sequences_by_env[env_id] = 0
sequences_by_env[env_id] += seq_count
# Calculate packed groups for each environment (excluding the requesting env)
if max_group_size > 1:
for env_id, seq_count in sequences_by_env.items():
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
# Calculate adjusted queue size
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
available_from_others = 0
for env_id in packed_groups_by_env:
packed_sequences = packed_groups_by_env[env_id] * max_group_size
min_sequences = min_sequences_by_env.get(env_id, 0)
available_from_others += max(0, packed_sequences - min_sequences)
env_queue_size = curr_env_total_sequences + available_from_others
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": env_queue_size // env_group_size,
"unallocated_fraction": unallocated_fraction,
"self_queue_size": num_self_sequences_in_queue // env_group_size,
"max_group_size": max_group_size,
}
except AttributeError:
ret_dict = {
"current_step": 0,
"queue_size": 0,
"unallocated_fraction": 1.0,
"num_self_sequences_in_queue": 0,
}
ret_dict["env_weight"] = env_weight
return ret_dict
@app.get("/reset_data")
async def reset_data():
try:
del app.state.queue
app.state.group = None
app.state.project = None
app.state.batchsize = -1
app.state.num_steps = -1
app.state.status_dict = {"step": 0}
app.state.curr_batch = []
app.state.started = False
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)