mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
182
atroposlib/api/README.md
Normal file
182
atroposlib/api/README.md
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
# Trajectory Handler API
|
||||
|
||||
## Overview
|
||||
|
||||
The AtroposLib API is a FastAPI application designed to act as a central buffer and aggregator for reinforcement learning (RL) experience data. Its primary purpose is to decouple RL data generation (by "Rollout Handlers" or "Environments") from RL data consumption (by one or more "Trainers"), particularly in distributed online RL settings.
|
||||
|
||||
This service specifically handles the **experience data pathway**:
|
||||
|
||||
* Rollout Handlers connect and push trajectories (tokens, masks, scores, etc.).
|
||||
* The API buffers this data in a queue.
|
||||
* Trainers connect and pull processed batches of experience data for training updates.
|
||||
|
||||
**Important:** This service does *not* handle the distribution of updated policies from the Trainer back to the Rollout Handlers/Inference Servers. That part of the online RL loop is assumed to be handled by a separate mechanism.
|
||||
|
||||
## Features
|
||||
|
||||
* Centralized, in-memory queue for RL trajectory data.
|
||||
* Registration endpoints for Trainers and Rollout Handlers.
|
||||
* Serves batches of aggregated experience data to Trainers.
|
||||
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
|
||||
* Provides status endpoints for monitoring queue size and training step count.
|
||||
* Basic integration with Weights & Biases (W&B) project/group info.
|
||||
* Endpoints for Rollout Handlers to disconnect gracefully.
|
||||
* Debug endpoint to retrieve the latest submitted data sample.
|
||||
|
||||
## Architecture Context
|
||||
|
||||
This API typically sits within a larger RL system:
|
||||
|
||||
1. **Rollout Handlers:** Instances simulating the environment. They interact with Inference Servers to get actions based on the current policy and send resulting trajectory data (`ScoredData`) to this AtroposLib API (`/scored_data`).
|
||||
2. **Inference Servers (External):** Serve the current policy (e.g., via an OpenAI-compatible API). Receive policy updates directly from the Trainer. *Not part of this service.*
|
||||
3. **AtroposLib API (This Service):** Buffers and batches experience data received from Rollout Handlers.
|
||||
4. **Trainer(s):** Pull batches of experience data from this API (`/batch`), compute gradients, update the policy, and push updated policies directly to the Inference Servers.
|
||||
|
||||
|
||||
## Running the Server
|
||||
|
||||
with the repository installed we provide a helper script to run the server:
|
||||
|
||||
```bash
|
||||
run-api
|
||||
```
|
||||
if you need more control over the server you can run it directly with:
|
||||
|
||||
```bash
|
||||
uvicorn atroposlib.api.server:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
* `--host 0.0.0.0`: Makes the server accessible on your network.
|
||||
* `--port 8000`: Specifies the port (change if needed).
|
||||
* `--reload`: Enables auto-reloading on code changes (for development). Remove for production.
|
||||
|
||||
The API documentation (Swagger UI) will be available at `http://<your-server-ip>:8000/docs`.
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### General
|
||||
|
||||
* `GET /`
|
||||
* **Description:** Root endpoint for basic health check.
|
||||
* **Response:** `{"message": "AtroposLib API"}`
|
||||
|
||||
### Trainer Registration & Info
|
||||
|
||||
* `POST /register`
|
||||
* **Description:** Called once by the Trainer process to initialize the server state for a training run. Resets state if called again.
|
||||
* **Request Body:** `Registration` model
|
||||
```python
|
||||
class Registration(BaseModel):
|
||||
wandb_group: str
|
||||
wandb_project: str
|
||||
batch_size: int
|
||||
max_token_len: int # Max token length expected in trajectories
|
||||
checkpoint_dir: str # Shared location for checkpoints
|
||||
save_checkpoint_interval: int
|
||||
starting_step: int
|
||||
num_steps: int # Total expected training steps
|
||||
```
|
||||
* **Response:** `{"uuid": <generated_uuid_int>}`
|
||||
* `GET /wandb_info`
|
||||
* **Description:** Retrieve W&B group and project info set during registration.
|
||||
* **Response:** `{"group": <group_name_or_null>, "project": <project_name_or_null>}`
|
||||
* `GET /info`
|
||||
* **Description:** Retrieve batch size and max token length set during registration.
|
||||
* **Response:** `{"batch_size": <size_or_-1>, "max_token_len": <len_or_-1>}`
|
||||
* `GET /status`
|
||||
* **Description:** Get the current training step (based on batches served) and queue size.
|
||||
* **Response:** `{"current_step": <step_count>, "queue_size": <queue_length>}`
|
||||
|
||||
### Rollout Handler Registration & Info
|
||||
|
||||
* `POST /register-env`
|
||||
* **Description:** Called by each Rollout Handler instance to register itself.
|
||||
* **Request Body:** `RegisterEnv` model
|
||||
```python
|
||||
class RegisterEnv(BaseModel):
|
||||
max_token_length: int # Max length this env produces
|
||||
desired_name: str # Base name for identification/logging
|
||||
weight: float # Weight for sampling/batching (e.g., 1.0)
|
||||
```
|
||||
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"env_id": <assigned_env_id_int>,
|
||||
"wandb_name": <generated_unique_name>,
|
||||
"checkpoint_dir": <checkpoint_dir_from_registration>,
|
||||
"starting_step": <current_server_step>,
|
||||
"checkpoint_interval": <interval_from_registration>,
|
||||
"num_steps": <num_steps_from_registration>
|
||||
}
|
||||
```
|
||||
* `POST /disconnect-env`
|
||||
* **Description:** Allows a Rollout Handler to signal it's disconnecting gracefully.
|
||||
* **Request Body:** `EnvIdentifier` model `{"env_id": <registered_env_id_int>}`
|
||||
* **Response:** `{"status": "success"}` or `{"status": "failure", "error": ...}`
|
||||
* `GET /status-env`
|
||||
* **Description:** Called by a Rollout Handler to get general status plus its calculated sampling weight relative to other connected environments.
|
||||
* **Query Parameter:** Requires `env: EnvIdentifier` model (e.g., `?env_id=0` - actual implementation might differ slightly, check FastAPI docs for query parameter models). **Note:** The code shows `env: EnvIdentifier` as a body parameter for a GET request, which is non-standard. This might need adjustment or testing. Assuming it works via query or a POST instead.
|
||||
* **Response:** `{"current_step": <step>, "queue_size": <size>, "env_weight": <calculated_weight_float>}`
|
||||
|
||||
### Data Handling
|
||||
|
||||
* `POST /scored_data`
|
||||
* **Description:** Endpoint for Rollout Handlers to push a single chunk of trajectory data.
|
||||
* **Request Body:** `ScoredData` model
|
||||
```python
|
||||
class ScoredData(BaseModel):
|
||||
tokens: List[List[int]]
|
||||
masks: List[List[int]]
|
||||
scores: List[float]
|
||||
ref_logprobs: Optional[List[List[float]]] = None
|
||||
overrides: Optional[List[dict]] = None # Per-item logging overrides
|
||||
group_overrides: Optional[dict] = None # Group logging overrides
|
||||
```
|
||||
* **Response:** `{"status": "received"}`
|
||||
* `POST /scored_data_list`
|
||||
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
|
||||
* **Request Body:** `List[ScoredData]`
|
||||
* **Response:** `{"status": "received", "groups_processed": <count>}`
|
||||
* `GET /batch`
|
||||
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
|
||||
* **Response:**
|
||||
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
|
||||
* Not enough data: `{"batch": null}`
|
||||
* `GET /latest_example`
|
||||
* **Description:** Debug endpoint to retrieve the most recently added `ScoredData` item.
|
||||
* **Response:** The last `ScoredData` dictionary pushed, or empty lists if none yet.
|
||||
|
||||
### Debugging
|
||||
|
||||
* `GET /reset_data`
|
||||
* **Description:** **Warning:** Resets all server state, including the queue, configuration, registered environments, and step count. Use with caution during development/debugging.
|
||||
* **Response:** Plain text `Reset successful` with HTTP status 200.
|
||||
|
||||
## Common Workflow Example
|
||||
|
||||
1. **Start Server:** Launch the `AtroposLib` API server.
|
||||
2. **Trainer Initialization:** The main Trainer process sends a `POST /register` request with run parameters.
|
||||
3. **Rollout Handler Initialization:** Each Rollout Handler starts and sends `POST /register-env`.
|
||||
4. **Data Generation:** Handlers run simulations, collect data, and send `POST /scored_data` or `POST /scored_data_list` periodically.
|
||||
5. **Training Loop:**
|
||||
* The Trainer (e.g., Rank 0 in distributed setup) enters a loop:
|
||||
* Calls `GET /batch`.
|
||||
* If `batch` is not `null`:
|
||||
* (Distribute batch to other ranks if applicable).
|
||||
* Perform training step.
|
||||
* Optionally call `GET /status` for monitoring.
|
||||
* If `batch` is `null`:
|
||||
* Wait briefly (`time.sleep`) and retry `GET /batch`.
|
||||
* mermaid diagram of how a trainer interacts with the api is located [here](trainer_interaction.md).
|
||||
* (In distributed setups, other ranks (1..N-1) might poll `GET /status` to wait for the step counter to increment before expecting the broadcasted batch from Rank 0).
|
||||
* The envs periodically poll `GET /status-env` to check their status and sampling weight.
|
||||
* In asynchronous setups, they may stop at a maximum off-policy step count.
|
||||
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
|
||||
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
|
||||
|
||||
## Limitations & TODOs
|
||||
|
||||
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).
|
||||
* **No Persistence:** Data is lost if the server restarts.
|
||||
* **Scalability Bottleneck:** API cannot scale beyond a single server instance easily.
|
||||
3
atroposlib/api/__init__.py
Normal file
3
atroposlib/api/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .server import app
|
||||
|
||||
__all__ = ["app"]
|
||||
70
atroposlib/api/env_interaction.md
Normal file
70
atroposlib/api/env_interaction.md
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
```mermaid
|
||||
sequenceDiagram
|
||||
participant RH as Rollout Handler
|
||||
participant API as AtroposLib API
|
||||
|
||||
%% --- Initialization ---
|
||||
RH->>API: POST /register-env (Send env details)
|
||||
activate API
|
||||
API-->>RH: Response (env_id, starting_step, wandb_name, ...) %% wandb_name is unique to this handler
|
||||
deactivate API
|
||||
Note over RH: Store env_id and unique wandb_name.
|
||||
|
||||
Note over RH: Fetch W&B configuration (Assumes Trainer already called /register)
|
||||
RH->>API: GET /wandb_info
|
||||
activate API
|
||||
API-->>RH: Response {"group": wb_group, "project": wb_project}
|
||||
deactivate API
|
||||
Note over RH: Initialize wandb logging (e.g., wandb.init) using group=wb_group, project=wb_project, name=wandb_name.
|
||||
|
||||
Note over RH: Know target batch_size (from config?). Set off_policy_tolerance (e.g., 3). Set internal state = 'Running'.
|
||||
|
||||
loop Simulation Loop
|
||||
|
||||
%% --- Check Pause State & Generate/Send Data ---
|
||||
alt State is 'Running'
|
||||
Note over RH: Generating data using internal environment logic...
|
||||
%% (Internal simulation steps, action selection, etc., happen here - details are opaque to the API)
|
||||
Note over RH: Trajectory chunk collected (contains tokens, masks, scores...). Log env-specific metrics to wandb (e.g., episode reward, length).
|
||||
|
||||
%% --- Send Data ---
|
||||
RH->>API: POST /scored_data or /scored_data_list (Send collected chunk)
|
||||
activate API
|
||||
API-->>RH: Ack {"status": "received", ...}
|
||||
deactivate API
|
||||
else State is 'Paused'
|
||||
Note over RH: Currently paused, skipping data generation and sending. Will check status again.
|
||||
%% Implement delay/sleep here to avoid busy-checking status when paused
|
||||
end
|
||||
|
||||
|
||||
%% --- Periodic Queue Size Check (Pause/Resume Logic) ---
|
||||
Note over RH: Checking API queue status to decide pause/resume state.
|
||||
RH->>API: GET /status-env (using stored env_id)
|
||||
activate API
|
||||
API-->>RH: Response {"current_step": T_current, "queue_size": Q, "env_weight": W}
|
||||
deactivate API
|
||||
Note over RH: T_current might be logged or used for other internal reasons by the handler. Log queue size Q?
|
||||
|
||||
Note over RH: Calculate threshold = off_policy_tolerance * batch_size
|
||||
alt Check if queue size exceeds threshold (Q > threshold)
|
||||
Note over RH: Queue size (Q = Q) > threshold. Setting internal state to 'Paused'.
|
||||
opt State was 'Running'
|
||||
Note over RH: Stopping data generation. Log pause event to wandb.
|
||||
end
|
||||
else Queue size is acceptable (Q <= threshold)
|
||||
Note over RH: Queue size (Q = Q) <= threshold. Ensuring state is 'Running'.
|
||||
opt State was 'Paused'
|
||||
Note over RH: Resuming data generation. Log resume event to wandb.
|
||||
end
|
||||
end
|
||||
|
||||
end %% End Simulation Loop
|
||||
|
||||
%% --- Optional Shutdown ---
|
||||
RH->>API: POST /disconnect-env (using stored env_id)
|
||||
activate API
|
||||
API-->>RH: Ack {"status": "success"}
|
||||
deactivate API
|
||||
Note over RH: Finalize wandb logging (wandb.finish).
|
||||
```
|
||||
305
atroposlib/api/server.py
Normal file
305
atroposlib/api/server.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
import time
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
|
||||
|
||||
app = FastAPI(title="AtroposLib API")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "AtroposLib API"}
|
||||
|
||||
|
||||
class Registration(BaseModel):
|
||||
wandb_group: str
|
||||
wandb_project: str
|
||||
batch_size: int
|
||||
max_token_len: int
|
||||
checkpoint_dir: str
|
||||
save_checkpoint_interval: int
|
||||
starting_step: int
|
||||
num_steps: int
|
||||
|
||||
|
||||
class RegisterEnv(BaseModel):
|
||||
max_token_length: int
|
||||
desired_name: str
|
||||
weight: float
|
||||
|
||||
|
||||
class EnvIdentifier(BaseModel):
|
||||
env_id: int
|
||||
|
||||
|
||||
class ScoredData(BaseModel):
|
||||
tokens: List[List[int]]
|
||||
masks: List[List[int]]
|
||||
scores: List[float]
|
||||
ref_logprobs: Optional[List[List[float]]] = None
|
||||
overrides: Optional[List[dict]] = None
|
||||
group_overrides: Optional[dict] = None
|
||||
images: Optional[Any] = None
|
||||
|
||||
|
||||
class Status(BaseModel):
|
||||
"""
|
||||
basemodel for status information of the current server
|
||||
"""
|
||||
|
||||
current_step: int
|
||||
queue_size: int
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
"""
|
||||
basemodel for useful information
|
||||
"""
|
||||
|
||||
batch_size: int = -1
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register(registration: Registration):
|
||||
try:
|
||||
isinstance(app.state.queue, list)
|
||||
except AttributeError:
|
||||
app.state.queue = []
|
||||
app.state.group = registration.wandb_group
|
||||
app.state.project = registration.wandb_project
|
||||
app.state.batchsize = int(registration.batch_size)
|
||||
app.state.max_token_len = int(registration.max_token_len)
|
||||
app.state.status_dict = {"step": registration.starting_step}
|
||||
app.state.checkpoint_dir = registration.checkpoint_dir
|
||||
app.state.save_checkpoint_interval = registration.save_checkpoint_interval
|
||||
app.state.num_steps = registration.num_steps
|
||||
app.state.curr_batch = []
|
||||
app.state.started = False
|
||||
app.state.envs = []
|
||||
try:
|
||||
app.state.requesters.append(uuid.uuid4().int)
|
||||
except AttributeError:
|
||||
# If requesters doesn't exist, create it
|
||||
app.state.requesters = [uuid.uuid4().int]
|
||||
return {"uuid": app.state.requesters[-1]}
|
||||
|
||||
|
||||
@app.post("/register-env")
|
||||
async def register_env_url(register_env: RegisterEnv):
|
||||
try:
|
||||
isinstance(app.state.envs, list)
|
||||
except AttributeError:
|
||||
app.state.envs = []
|
||||
checkpoint_dir = ""
|
||||
try:
|
||||
checkpoint_dir = app.state.checkpoint_dir
|
||||
except AttributeError:
|
||||
pass
|
||||
real_name = (
|
||||
f"{register_env.desired_name}_"
|
||||
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
|
||||
)
|
||||
registered_id = len(app.state.envs)
|
||||
app.state.envs.append(
|
||||
{
|
||||
"max_context_len": register_env.max_token_length,
|
||||
"weight": register_env.weight if register_env.weight is not None else 1.0,
|
||||
"desired_name": register_env.desired_name,
|
||||
"real_name": real_name,
|
||||
"registered_id": registered_id,
|
||||
"last_update": time.time(),
|
||||
"connected": True,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"env_id": registered_id,
|
||||
"wandb_name": real_name,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"starting_step": app.state.status_dict["step"],
|
||||
"checkpoint_interval": app.state.save_checkpoint_interval,
|
||||
"num_steps": app.state.num_steps,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/disconnect-env")
|
||||
async def disconnect_env(disconnect_env: EnvIdentifier):
|
||||
try:
|
||||
app.state.envs[disconnect_env.env_id]["connected"] = False
|
||||
return {"status": "success"}
|
||||
except (AttributeError, IndexError) as e:
|
||||
return {"status": "failure", "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/wandb_info")
|
||||
async def wandb_info():
|
||||
try:
|
||||
return {"group": app.state.group, "project": app.state.project}
|
||||
except AttributeError:
|
||||
return {"group": None, "project": None}
|
||||
|
||||
|
||||
@app.get("/info")
|
||||
async def info():
|
||||
try:
|
||||
return {
|
||||
"batch_size": app.state.batchsize,
|
||||
"max_token_len": app.state.max_token_len,
|
||||
}
|
||||
except AttributeError:
|
||||
return {"batch_size": -1, "max_token_len": -1}
|
||||
|
||||
|
||||
@app.get("/batch")
|
||||
async def get_batch():
|
||||
if not app.state.started:
|
||||
app.state.started = True
|
||||
|
||||
if len(app.state.curr_batch) > 0:
|
||||
return {"batch": app.state.curr_batch.pop()}
|
||||
else:
|
||||
new_batches = []
|
||||
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
||||
app.state.queue, app.state.batchsize
|
||||
)
|
||||
while batch is not None:
|
||||
new_batches.append(batch)
|
||||
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
||||
app.state.queue, app.state.batchsize
|
||||
)
|
||||
steps_to_take = len(new_batches)
|
||||
if steps_to_take == 0:
|
||||
return {"batch": None}
|
||||
app.state.status_dict["step"] += steps_to_take
|
||||
# chunk it
|
||||
for batch in new_batches:
|
||||
app.state.curr_batch.append(batch)
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
# check length before sending
|
||||
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
|
||||
return {"batch": curr_batch}
|
||||
|
||||
|
||||
@app.get("/latest_example")
|
||||
async def get_latest_example():
|
||||
try:
|
||||
return app.state.latest
|
||||
except AttributeError:
|
||||
return {
|
||||
"tokens": [],
|
||||
"masks": [],
|
||||
"scores": [],
|
||||
"ref_logprobs": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/scored_data")
|
||||
async def scored_data(scored_data: ScoredData):
|
||||
app.state.queue.append(
|
||||
{
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
}
|
||||
)
|
||||
app.state.latest = app.state.queue[-1]
|
||||
return {"status": "received"}
|
||||
|
||||
|
||||
@app.post("/scored_data_list")
|
||||
async def scored_data_list(scored_data_list: List[ScoredData]):
|
||||
"""Handle a list of ScoredData objects for step-based learning"""
|
||||
|
||||
for idx, scored_data in enumerate(scored_data_list):
|
||||
|
||||
app.state.queue.append(
|
||||
{
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"images": scored_data.images,
|
||||
}
|
||||
)
|
||||
|
||||
if scored_data_list:
|
||||
app.state.latest = app.state.queue[-1]
|
||||
|
||||
return {"status": "received", "groups_processed": len(scored_data_list)}
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def get_status():
|
||||
try:
|
||||
return {
|
||||
"current_step": app.state.status_dict["step"],
|
||||
"queue_size": len(app.state.queue),
|
||||
}
|
||||
except AttributeError:
|
||||
return {"current_step": 0, "queue_size": 0}
|
||||
|
||||
|
||||
@app.get("/status-env")
|
||||
async def get_status_env(env: EnvIdentifier):
|
||||
total = sum(
|
||||
[
|
||||
x["max_context_len"] * max(0.0, x["weight"])
|
||||
for x in app.state.envs
|
||||
if x["connected"]
|
||||
]
|
||||
)
|
||||
env_weight = (
|
||||
app.state.envs[env.env_id]["max_context_len"]
|
||||
* app.state.envs[env.env_id]["weight"]
|
||||
/ total
|
||||
)
|
||||
env_weight = max(
|
||||
0.01, env_weight
|
||||
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
|
||||
|
||||
try:
|
||||
ret_dict = {
|
||||
"current_step": app.state.status_dict["step"],
|
||||
"queue_size": len(app.state.queue),
|
||||
}
|
||||
except AttributeError:
|
||||
ret_dict = {"current_step": 0, "queue_size": 0}
|
||||
ret_dict["env_weight"] = env_weight
|
||||
return ret_dict
|
||||
|
||||
|
||||
@app.get("/reset_data")
|
||||
async def reset_data():
|
||||
try:
|
||||
del app.state.queue
|
||||
app.state.group = None
|
||||
app.state.project = None
|
||||
app.state.batchsize = -1
|
||||
app.state.num_steps = -1
|
||||
app.state.status_dict = {"step": 0}
|
||||
app.state.curr_batch = []
|
||||
app.state.started = False
|
||||
app.state.requesters = []
|
||||
app.state.envs = []
|
||||
except KeyError:
|
||||
pass
|
||||
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)
|
||||
58
atroposlib/api/trainer_interaction.md
Normal file
58
atroposlib/api/trainer_interaction.md
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
```mermaid
|
||||
sequenceDiagram
|
||||
participant R0 as Trainer Rank 0
|
||||
participant R1N as Trainer Rank 1..N-1
|
||||
participant API as AtroposLib API
|
||||
|
||||
R0->>API: POST /register (send Registration data)
|
||||
activate API
|
||||
API-->>R0: Respond with {'uuid': trainer_uuid}
|
||||
deactivate API
|
||||
Note over R0, R1N: Initialization complete. Trainer begins requesting data
|
||||
|
||||
loop Training Steps
|
||||
%% --- Phase 2: Rank 0 fetches batch, others wait/poll ---
|
||||
par Fetch vs Poll
|
||||
loop While Batch is Null:
|
||||
R0->>API: GET /batch
|
||||
activate API
|
||||
|
||||
Note over API: Checks queue, potentially increments step counter if batch is formed.
|
||||
|
||||
alt Batch Available
|
||||
API-->>R0: {'batch': [data_item_1, ...]}
|
||||
Note over R0: Received batch for step S+1. Breaking loop.
|
||||
else No Batch Available
|
||||
API-->>R0: {'batch': null}
|
||||
Note over R0: No batch ready yet. Will retry.
|
||||
end
|
||||
deactivate API
|
||||
end
|
||||
and
|
||||
Note over R1N: Poll status until step increments from S.
|
||||
loop While Server Step is S
|
||||
R1N->>API: GET /status
|
||||
activate API
|
||||
API-->>R1N: {'current_step': S_new, 'queue_size': Q_new}
|
||||
deactivate API
|
||||
Note over R1N: Checking if S_new > S... (Current S_new = S_new)
|
||||
%% In implementation, add delay here if S_new == S to avoid busy-wait
|
||||
end
|
||||
Note over R1N: Detected step incremented (S_new > S). Ready for broadcast.
|
||||
end
|
||||
|
||||
%% --- Phase 3: Handle result ---
|
||||
Note over R0: Broadcasts received batch data to Ranks 1..N-1 (External Mechanism)
|
||||
Note over R1N: Receives broadcasted data from Rank 0.
|
||||
Note over R0, R1N: All ranks now have the same batch for step S+1.
|
||||
|
||||
%% --- Phase 4: Perform Training Step ---
|
||||
par Perform Training
|
||||
R0->>R0: Perform training step with batch data
|
||||
and
|
||||
R1N->>R1N: Perform training step with batch data
|
||||
end
|
||||
Note over R0, R1N: Training step S+1 complete.
|
||||
|
||||
end # End Training Steps Loop
|
||||
```
|
||||
61
atroposlib/api/utils.py
Normal file
61
atroposlib/api/utils.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def grab_exact_from_heterogeneous_queue(
|
||||
queue: List[Dict[str, List]], batch_size: int
|
||||
) -> Tuple[Optional[List], List]:
|
||||
"""
|
||||
Grabs a batch of size batchsize from a queue of different sized items
|
||||
|
||||
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
|
||||
|
||||
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
|
||||
|
||||
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
|
||||
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
|
||||
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
|
||||
|
||||
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
|
||||
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
|
||||
|
||||
:param queue:
|
||||
:param batch_size:
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
# check if we can even potentially grab a batch
|
||||
if sum(len(item["tokens"]) for item in queue) < batch_size:
|
||||
return None, queue
|
||||
# Get max batch size
|
||||
max_group_size = max(len(group["tokens"]) for group in queue)
|
||||
group_sizes = set(len(group["tokens"]) for group in queue)
|
||||
group_batching_storage = {i: [] for i in group_sizes}
|
||||
# pack the groups into [max_group_size // group_size] packs
|
||||
potential_batch = []
|
||||
for i, item in enumerate(queue):
|
||||
key = len(item["tokens"])
|
||||
group_batching_storage[key].append({"group": item, "indx": i})
|
||||
if len(group_batching_storage[key]) * key == max_group_size:
|
||||
potential_batch.extend(group_batching_storage[key])
|
||||
group_batching_storage[key] = []
|
||||
if (
|
||||
sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch)
|
||||
< batch_size
|
||||
):
|
||||
return None, queue
|
||||
# we have a batch
|
||||
batch = []
|
||||
indxes_to_remove_from_queue = []
|
||||
for item in potential_batch:
|
||||
group = item["group"]
|
||||
indx = item["indx"]
|
||||
batch.append(group)
|
||||
indxes_to_remove_from_queue.append(indx)
|
||||
if sum(len(item["tokens"]) for item in batch) == batch_size:
|
||||
break
|
||||
if sum(len(item["tokens"]) for item in batch) != batch_size:
|
||||
return None, queue
|
||||
# remove the items from the queue
|
||||
new_queue = [
|
||||
item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue
|
||||
]
|
||||
return batch, new_queue
|
||||
Loading…
Add table
Add a link
Reference in a new issue