feat: add minimum batch allocation support for environments

- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch
- Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100%
- Add mixed-size group buffering to handle variable-sized data submissions
- Update server to use minimum allocation logic when any env has min_batch_allocation set
- Add comprehensive tests for minimum allocation scenarios
- Update documentation in API README and CONFIG.md
- Update example environments to demonstrate the feature

This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dakota 2025-07-07 08:50:28 -05:00
parent 4769eeb4a6
commit 08e14cc745
11 changed files with 1670 additions and 91 deletions

View file

@ -7,7 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
grab_exact_from_heterogeneous_queue,
)
# Message import removed - using Dict[str, Any] for more flexible validation
@ -42,6 +46,10 @@ class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
group_size: int
min_batch_allocation: Optional[float] = (
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
)
class EnvIdentifier(BaseModel):
@ -60,6 +68,7 @@ class ScoredData(BaseModel):
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
env_id: Optional[int] = None # ID of the environment that generated this data
@field_validator("messages", mode="before")
@classmethod
@ -115,6 +124,7 @@ async def register(registration: Registration):
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
@ -157,6 +167,8 @@ async def register_env_url(register_env: RegisterEnv):
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
}
)
return {
@ -207,14 +219,31 @@ async def get_batch():
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
# Check if any envs have minimum allocations
has_min_allocations = any(
env.get("min_batch_allocation") is not None
for env in getattr(app.state, "envs", [])
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
return {"batch": None}
@ -224,7 +253,7 @@ async def get_batch():
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
return {"batch": curr_batch}
@ -246,20 +275,58 @@ async def get_latest_example():
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
}
)
app.state.latest = app.state.queue[-1]
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
@ -267,24 +334,57 @@ async def scored_data(scored_data: ScoredData):
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
for idx, scored_data in enumerate(scored_data_list):
# Process each scored data item
for scored_data in scored_data_list:
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"env_id": scored_data.env_id,
}
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
}
)
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if scored_data_list:
app.state.latest = app.state.queue[-1]
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
else:
# Normal size - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
else:
# No env info or normal path - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received", "groups_processed": len(scored_data_list)}
@ -309,6 +409,7 @@ async def get_status_env(env: EnvIdentifier):
if x["connected"]
]
)
env_group_size = app.state.envs[env.env_id]["group_size"]
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
@ -318,13 +419,95 @@ async def get_status_env(env: EnvIdentifier):
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
# Calculate total minimum allocations
total_min_allocation = 0.0
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
total_min_allocation += env_config["min_batch_allocation"]
# Calculate unallocated fraction
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
# Find the maximum group size across all items in queue
queue = getattr(app.state, "queue", [])
max_group_size = 1
num_self_sequences_in_queue = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
if item.get("env_id") == env.env_id:
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
env_group_size = max(env_group_size, group_size)
num_self_sequences_in_queue += group_size
# update the group size for the requesting env
app.state.envs[env.env_id]["group_size"] = env_group_size
# Calculate minimum sequences allocated to each environment
batch_size = getattr(app.state, "batchsize", 0)
min_sequences_by_env = {}
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
env_id = env_config["registered_id"]
min_sequences = int(batch_size * env_config["min_batch_allocation"])
min_sequences_by_env[env_id] = min_sequences
# Count sequences and calculate packed groups for each environment
import math
sequences_by_env = {}
packed_groups_by_env = {}
curr_env_total_sequences = 0
for item in queue:
env_id = item.get("env_id")
seq_count = len(item.get("tokens", []))
# Special handling for the requesting environment
if env_id == env.env_id:
curr_env_total_sequences += seq_count
else:
if env_id not in sequences_by_env:
sequences_by_env[env_id] = 0
sequences_by_env[env_id] += seq_count
# Calculate packed groups for each environment (excluding the requesting env)
if max_group_size > 1:
for env_id, seq_count in sequences_by_env.items():
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
# Calculate adjusted queue size
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
available_from_others = 0
for env_id in packed_groups_by_env:
packed_sequences = packed_groups_by_env[env_id] * max_group_size
min_sequences = min_sequences_by_env.get(env_id, 0)
available_from_others += max(0, packed_sequences - min_sequences)
env_queue_size = curr_env_total_sequences + available_from_others
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
"queue_size": env_queue_size // env_group_size,
"unallocated_fraction": unallocated_fraction,
"self_queue_size": num_self_sequences_in_queue // env_group_size,
"max_group_size": max_group_size,
}
except AttributeError:
ret_dict = {"current_step": 0, "queue_size": 0}
ret_dict = {
"current_step": 0,
"queue_size": 0,
"unallocated_fraction": 1.0,
"num_self_sequences_in_queue": 0,
}
ret_dict["env_weight"] = env_weight
return ret_dict
@ -342,6 +525,7 @@ async def reset_data():
app.state.started = False
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)