feat: add minimum batch allocation support for environments

- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch
- Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100%
- Add mixed-size group buffering to handle variable-sized data submissions
- Update server to use minimum allocation logic when any env has min_batch_allocation set
- Add comprehensive tests for minimum allocation scenarios
- Update documentation in API README and CONFIG.md
- Update example environments to demonstrate the feature

This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dakota 2025-07-07 08:50:28 -05:00
parent 4769eeb4a6
commit 08e14cc745
11 changed files with 1670 additions and 91 deletions

View file

@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**:
* Registration endpoints for Trainers and Rollout Handlers.
* Serves batches of aggregated experience data to Trainers.
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch.
* Provides status endpoints for monitoring queue size and training step count.
* Basic integration with Weights & Biases (W&B) project/group info.
* Endpoints for Rollout Handlers to disconnect gracefully.
@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
max_token_length: int # Max length this env produces
desired_name: str # Base name for identification/logging
weight: float # Weight for sampling/batching (e.g., 1.0)
group_size: int # Expected number of sequences per data submission
min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0)
```
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
```json
@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
overrides: Optional[List[dict]] = None # Per-item logging overrides
group_overrides: Optional[dict] = None # Group logging overrides
images: Optional[Any] = None # Image data (if applicable)
env_id: Optional[int] = None # ID of the environment that generated this data
```
* **Response:** `{"status": "received"}`
* **Response:**
* Normal submission: `{"status": "received"}`
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
* `POST /scored_data_list`
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
* **Request Body:** `List[ScoredData]`
* **Response:** `{"status": "received", "groups_processed": <count>}`
* `GET /batch`
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Response:**
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
* Not enough data: `{"batch": null}`
@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
## Minimum Batch Allocation Feature
The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch.
### How It Works
1. **Environment Registration**: When registering an environment via `/register-env`, you can specify:
- `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute
- `group_size` (int): The expected number of sequences per data submission from this environment
2. **Batch Formation**: When the trainer requests a batch via `/batch`:
- If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met
- The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum
- If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down
- If an environment with a minimum allocation has no data available, the batch formation fails (returns null)
3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`:
- The data is buffered separately for that environment
- The system attempts to combine buffered groups to match the expected `group_size`
- Once combined, the data is added to the main queue
- Response includes `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
### Example Configuration
```python
# Environment 1: Requires at least 30% of each batch
{
"max_token_length": 512,
"desired_name": "critical_env",
"weight": 1.0,
"group_size": 4,
"min_batch_allocation": 0.3 # 30% minimum
}
# Environment 2: No minimum requirement
{
"max_token_length": 512,
"desired_name": "standard_env",
"weight": 1.0,
"group_size": 2,
"min_batch_allocation": None # No minimum
}
```
### Important Notes
- Minimum allocations are enforced per batch, not globally
- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails
- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied
- The feature respects heterogeneous packing constraints when forming batches
## Limitations & TODOs
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).