mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat: add minimum batch allocation support for environments
- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch - Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100% - Add mixed-size group buffering to handle variable-sized data submissions - Update server to use minimum allocation logic when any env has min_batch_allocation set - Add comprehensive tests for minimum allocation scenarios - Update documentation in API README and CONFIG.md - Update example environments to demonstrate the feature This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4769eeb4a6
commit
08e14cc745
11 changed files with 1670 additions and 91 deletions
|
|
@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**:
|
|||
* Registration endpoints for Trainers and Rollout Handlers.
|
||||
* Serves batches of aggregated experience data to Trainers.
|
||||
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
|
||||
* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch.
|
||||
* Provides status endpoints for monitoring queue size and training step count.
|
||||
* Basic integration with Weights & Biases (W&B) project/group info.
|
||||
* Endpoints for Rollout Handlers to disconnect gracefully.
|
||||
|
|
@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
max_token_length: int # Max length this env produces
|
||||
desired_name: str # Base name for identification/logging
|
||||
weight: float # Weight for sampling/batching (e.g., 1.0)
|
||||
group_size: int # Expected number of sequences per data submission
|
||||
min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0)
|
||||
```
|
||||
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
|
||||
```json
|
||||
|
|
@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
overrides: Optional[List[dict]] = None # Per-item logging overrides
|
||||
group_overrides: Optional[dict] = None # Group logging overrides
|
||||
images: Optional[Any] = None # Image data (if applicable)
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
```
|
||||
* **Response:** `{"status": "received"}`
|
||||
* **Response:**
|
||||
* Normal submission: `{"status": "received"}`
|
||||
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
|
||||
* `POST /scored_data_list`
|
||||
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
|
||||
* **Request Body:** `List[ScoredData]`
|
||||
* **Response:** `{"status": "received", "groups_processed": <count>}`
|
||||
* `GET /batch`
|
||||
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
|
||||
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
|
||||
* **Response:**
|
||||
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
|
||||
* Not enough data: `{"batch": null}`
|
||||
|
|
@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
|
||||
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
|
||||
|
||||
## Minimum Batch Allocation Feature
|
||||
|
||||
The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Environment Registration**: When registering an environment via `/register-env`, you can specify:
|
||||
- `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute
|
||||
- `group_size` (int): The expected number of sequences per data submission from this environment
|
||||
|
||||
2. **Batch Formation**: When the trainer requests a batch via `/batch`:
|
||||
- If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met
|
||||
- The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum
|
||||
- If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down
|
||||
- If an environment with a minimum allocation has no data available, the batch formation fails (returns null)
|
||||
|
||||
3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`:
|
||||
- The data is buffered separately for that environment
|
||||
- The system attempts to combine buffered groups to match the expected `group_size`
|
||||
- Once combined, the data is added to the main queue
|
||||
- Response includes `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
# Environment 1: Requires at least 30% of each batch
|
||||
{
|
||||
"max_token_length": 512,
|
||||
"desired_name": "critical_env",
|
||||
"weight": 1.0,
|
||||
"group_size": 4,
|
||||
"min_batch_allocation": 0.3 # 30% minimum
|
||||
}
|
||||
|
||||
# Environment 2: No minimum requirement
|
||||
{
|
||||
"max_token_length": 512,
|
||||
"desired_name": "standard_env",
|
||||
"weight": 1.0,
|
||||
"group_size": 2,
|
||||
"min_batch_allocation": None # No minimum
|
||||
}
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
|
||||
- Minimum allocations are enforced per batch, not globally
|
||||
- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails
|
||||
- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied
|
||||
- The feature respects heterogeneous packing constraints when forming batches
|
||||
|
||||
## Limitations & TODOs
|
||||
|
||||
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue