mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
linting
This commit is contained in:
parent
bdcc3cb88f
commit
e96970f82e
3 changed files with 84 additions and 323 deletions
|
|
@ -5,12 +5,7 @@ import random
|
|||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
|
|
@ -56,9 +51,7 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
|
|||
boxed_answer_bonus: float = 0.2
|
||||
|
||||
apply_length_penalty: bool = True
|
||||
length_threshold_ratio: float = (
|
||||
0.5
|
||||
)
|
||||
length_threshold_ratio: float = 0.5
|
||||
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
|
|
@ -78,9 +71,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
self.config = config
|
||||
|
||||
self.percent_correct_buffer = []
|
||||
self.level_correct_buffer = {
|
||||
i: [] for i in range(1, 8)
|
||||
}
|
||||
self.level_correct_buffer = {i: [] for i in range(1, 8)}
|
||||
self.eval_metrics = []
|
||||
|
||||
self.curriculum = None
|
||||
|
|
@ -416,7 +407,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
|
||||
to_score = []
|
||||
|
||||
|
||||
level = None
|
||||
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
|
||||
if generator_id in generator_ids:
|
||||
|
|
@ -436,13 +427,13 @@ class InfiniteMathEnv(BaseEnv):
|
|||
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
||||
{"role": "assistant", "content": model_answer},
|
||||
]
|
||||
|
||||
|
||||
to_score.append((full_messages, solution, generator_id, level))
|
||||
|
||||
backlog = []
|
||||
|
||||
return to_score, backlog
|
||||
|
||||
|
||||
async def score(self, rollout_group_data) -> ScoredDataGroup:
|
||||
"""Score the collected trajectories."""
|
||||
scored_data = ScoredDataGroup()
|
||||
|
|
@ -450,8 +441,10 @@ class InfiniteMathEnv(BaseEnv):
|
|||
scored_data["masks"] = []
|
||||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
|
||||
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
|
||||
|
||||
for i, (messages, solution, generator_id, level) in enumerate(
|
||||
rollout_group_data
|
||||
):
|
||||
model_answer = messages[-1]["content"]
|
||||
current_score = 0.0
|
||||
|
||||
|
|
@ -460,7 +453,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
current_score += self.config.correct_reward
|
||||
else:
|
||||
current_score += self.config.incorrect_reward
|
||||
|
||||
|
||||
self.percent_correct_buffer.append(1 if is_correct else 0)
|
||||
if level is not None:
|
||||
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
||||
|
|
@ -472,26 +465,32 @@ class InfiniteMathEnv(BaseEnv):
|
|||
if think_content:
|
||||
current_score += self.config.think_block_bonus
|
||||
|
||||
after_think_part = model_answer.split("</think>")[-1].strip() if "</think>" in model_answer else model_answer
|
||||
after_think_part = (
|
||||
model_answer.split("</think>")[-1].strip()
|
||||
if "</think>" in model_answer
|
||||
else model_answer
|
||||
)
|
||||
boxed_answer_content = self._extract_boxed_answer(after_think_part)
|
||||
if boxed_answer_content is not None:
|
||||
current_score += self.config.boxed_answer_bonus
|
||||
|
||||
logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}")
|
||||
|
||||
logger.info(
|
||||
f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}"
|
||||
)
|
||||
|
||||
tokens_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
scored_data["tokens"].append(tokens_dict["tokens"])
|
||||
scored_data["masks"].append(tokens_dict["masks"])
|
||||
scored_data["scores"].append(current_score)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
|
||||
self.curriculum.advance_difficulty()
|
||||
|
||||
|
||||
return scored_data
|
||||
|
||||
@classmethod
|
||||
|
|
@ -511,20 +510,19 @@ class InfiniteMathEnv(BaseEnv):
|
|||
inference_weight=1.0,
|
||||
wandb_name="infinite_math",
|
||||
data_path_to_save_groups="data/infinite_math_groups.jsonl",
|
||||
|
||||
# InfiniteMathEnvConfig specific fields
|
||||
starting_level=1,
|
||||
progress_threshold=0.8,
|
||||
min_evaluations=10,
|
||||
max_attempts_per_problem=3, # Default from class, not in old main
|
||||
correct_reward=1.0, # As in old main
|
||||
incorrect_reward=-0.5, # As in old main (class default was -1.0)
|
||||
think_block_bonus=0.2, # As per previous update
|
||||
boxed_answer_bonus=0.2, # As per previous update
|
||||
max_attempts_per_problem=3, # Default from class, not in old main
|
||||
correct_reward=1.0, # As in old main
|
||||
incorrect_reward=-0.5, # As in old main (class default was -1.0)
|
||||
think_block_bonus=0.2, # As per previous update
|
||||
boxed_answer_bonus=0.2, # As per previous update
|
||||
apply_length_penalty=True, # As in old main
|
||||
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
|
||||
temperature=0.7, # As in old main
|
||||
top_p=0.9 # As in old main
|
||||
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
|
||||
temperature=0.7, # As in old main
|
||||
top_p=0.9, # As in old main
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue