This commit is contained in:
Shannon Sands 2025-05-12 07:53:12 +10:00
parent bdcc3cb88f
commit e96970f82e
3 changed files with 84 additions and 323 deletions

View file

@ -5,12 +5,7 @@ import random
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .curriculum import MathCurriculum
@ -56,9 +51,7 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
boxed_answer_bonus: float = 0.2
apply_length_penalty: bool = True
length_threshold_ratio: float = (
0.5
)
length_threshold_ratio: float = 0.5
temperature: float = 0.7
top_p: float = 0.9
@ -78,9 +71,7 @@ class InfiniteMathEnv(BaseEnv):
self.config = config
self.percent_correct_buffer = []
self.level_correct_buffer = {
i: [] for i in range(1, 8)
}
self.level_correct_buffer = {i: [] for i in range(1, 8)}
self.eval_metrics = []
self.curriculum = None
@ -416,7 +407,7 @@ class InfiniteMathEnv(BaseEnv):
)
to_score = []
level = None
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
if generator_id in generator_ids:
@ -436,13 +427,13 @@ class InfiniteMathEnv(BaseEnv):
{"role": "user", "content": dict(problem_prompt[0])["content"]},
{"role": "assistant", "content": model_answer},
]
to_score.append((full_messages, solution, generator_id, level))
backlog = []
return to_score, backlog
async def score(self, rollout_group_data) -> ScoredDataGroup:
"""Score the collected trajectories."""
scored_data = ScoredDataGroup()
@ -450,8 +441,10 @@ class InfiniteMathEnv(BaseEnv):
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["messages"] = []
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
for i, (messages, solution, generator_id, level) in enumerate(
rollout_group_data
):
model_answer = messages[-1]["content"]
current_score = 0.0
@ -460,7 +453,7 @@ class InfiniteMathEnv(BaseEnv):
current_score += self.config.correct_reward
else:
current_score += self.config.incorrect_reward
self.percent_correct_buffer.append(1 if is_correct else 0)
if level is not None:
self.level_correct_buffer[level].append(1 if is_correct else 0)
@ -472,26 +465,32 @@ class InfiniteMathEnv(BaseEnv):
if think_content:
current_score += self.config.think_block_bonus
after_think_part = model_answer.split("</think>")[-1].strip() if "</think>" in model_answer else model_answer
after_think_part = (
model_answer.split("</think>")[-1].strip()
if "</think>" in model_answer
else model_answer
)
boxed_answer_content = self._extract_boxed_answer(after_think_part)
if boxed_answer_content is not None:
current_score += self.config.boxed_answer_bonus
logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}")
logger.info(
f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}"
)
tokens_dict = tokenize_for_trainer(
self.tokenizer,
messages,
None,
)
scored_data["tokens"].append(tokens_dict["tokens"])
scored_data["masks"].append(tokens_dict["masks"])
scored_data["scores"].append(current_score)
scored_data["messages"].append(messages)
self.curriculum.advance_difficulty()
return scored_data
@classmethod
@ -511,20 +510,19 @@ class InfiniteMathEnv(BaseEnv):
inference_weight=1.0,
wandb_name="infinite_math",
data_path_to_save_groups="data/infinite_math_groups.jsonl",
# InfiniteMathEnvConfig specific fields
starting_level=1,
progress_threshold=0.8,
min_evaluations=10,
max_attempts_per_problem=3, # Default from class, not in old main
correct_reward=1.0, # As in old main
incorrect_reward=-0.5, # As in old main (class default was -1.0)
think_block_bonus=0.2, # As per previous update
boxed_answer_bonus=0.2, # As per previous update
max_attempts_per_problem=3, # Default from class, not in old main
correct_reward=1.0, # As in old main
incorrect_reward=-0.5, # As in old main (class default was -1.0)
think_block_bonus=0.2, # As per previous update
boxed_answer_bonus=0.2, # As per previous update
apply_length_penalty=True, # As in old main
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
temperature=0.7, # As in old main
top_p=0.9 # As in old main
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
temperature=0.7, # As in old main
top_p=0.9, # As in old main
)
server_configs = [