feat: add minimum batch allocation support for environments

- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch
- Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100%
- Add mixed-size group buffering to handle variable-sized data submissions
- Update server to use minimum allocation logic when any env has min_batch_allocation set
- Add comprehensive tests for minimum allocation scenarios
- Update documentation in API README and CONFIG.md
- Update example environments to demonstrate the feature

This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dakota 2025-07-07 08:50:28 -05:00
parent 4769eeb4a6
commit 08e14cc745
11 changed files with 1670 additions and 91 deletions

View file

@ -22,6 +22,13 @@ from atroposlib.envs.base import (
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
problem_format = "{problem}"
judge_format = """Here is a math problem and a proposed solution:
@ -85,7 +92,7 @@ class RSConfig(BaseEnvConfig):
)
percent_to_judge: float = Field(0.3, description="The percentage of items to judge")
percent_length_penalty: float = Field(
0.0, description="The percentage of items to have length penalty"
0.1, description="The percentage of items to have length penalty"
)
@ -179,21 +186,24 @@ class MathEnv(BaseEnv):
@classmethod
def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]:
env_config = RSConfig(
tokenizer_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
group_size=8,
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
max_num_workers_per_node=24,
steps_per_eval=25,
max_token_length=31000, # 22000 // (2 ** i),
max_token_length=8192, # 22000 // (2 ** i),
wandb_name="math",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
inference_weight=4,
min_batch_allocation=0.1,
)
server_configs = [
APIServerConfig(
model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # since evaling only on one...
@ -306,6 +316,7 @@ class MathEnv(BaseEnv):
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
@ -352,11 +363,16 @@ class MathEnv(BaseEnv):
thinking_len = self.config.max_token_length
user_prompt = problem_format.format(problem=item[0])
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
thinking_len = thinking_len - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
print(f"thinking_len: {thinking_len}", flush=True)
if thinking_len < 1024:
print("thinking_len is less than 1024, skipping", flush=True)
return None, []
chat_completions = await self.server.chat_completion(
messages=chat,
n=self.config.group_size,
@ -369,6 +385,7 @@ class MathEnv(BaseEnv):
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
)
@ -379,8 +396,9 @@ class MathEnv(BaseEnv):
chat_completion.finish_reason,
)
)
print("scoring normal", flush=True)
to_postprocess = await self.score_normal(to_score)
print("scoring normal done", flush=True)
if to_postprocess is None:
return None, to_backlog
if all(
@ -712,6 +730,7 @@ class MathEnv(BaseEnv):
)
print("Sending to server")
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_fwd},
]
max_token_length = self.config.max_token_length - len(
@ -727,6 +746,7 @@ class MathEnv(BaseEnv):
print("Sending to server")
# Should be the same token length as the fwd but tokenizers are cursed
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_bwd},
]
max_token_length = self.config.max_token_length - len(
@ -822,6 +842,7 @@ class MathEnv(BaseEnv):
)
to_backlog = list()
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
max_token_length = self.config.max_token_length - len(
@ -862,6 +883,7 @@ class MathEnv(BaseEnv):
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
],
@ -902,6 +924,7 @@ class MathEnv(BaseEnv):
)
print("Sending to server")
retry_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": retry_prompt},
]
max_token_length = self.config.max_token_length - len(