mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
663 lines
26 KiB
Python
663 lines
26 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from dotenv import load_dotenv
|
|
from openai import AsyncOpenAI
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
from .curriculum import MathCurriculum
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
system_prompt = (
|
|
"You are an expert mathematician that can use extremely long chains of thought "
|
|
"to deeply consider the problem and deliberate with yourself via systematic "
|
|
"reasoning processes to help come to a correct solution prior to answering.\n"
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
|
"tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.\n\n"
|
|
"The problems will be given in a LaTeX format, so be sure to follow the LaTeX "
|
|
"syntax when writing your answer (although no $ delimiters are necessary).\n\n"
|
|
"Follow these steps:\n"
|
|
"1. Understand the problem carefully\n"
|
|
"2. Plan your approach\n"
|
|
"3. Execute the calculations step-by-step\n"
|
|
"4. Verify your solution\n"
|
|
"5. Express the final answer as \\boxed{your answer here}\n\n"
|
|
"You may use extremely long chains of thought to deeply consider the problem "
|
|
"and deliberate with yourself via systematic reasoning processes to help come "
|
|
"to a correct solution prior to answering.\n\n"
|
|
"Your answer format should be:\n"
|
|
"<think>\n"
|
|
"[Your detailed step-by-step reasoning process here]\n"
|
|
"</think>\n\n"
|
|
"\\boxed{your final answer here}\n\n"
|
|
"Remember to format your final answer correctly as this is important for evaluation. "
|
|
"Do not apply any rounding to your final answer, be as exact as possible."
|
|
)
|
|
|
|
|
|
class InfiniteMathEnvConfig(BaseEnvConfig):
|
|
"""Configuration for the InfiniteMath environment."""
|
|
|
|
starting_level: int = 1
|
|
progress_threshold: float = 0.8
|
|
min_evaluations: int = 5
|
|
|
|
max_attempts_per_problem: int = 3
|
|
correct_reward: float = 1.0
|
|
incorrect_reward: float = -1.0
|
|
think_block_bonus: float = 0.2
|
|
boxed_answer_bonus: float = 0.2
|
|
|
|
apply_length_penalty: bool = True
|
|
length_threshold_ratio: float = 0.5
|
|
|
|
temperature: float = 0.7
|
|
top_p: float = 0.9
|
|
|
|
word_problem_model_name: Optional[str] = "gpt-4.1-mini"
|
|
word_problem_openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
|
|
word_problem_openai_base_url: Optional[str] = None
|
|
|
|
|
|
class InfiniteMathEnv(BaseEnv):
|
|
"""Environment for procedurally generated math problems with curriculum advancement."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: InfiniteMathEnvConfig,
|
|
server_configs: Union[List[APIServerConfig], APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config = config
|
|
|
|
self.percent_correct_buffer = []
|
|
self.level_correct_buffer = {i: [] for i in range(1, 8)}
|
|
self.eval_metrics = []
|
|
|
|
self.curriculum = None
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
async def setup(self):
|
|
"""Initialize the environment and curriculum."""
|
|
logger.info("Setting up InfiniteMathEnv")
|
|
|
|
self.curriculum = MathCurriculum(
|
|
starting_level=self.config.starting_level,
|
|
progress_threshold=self.config.progress_threshold,
|
|
min_evaluations=self.config.min_evaluations,
|
|
)
|
|
|
|
self.eval_problems = {}
|
|
for level in range(1, 8):
|
|
self.eval_problems[level] = []
|
|
temp_curriculum = MathCurriculum(starting_level=level)
|
|
attempts = 0
|
|
max_attempts_per_level = 20
|
|
|
|
while (
|
|
len(self.eval_problems[level]) < 10
|
|
and attempts < max_attempts_per_level
|
|
):
|
|
try:
|
|
problem, solution, generator_id = temp_curriculum.get_problem()
|
|
problem = self._strip_latex_delimiters(problem)
|
|
solution = self._strip_latex_delimiters(solution)
|
|
self.eval_problems[level].append((problem, solution, generator_id))
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error generating evaluation problem for level {level}: {e}"
|
|
)
|
|
attempts += 1
|
|
|
|
logger.info(
|
|
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
|
|
)
|
|
|
|
for level in range(1, 8):
|
|
if not self.eval_problems[level]:
|
|
logger.warning(
|
|
f"No valid evaluation problems for level {level}, adding fallback"
|
|
)
|
|
if level == 1:
|
|
self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
|
|
elif level == 2:
|
|
self.eval_problems[level].append(
|
|
("What is the square root of 16?", "4", 6)
|
|
)
|
|
elif level == 3:
|
|
self.eval_problems[level].append(
|
|
(
|
|
"What is the area of a triangle with base 6 and height 8?",
|
|
"24",
|
|
18,
|
|
)
|
|
)
|
|
elif level == 4:
|
|
self.eval_problems[level].append(
|
|
("What is the solution to x + 5 = 12?", "7", 26)
|
|
)
|
|
elif level == 5:
|
|
self.eval_problems[level].append(
|
|
("What is the volume of a cube with side length 3?", "27", 33)
|
|
)
|
|
elif level == 6:
|
|
self.eval_problems[level].append(
|
|
("What is 5 factorial?", "120", 31)
|
|
)
|
|
else:
|
|
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
|
|
|
|
def _strip_latex_delimiters(self, text: str) -> str:
|
|
"""Strip LaTeX delimiters ($...$) from text."""
|
|
return re.sub(r"\$(.*?)\$", r"\1", text)
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save curriculum state in checkpoint."""
|
|
if data is None:
|
|
data = {}
|
|
|
|
data["curriculum_level"] = self.curriculum.get_current_level()
|
|
data["performance_history"] = {
|
|
str(k): v for k, v in self.curriculum.performance_history.items()
|
|
}
|
|
|
|
super().save_checkpoint(step, data)
|
|
|
|
def load_checkpoint(self):
|
|
"""Load curriculum state from checkpoint."""
|
|
super().load_checkpoint()
|
|
|
|
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
|
|
try:
|
|
with open(checkpoint_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
if "curriculum_level" in data:
|
|
level = data["curriculum_level"]
|
|
self.curriculum.current_level = level
|
|
|
|
if "performance_history" in data:
|
|
self.curriculum.performance_history = {
|
|
int(k): v for k, v in data["performance_history"].items()
|
|
}
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
logger.warning(f"Failed to load checkpoint: {e}")
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / max(1, len(self.percent_correct_buffer))
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
for level, buffer in self.level_correct_buffer.items():
|
|
if buffer:
|
|
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
|
|
buffer
|
|
)
|
|
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
|
|
|
|
if self.curriculum:
|
|
current_level = self.curriculum.get_current_level()
|
|
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
|
|
|
|
wandb_metrics["curriculum/current_level"] = current_level
|
|
wandb_metrics["curriculum/max_level"] = max_level
|
|
wandb_metrics["curriculum/progress_percent"] = (
|
|
current_level / max_level
|
|
) * 100
|
|
|
|
wandb_metrics["curriculum/level_description"] = (
|
|
self.curriculum.get_level_description()
|
|
)
|
|
|
|
if current_level in self.curriculum.performance_history:
|
|
history = self.curriculum.performance_history[current_level]
|
|
if history:
|
|
recent_history = history[
|
|
-min(len(history), self.curriculum.min_evaluations) :
|
|
]
|
|
if recent_history:
|
|
success_rate = sum(recent_history) / len(recent_history)
|
|
wandb_metrics["curriculum/current_level_success_rate"] = (
|
|
success_rate
|
|
)
|
|
wandb_metrics["curriculum/threshold_to_advance"] = (
|
|
self.curriculum.progress_threshold
|
|
)
|
|
wandb_metrics["curriculum/remaining_to_threshold"] = max(
|
|
0, self.curriculum.progress_threshold - success_rate
|
|
)
|
|
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
|
|
self.percent_correct_buffer = []
|
|
for level in self.level_correct_buffer:
|
|
self.level_correct_buffer[level] = []
|
|
self.eval_metrics = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def _convert_to_word_problem(self, raw_problem_text: str) -> str:
|
|
"""Converts a raw math problem string into a word problem using an LLM."""
|
|
system_prompt_word_problem = (
|
|
"You are an expert creative writer. Your task is to transform a given raw "
|
|
"mathematical expression into an engaging and imaginative word problem.\n\n"
|
|
"**Critical Instructions:**\n"
|
|
"1. **Strict Preservation:** The core mathematical question, ALL numbers, and ALL operations "
|
|
"from the raw problem MUST be EXACTLY preserved in the word problem. Do NOT change the calculation "
|
|
"required. For example, if the raw problem is 'A - B', the word problem must represent subtraction "
|
|
"of B from A, not any other operation.\n"
|
|
"2. **Clarity:** The word problem must clearly and unambiguously lead to solving the original "
|
|
"mathematical expression.\n"
|
|
"3. **Conciseness:** Keep the word problem relatively short and to the point.\n"
|
|
"4. **Output Format:** Output ONLY the word problem text. Do NOT include any preambles, "
|
|
"self-references (like 'Here is a word problem:'), special tokens (like '<|start_header_id|>'), "
|
|
"or any text other than the word problem itself.\n\n"
|
|
"**Examples of Correct Transformation:**\n"
|
|
"Raw Problem: 5 * 3\n"
|
|
"Word Problem: Sarah is baking cookies, and each batch requires 3 eggs. If Sarah wants to bake 5 batches, "
|
|
"how many eggs will she need in total?\n\n"
|
|
"Raw Problem: |10 - 15|\n"
|
|
"Word Problem: A submarine is 10 meters below sea level. Another submarine is 15 meters below sea level. "
|
|
"What is the absolute difference in their depths in meters?\n\n"
|
|
"Raw Problem: sqrt(16)\n"
|
|
"Word Problem: A square piece of land has an area of 16 square units. What is the length of one of its "
|
|
"sides in units?\n\n"
|
|
"**Example of Incorrect Transformation (Operation Changed):**\n"
|
|
"Raw Problem: |3 - (-67)| (This is 3 + 67)\n"
|
|
"Incorrect Word Problem: In a magical forest, there are 3 enchanted trees,"
|
|
" and each tree has 67 glowing fruits. "
|
|
"How many glowing fruits are there in total? (This became 3 * 67)\n"
|
|
"Correct Word Problem: A bird watcher is 3 meters up a tree. "
|
|
"She spots a rare bird 67 meters below ground level "
|
|
"in a cave. What is the total vertical distance between the bird watcher and the rare bird in meters?"
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt_word_problem},
|
|
{"role": "user", "content": f"Raw Problem: {raw_problem_text}"},
|
|
]
|
|
|
|
try:
|
|
api_key_to_use = self.config.word_problem_openai_api_key or os.environ.get(
|
|
"OPENAI_API_KEY"
|
|
)
|
|
base_url_to_use = self.config.word_problem_openai_base_url
|
|
model_to_use = self.config.word_problem_model_name or "gpt-4.1-mini"
|
|
|
|
if not api_key_to_use:
|
|
logger.error(
|
|
"OpenAI API key for word problem generation is not configured "
|
|
"(checked config and OPENAI_API_KEY env var)."
|
|
)
|
|
return raw_problem_text
|
|
|
|
client = AsyncOpenAI(
|
|
api_key=api_key_to_use,
|
|
base_url=base_url_to_use,
|
|
)
|
|
|
|
chat_completions = await client.chat.completions.create(
|
|
model=model_to_use,
|
|
messages=messages,
|
|
n=1,
|
|
max_tokens=512,
|
|
temperature=0.7,
|
|
top_p=1.0,
|
|
)
|
|
|
|
generated_text = chat_completions.choices[0].message.content
|
|
|
|
original_llm_output = generated_text
|
|
|
|
cleaned_text = original_llm_output.strip()
|
|
|
|
if not cleaned_text:
|
|
logger.warning(
|
|
f"Word problem conversion for '{raw_problem_text}' resulted in empty string after stripping. "
|
|
f"Original LLM output was: '{original_llm_output}'. Falling back to raw problem."
|
|
)
|
|
return raw_problem_text
|
|
|
|
logger.info(
|
|
f"Converted raw problem '{raw_problem_text}' to word problem: '{cleaned_text}'"
|
|
)
|
|
return cleaned_text
|
|
except Exception as e:
|
|
log_message_error = (
|
|
f"Error converting to word problem for '{raw_problem_text}': {e}"
|
|
)
|
|
logger.error(log_message_error)
|
|
return raw_problem_text
|
|
|
|
async def get_next_item(self):
|
|
"""Get the next problem based on current curriculum level."""
|
|
raw_problem, solution, generator_id = self.curriculum.get_problem()
|
|
|
|
raw_problem_stripped = self._strip_latex_delimiters(raw_problem)
|
|
solution_stripped = self._strip_latex_delimiters(solution)
|
|
|
|
word_problem_text = await self._convert_to_word_problem(raw_problem_stripped)
|
|
|
|
prompt = tuple(
|
|
[frozenset({"role": "user", "content": word_problem_text}.items())]
|
|
)
|
|
|
|
return (prompt, solution_stripped, generator_id)
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate the model on test problems at the current curriculum level."""
|
|
current_level = self.curriculum.get_current_level()
|
|
logger.info(f"Starting evaluation for curriculum level {current_level}")
|
|
|
|
eval_tasks = []
|
|
eval_generator_ids = []
|
|
if current_level in self.eval_problems:
|
|
for problem, solution, generator_id in self.eval_problems[current_level]:
|
|
eval_tasks.append(
|
|
self.evaluate_single_problem(problem, solution, current_level)
|
|
)
|
|
eval_generator_ids.append(generator_id)
|
|
|
|
if not eval_tasks:
|
|
logger.warning(
|
|
f"No evaluation problems available for level {current_level}"
|
|
)
|
|
return []
|
|
|
|
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
|
|
results = await asyncio.gather(*eval_tasks)
|
|
|
|
correct_count = sum(1 for _, is_correct in results if is_correct)
|
|
total_count = len(results)
|
|
accuracy = correct_count / total_count if total_count > 0 else 0
|
|
|
|
logger.info(
|
|
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
|
|
)
|
|
|
|
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
|
|
self.eval_metrics.append(("eval/current_level", current_level))
|
|
|
|
for i, (_, is_correct) in enumerate(results):
|
|
if i < len(eval_generator_ids):
|
|
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
|
|
else:
|
|
sample_generator_id = random.choice(
|
|
self.curriculum.DIFFICULTY_LEVELS[current_level]
|
|
)
|
|
self.curriculum.record_performance(sample_generator_id, is_correct)
|
|
|
|
advanced = self.curriculum.advance_difficulty()
|
|
new_level = self.curriculum.get_current_level()
|
|
|
|
if advanced:
|
|
logger.info(f"Advanced from level {current_level} to level {new_level}!")
|
|
self.eval_metrics.append(("eval/advanced_level", 1))
|
|
else:
|
|
logger.info(f"Remaining at level {current_level}")
|
|
self.eval_metrics.append(("eval/advanced_level", 0))
|
|
|
|
return self.eval_metrics
|
|
|
|
async def evaluate_single_problem(
|
|
self, problem: str, solution: str, level: int
|
|
) -> Tuple[int, bool]:
|
|
"""Evaluate a single problem."""
|
|
try:
|
|
word_problem_text = await self._convert_to_word_problem(problem)
|
|
logger.debug(
|
|
f"Evaluating level {level} word problem: {word_problem_text[:50]}... "
|
|
f"(Original raw: {problem[:30]}...)"
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": word_problem_text},
|
|
]
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
|
prefill = "\n<think>\n"
|
|
prefilled_prompt = prompt + prefill
|
|
|
|
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
|
|
completion = await self.server.completion(
|
|
prompt=prefilled_prompt,
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0,
|
|
top_p=1.0,
|
|
split="eval",
|
|
)
|
|
|
|
model_answer = prefill + (
|
|
completion.choices[0].text
|
|
if hasattr(completion.choices[0], "text")
|
|
else completion.choices[0].message.content
|
|
)
|
|
|
|
is_correct = self.check_answer(model_answer, solution)
|
|
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
|
|
|
|
return level, is_correct
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating problem: {e}")
|
|
return level, False
|
|
|
|
def check_answer(self, model_answer: str, solution: str) -> bool:
|
|
"""Check if the model's answer matches the solution."""
|
|
after_think_part = (
|
|
model_answer.split("</think>")[-1].strip()
|
|
if "</think>" in model_answer
|
|
else model_answer
|
|
)
|
|
|
|
boxed_answer = self._extract_boxed_answer(after_think_part)
|
|
if not boxed_answer:
|
|
lines = after_think_part.strip().split("\n")
|
|
if lines:
|
|
boxed_answer = lines[-1].strip()
|
|
|
|
model_clean = self._clean_for_comparison(
|
|
boxed_answer if boxed_answer else after_think_part
|
|
)
|
|
solution_clean = self._clean_for_comparison(solution)
|
|
|
|
return model_clean == solution_clean
|
|
|
|
def _extract_boxed_answer(self, text: str) -> Optional[str]:
|
|
"""Extract answer from a LaTeX boxed expression."""
|
|
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
|
|
if boxed_match:
|
|
return boxed_match.group(1)
|
|
return None
|
|
|
|
def _clean_for_comparison(self, text: str) -> str:
|
|
"""Clean text for comparison."""
|
|
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
|
|
cleaned = re.sub(r"[,\s]", "", cleaned)
|
|
cleaned = cleaned.lower()
|
|
return cleaned
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
|
"""Collect trajectories for the current item."""
|
|
problem_prompt, solution, generator_id = item
|
|
|
|
prefill = "\n<think>\n"
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
|
{"role": "assistant", "content": prefill},
|
|
]
|
|
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
completions = await self.server.completion(
|
|
prompt=prompt,
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
)
|
|
|
|
to_score = []
|
|
|
|
level = None
|
|
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
|
|
if generator_id in generator_ids:
|
|
level = lvl
|
|
break
|
|
|
|
for i, completion in enumerate(completions.choices):
|
|
model_answer = prefill + (
|
|
completion.text
|
|
if hasattr(completion, "text")
|
|
else completion.message.content
|
|
)
|
|
logger.debug(f"model_answer: {model_answer}")
|
|
|
|
full_messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
|
{"role": "assistant", "content": model_answer},
|
|
]
|
|
|
|
to_score.append((full_messages, solution, generator_id, level))
|
|
|
|
backlog = []
|
|
|
|
return to_score, backlog
|
|
|
|
async def score(self, rollout_group_data) -> ScoredDataGroup:
|
|
"""Score the collected trajectories."""
|
|
scored_data = ScoredDataGroup()
|
|
scored_data["tokens"] = []
|
|
scored_data["masks"] = []
|
|
scored_data["scores"] = []
|
|
scored_data["messages"] = []
|
|
|
|
for i, (messages, solution, generator_id, level) in enumerate(
|
|
rollout_group_data
|
|
):
|
|
model_answer = messages[-1]["content"]
|
|
current_score = 0.0
|
|
|
|
is_correct = self.check_answer(model_answer, solution)
|
|
if is_correct:
|
|
current_score += self.config.correct_reward
|
|
else:
|
|
current_score += self.config.incorrect_reward
|
|
|
|
self.percent_correct_buffer.append(1 if is_correct else 0)
|
|
if level is not None:
|
|
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
|
self.curriculum.record_performance(generator_id, is_correct)
|
|
|
|
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
|
|
if think_match:
|
|
think_content = think_match.group(1).strip()
|
|
if think_content:
|
|
current_score += self.config.think_block_bonus
|
|
|
|
after_think_part = (
|
|
model_answer.split("</think>")[-1].strip()
|
|
if "</think>" in model_answer
|
|
else model_answer
|
|
)
|
|
boxed_answer_content = self._extract_boxed_answer(after_think_part)
|
|
if boxed_answer_content is not None:
|
|
current_score += self.config.boxed_answer_bonus
|
|
|
|
logger.info(
|
|
f"Item {i}: Correct: {is_correct}, "
|
|
f"Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, "
|
|
f"Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, "
|
|
f"Final Score: {current_score}"
|
|
)
|
|
|
|
tokens_dict = tokenize_for_trainer(
|
|
self.tokenizer,
|
|
messages,
|
|
None,
|
|
)
|
|
|
|
scored_data["tokens"].append(tokens_dict["tokens"])
|
|
scored_data["masks"].append(tokens_dict["masks"])
|
|
scored_data["scores"].append(current_score)
|
|
scored_data["messages"].append(messages)
|
|
|
|
self.curriculum.advance_difficulty()
|
|
|
|
return scored_data
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[APIServerConfig]]:
|
|
"""Initialize environment and OpenAI configurations with default values."""
|
|
env_config = InfiniteMathEnvConfig(
|
|
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
max_num_workers=64,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=10000,
|
|
batch_size=1024,
|
|
steps_per_eval=25,
|
|
max_token_length=4096,
|
|
inference_weight=1.0,
|
|
wandb_name="infinimath",
|
|
data_path_to_save_groups="data/infinite_math_groups.jsonl",
|
|
starting_level=1,
|
|
progress_threshold=0.8,
|
|
min_evaluations=10,
|
|
max_attempts_per_problem=3,
|
|
correct_reward=1.0,
|
|
incorrect_reward=-0.5,
|
|
think_block_bonus=0.2,
|
|
boxed_answer_bonus=0.2,
|
|
apply_length_penalty=True,
|
|
length_threshold_ratio=0.6,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
word_problem_model_name="gpt-4.1-mini",
|
|
word_problem_openai_api_key=None,
|
|
word_problem_openai_base_url=None,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=64,
|
|
)
|
|
]
|
|
return env_config, server_configs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
InfiniteMathEnv.cli()
|