mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
feat: add minimum batch allocation support for environments
- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch - Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100% - Add mixed-size group buffering to handle variable-sized data submissions - Update server to use minimum allocation logic when any env has min_batch_allocation set - Add comprehensive tests for minimum allocation scenarios - Update documentation in API README and CONFIG.md - Update example environments to demonstrate the feature This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4769eeb4a6
commit
08e14cc745
11 changed files with 1670 additions and 91 deletions
|
|
@ -22,6 +22,13 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
||||
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
||||
"</think> tags, and then provide your solution or response to the problem."
|
||||
)
|
||||
|
||||
problem_format = "{problem}"
|
||||
|
||||
judge_format = """Here is a math problem and a proposed solution:
|
||||
|
|
@ -85,7 +92,7 @@ class RSConfig(BaseEnvConfig):
|
|||
)
|
||||
percent_to_judge: float = Field(0.3, description="The percentage of items to judge")
|
||||
percent_length_penalty: float = Field(
|
||||
0.0, description="The percentage of items to have length penalty"
|
||||
0.1, description="The percentage of items to have length penalty"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -179,21 +186,24 @@ class MathEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]:
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||
group_size=8,
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
max_num_workers_per_node=24,
|
||||
steps_per_eval=25,
|
||||
max_token_length=31000, # 22000 // (2 ** i),
|
||||
max_token_length=8192, # 22000 // (2 ** i),
|
||||
wandb_name="math",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
inference_weight=4,
|
||||
min_batch_allocation=0.1,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
|
|
@ -306,6 +316,7 @@ class MathEnv(BaseEnv):
|
|||
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
|
|
@ -352,11 +363,16 @@ class MathEnv(BaseEnv):
|
|||
thinking_len = self.config.max_token_length
|
||||
user_prompt = problem_format.format(problem=item[0])
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
thinking_len = thinking_len - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
print(f"thinking_len: {thinking_len}", flush=True)
|
||||
if thinking_len < 1024:
|
||||
print("thinking_len is less than 1024, skipping", flush=True)
|
||||
return None, []
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
|
|
@ -369,6 +385,7 @@ class MathEnv(BaseEnv):
|
|||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
|
|
@ -379,8 +396,9 @@ class MathEnv(BaseEnv):
|
|||
chat_completion.finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
print("scoring normal", flush=True)
|
||||
to_postprocess = await self.score_normal(to_score)
|
||||
print("scoring normal done", flush=True)
|
||||
if to_postprocess is None:
|
||||
return None, to_backlog
|
||||
if all(
|
||||
|
|
@ -712,6 +730,7 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
print("Sending to server")
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_fwd},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
|
|
@ -727,6 +746,7 @@ class MathEnv(BaseEnv):
|
|||
print("Sending to server")
|
||||
# Should be the same token length as the fwd but tokenizers are cursed
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_bwd},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
|
|
@ -822,6 +842,7 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
to_backlog = list()
|
||||
chat = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
|
|
@ -862,6 +883,7 @@ class MathEnv(BaseEnv):
|
|||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
],
|
||||
|
|
@ -902,6 +924,7 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
print("Sending to server")
|
||||
retry_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": retry_prompt},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue