mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
667 lines
27 KiB
Python
667 lines
27 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from dotenv import load_dotenv
|
|
from openai import AsyncOpenAI, NotGiven
|
|
|
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
from .curriculum import MathCurriculum
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
system_prompt = """You are an expert mathematician that can use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
|
|
You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
|
|
|
|
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
|
|
|
|
Follow these steps:
|
|
1. Understand the problem carefully
|
|
2. Plan your approach
|
|
3. Execute the calculations step-by-step
|
|
4. Verify your solution
|
|
5. Express the final answer as \\boxed{your answer here}
|
|
|
|
You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
|
|
|
|
Your answer format should be:
|
|
<think>
|
|
[Your detailed step-by-step reasoning process here]
|
|
</think>
|
|
|
|
\\boxed{your final answer here}
|
|
|
|
Remember to format your final answer correctly as this is important for evaluation. Do not apply any rounding to your final answer, be as exact as possible."""
|
|
|
|
|
|
class InfiniteMathEnvConfig(BaseEnvConfig):
|
|
"""Configuration for the InfiniteMath environment."""
|
|
|
|
starting_level: int = 1
|
|
progress_threshold: float = 0.8
|
|
min_evaluations: int = 5
|
|
|
|
max_attempts_per_problem: int = 3
|
|
correct_reward: float = 1.0
|
|
incorrect_reward: float = -1.0
|
|
think_block_bonus: float = 0.2
|
|
boxed_answer_bonus: float = 0.2
|
|
|
|
apply_length_penalty: bool = True
|
|
length_threshold_ratio: float = 0.5
|
|
|
|
temperature: float = 0.7
|
|
top_p: float = 0.9
|
|
|
|
# Model for word problem generation
|
|
word_problem_model_name: Optional[str] = "gpt-4.1-mini"
|
|
word_problem_openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
|
|
word_problem_openai_base_url: Optional[str] = None # Add for custom server
|
|
|
|
|
|
class InfiniteMathEnv(BaseEnv):
|
|
"""Environment for procedurally generated math problems with curriculum advancement."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: InfiniteMathEnvConfig,
|
|
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config = config
|
|
|
|
self.percent_correct_buffer = []
|
|
self.level_correct_buffer = {i: [] for i in range(1, 8)}
|
|
self.eval_metrics = []
|
|
|
|
self.curriculum = None
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
async def setup(self):
|
|
"""Initialize the environment and curriculum."""
|
|
logger.info("Setting up InfiniteMathEnv")
|
|
|
|
self.curriculum = MathCurriculum(
|
|
starting_level=self.config.starting_level,
|
|
progress_threshold=self.config.progress_threshold,
|
|
min_evaluations=self.config.min_evaluations,
|
|
)
|
|
|
|
self.eval_problems = {}
|
|
for level in range(1, 8):
|
|
self.eval_problems[level] = []
|
|
temp_curriculum = MathCurriculum(starting_level=level)
|
|
attempts = 0
|
|
max_attempts_per_level = 20
|
|
|
|
while (
|
|
len(self.eval_problems[level]) < 10
|
|
and attempts < max_attempts_per_level
|
|
):
|
|
try:
|
|
problem, solution, generator_id = temp_curriculum.get_problem()
|
|
problem = self._strip_latex_delimiters(problem)
|
|
solution = self._strip_latex_delimiters(solution)
|
|
self.eval_problems[level].append((problem, solution, generator_id))
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error generating evaluation problem for level {level}: {e}"
|
|
)
|
|
attempts += 1
|
|
|
|
logger.info(
|
|
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
|
|
)
|
|
|
|
for level in range(1, 8):
|
|
if not self.eval_problems[level]:
|
|
logger.warning(
|
|
f"No valid evaluation problems for level {level}, adding fallback"
|
|
)
|
|
if level == 1:
|
|
self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
|
|
elif level == 2:
|
|
self.eval_problems[level].append(
|
|
("What is the square root of 16?", "4", 6)
|
|
)
|
|
elif level == 3:
|
|
self.eval_problems[level].append(
|
|
(
|
|
"What is the area of a triangle with base 6 and height 8?",
|
|
"24",
|
|
18,
|
|
)
|
|
)
|
|
elif level == 4:
|
|
self.eval_problems[level].append(
|
|
("What is the solution to x + 5 = 12?", "7", 26)
|
|
)
|
|
elif level == 5:
|
|
self.eval_problems[level].append(
|
|
("What is the volume of a cube with side length 3?", "27", 33)
|
|
)
|
|
elif level == 6:
|
|
self.eval_problems[level].append(
|
|
("What is 5 factorial?", "120", 31)
|
|
)
|
|
else:
|
|
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
|
|
|
|
def _strip_latex_delimiters(self, text: str) -> str:
|
|
"""Strip LaTeX delimiters ($...$) from text."""
|
|
return re.sub(r"\$(.*?)\$", r"\1", text)
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save curriculum state in checkpoint."""
|
|
if data is None:
|
|
data = {}
|
|
|
|
data["curriculum_level"] = self.curriculum.get_current_level()
|
|
data["performance_history"] = {
|
|
str(k): v for k, v in self.curriculum.performance_history.items()
|
|
}
|
|
|
|
super().save_checkpoint(step, data)
|
|
|
|
def load_checkpoint(self):
|
|
"""Load curriculum state from checkpoint."""
|
|
super().load_checkpoint()
|
|
|
|
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
|
|
try:
|
|
with open(checkpoint_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
if "curriculum_level" in data:
|
|
level = data["curriculum_level"]
|
|
self.curriculum.current_level = level
|
|
|
|
if "performance_history" in data:
|
|
self.curriculum.performance_history = {
|
|
int(k): v for k, v in data["performance_history"].items()
|
|
}
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
logger.warning(f"Failed to load checkpoint: {e}")
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / max(1, len(self.percent_correct_buffer))
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
for level, buffer in self.level_correct_buffer.items():
|
|
if buffer:
|
|
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
|
|
buffer
|
|
)
|
|
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
|
|
|
|
if self.curriculum:
|
|
current_level = self.curriculum.get_current_level()
|
|
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
|
|
|
|
wandb_metrics["curriculum/current_level"] = current_level
|
|
wandb_metrics["curriculum/max_level"] = max_level
|
|
wandb_metrics["curriculum/progress_percent"] = (
|
|
current_level / max_level
|
|
) * 100
|
|
|
|
wandb_metrics["curriculum/level_description"] = (
|
|
self.curriculum.get_level_description()
|
|
)
|
|
|
|
if current_level in self.curriculum.performance_history:
|
|
history = self.curriculum.performance_history[current_level]
|
|
if history:
|
|
recent_history = history[
|
|
-min(len(history), self.curriculum.min_evaluations) :
|
|
]
|
|
if recent_history:
|
|
success_rate = sum(recent_history) / len(recent_history)
|
|
wandb_metrics["curriculum/current_level_success_rate"] = (
|
|
success_rate
|
|
)
|
|
wandb_metrics["curriculum/threshold_to_advance"] = (
|
|
self.curriculum.progress_threshold
|
|
)
|
|
wandb_metrics["curriculum/remaining_to_threshold"] = max(
|
|
0, self.curriculum.progress_threshold - success_rate
|
|
)
|
|
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
|
|
self.percent_correct_buffer = []
|
|
for level in self.level_correct_buffer:
|
|
self.level_correct_buffer[level] = []
|
|
self.eval_metrics = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def _convert_to_word_problem(self, raw_problem_text: str) -> str:
|
|
"""Converts a raw math problem string into a word problem using an LLM."""
|
|
system_prompt_word_problem = """You are an expert creative writer. Your task is to transform a given raw mathematical expression into an engaging and imaginative word problem.
|
|
|
|
**Critical Instructions:**
|
|
1. **Strict Preservation:** The core mathematical question, ALL numbers, and ALL operations from the raw problem MUST be EXACTLY preserved in the word problem. Do NOT change the calculation required. For example, if the raw problem is 'A - B', the word problem must represent subtraction of B from A, not any other operation.
|
|
2. **Clarity:** The word problem must clearly and unambiguously lead to solving the original mathematical expression.
|
|
3. **Conciseness:** Keep the word problem relatively short and to the point.
|
|
4. **Output Format:** Output ONLY the word problem text. Do NOT include any preambles, self-references (like 'Here is a word problem:'), special tokens (like '<|start_header_id|>'), or any text other than the word problem itself.
|
|
|
|
**Examples of Correct Transformation:**
|
|
Raw Problem: 5 * 3
|
|
Word Problem: Sarah is baking cookies, and each batch requires 3 eggs. If Sarah wants to bake 5 batches, how many eggs will she need in total?
|
|
|
|
Raw Problem: |10 - 15|
|
|
Word Problem: A submarine is 10 meters below sea level. Another submarine is 15 meters below sea level. What is the absolute difference in their depths in meters?
|
|
|
|
Raw Problem: sqrt(16)
|
|
Word Problem: A square piece of land has an area of 16 square units. What is the length of one of its sides in units?
|
|
|
|
**Example of Incorrect Transformation (Operation Changed):**
|
|
Raw Problem: |3 - (-67)| (This is 3 + 67)
|
|
Incorrect Word Problem: In a magical forest, there are 3 enchanted trees, and each tree has 67 glowing fruits. How many glowing fruits are there in total? (This became 3 * 67)
|
|
Correct Word Problem: A bird watcher is 3 meters up a tree. She spots a rare bird 67 meters below ground level in a cave. What is the total vertical distance between the bird watcher and the rare bird in meters?
|
|
"""
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt_word_problem},
|
|
{"role": "user", "content": f"Raw Problem: {raw_problem_text}"},
|
|
]
|
|
|
|
prompt_for_llm = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
|
try:
|
|
api_key_to_use = self.config.word_problem_openai_api_key or os.environ.get(
|
|
"OPENAI_API_KEY"
|
|
)
|
|
base_url_to_use = self.config.word_problem_openai_base_url
|
|
model_to_use = (
|
|
self.config.word_problem_model_name or "gpt-4.1-mini"
|
|
) # Fallback if not set in config somehow
|
|
|
|
if not api_key_to_use:
|
|
logger.error(
|
|
"OpenAI API key for word problem generation is not configured (checked config and OPENAI_API_KEY env var)."
|
|
)
|
|
return raw_problem_text # Fallback if no API key
|
|
|
|
client = AsyncOpenAI(
|
|
api_key=api_key_to_use,
|
|
base_url=base_url_to_use, # If base_url_to_use is None, client uses default. No need for NotGiven here.
|
|
)
|
|
|
|
chat_completions = await client.chat.completions.create(
|
|
model=model_to_use,
|
|
messages=messages,
|
|
n=1,
|
|
max_tokens=512,
|
|
temperature=0.7,
|
|
top_p=1.0,
|
|
)
|
|
|
|
generated_text = chat_completions.choices[0].message.content
|
|
|
|
original_llm_output = (
|
|
generated_text # Store raw LLM output for logging if cleaning fails
|
|
)
|
|
|
|
# Simplified Cleaning: Only strip whitespace now
|
|
cleaned_text = original_llm_output.strip()
|
|
|
|
if (
|
|
not cleaned_text
|
|
): # If cleaning (now just stripping) results in an empty string, fallback
|
|
logger.warning(
|
|
f"Word problem conversion for '{raw_problem_text}' resulted in empty string after stripping. Original LLM output was: '{original_llm_output}'. Falling back to raw problem."
|
|
)
|
|
return raw_problem_text
|
|
|
|
logger.info(
|
|
f"Converted raw problem '{raw_problem_text}' to word problem: '{cleaned_text}'"
|
|
)
|
|
return cleaned_text
|
|
except Exception as e:
|
|
log_message_error = (
|
|
f"Error converting to word problem for '{raw_problem_text}': {e}"
|
|
)
|
|
logger.error(log_message_error)
|
|
return raw_problem_text
|
|
|
|
async def get_next_item(self):
|
|
"""Get the next problem based on current curriculum level."""
|
|
raw_problem, solution, generator_id = self.curriculum.get_problem()
|
|
|
|
# Strip LaTeX delimiters from the raw problem before converting to word problem
|
|
raw_problem_stripped = self._strip_latex_delimiters(raw_problem)
|
|
# Also strip from solution for consistency, though solution isn't used in word problem conversion
|
|
solution_stripped = self._strip_latex_delimiters(solution)
|
|
|
|
# Convert the stripped raw problem to a word problem
|
|
word_problem_text = await self._convert_to_word_problem(raw_problem_stripped)
|
|
|
|
# Create a message with the word problem
|
|
# The agent will solve this word problem, which should map back to the original solution.
|
|
prompt = tuple(
|
|
[frozenset({"role": "user", "content": word_problem_text}.items())]
|
|
)
|
|
|
|
# Return the word problem with the original (stripped) solution and generator_id
|
|
return (prompt, solution_stripped, generator_id)
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate the model on test problems at the current curriculum level."""
|
|
current_level = self.curriculum.get_current_level()
|
|
logger.info(f"Starting evaluation for curriculum level {current_level}")
|
|
|
|
eval_tasks = []
|
|
eval_generator_ids = []
|
|
if current_level in self.eval_problems:
|
|
for problem, solution, generator_id in self.eval_problems[current_level]:
|
|
eval_tasks.append(
|
|
self.evaluate_single_problem(problem, solution, current_level)
|
|
)
|
|
eval_generator_ids.append(generator_id)
|
|
|
|
if not eval_tasks:
|
|
logger.warning(
|
|
f"No evaluation problems available for level {current_level}"
|
|
)
|
|
return []
|
|
|
|
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
|
|
results = await asyncio.gather(*eval_tasks)
|
|
|
|
correct_count = sum(1 for _, is_correct in results if is_correct)
|
|
total_count = len(results)
|
|
accuracy = correct_count / total_count if total_count > 0 else 0
|
|
|
|
logger.info(
|
|
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
|
|
)
|
|
|
|
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
|
|
self.eval_metrics.append(("eval/current_level", current_level))
|
|
|
|
for i, (_, is_correct) in enumerate(results):
|
|
if i < len(eval_generator_ids):
|
|
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
|
|
else:
|
|
sample_generator_id = random.choice(
|
|
self.curriculum.DIFFICULTY_LEVELS[current_level]
|
|
)
|
|
self.curriculum.record_performance(sample_generator_id, is_correct)
|
|
|
|
advanced = self.curriculum.advance_difficulty()
|
|
new_level = self.curriculum.get_current_level()
|
|
|
|
if advanced:
|
|
logger.info(f"Advanced from level {current_level} to level {new_level}!")
|
|
self.eval_metrics.append(("eval/advanced_level", 1))
|
|
else:
|
|
logger.info(f"Remaining at level {current_level}")
|
|
self.eval_metrics.append(("eval/advanced_level", 0))
|
|
|
|
return self.eval_metrics
|
|
|
|
async def evaluate_single_problem(
|
|
self, problem: str, solution: str, level: int
|
|
) -> Tuple[int, bool]:
|
|
"""Evaluate a single problem."""
|
|
try:
|
|
# Problem here is already stripped of LaTeX by the setup method
|
|
# Convert the raw problem to a word problem for evaluation
|
|
word_problem_text = await self._convert_to_word_problem(problem)
|
|
logger.debug(
|
|
f"Evaluating level {level} word problem: {word_problem_text[:50]}... (Original raw: {problem[:30]}...)"
|
|
)
|
|
|
|
# Convert messages to a single prompt using the tokenizer
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": word_problem_text}, # Use word problem here
|
|
]
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
|
# Add prefilled thinking starter
|
|
prefill = "\n<think>\n"
|
|
prefilled_prompt = prompt + prefill
|
|
|
|
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
|
|
completion = await self.server.completion(
|
|
prompt=prefilled_prompt,
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0,
|
|
top_p=1.0,
|
|
split="eval",
|
|
)
|
|
|
|
model_answer = prefill + (
|
|
completion.choices[0].text
|
|
if hasattr(completion.choices[0], "text")
|
|
else completion.choices[0].message.content
|
|
)
|
|
|
|
is_correct = self.check_answer(model_answer, solution)
|
|
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
|
|
|
|
return level, is_correct
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating problem: {e}")
|
|
return level, False
|
|
|
|
def check_answer(self, model_answer: str, solution: str) -> bool:
|
|
"""Check if the model's answer matches the solution."""
|
|
after_think_part = (
|
|
model_answer.split("</think>")[-1].strip()
|
|
if "</think>" in model_answer
|
|
else model_answer
|
|
)
|
|
|
|
boxed_answer = self._extract_boxed_answer(after_think_part)
|
|
if not boxed_answer:
|
|
lines = after_think_part.strip().split("\n")
|
|
if lines:
|
|
boxed_answer = lines[-1].strip()
|
|
|
|
model_clean = self._clean_for_comparison(
|
|
boxed_answer if boxed_answer else after_think_part
|
|
)
|
|
solution_clean = self._clean_for_comparison(solution)
|
|
|
|
return model_clean == solution_clean
|
|
|
|
def _extract_boxed_answer(self, text: str) -> Optional[str]:
|
|
"""Extract answer from a LaTeX boxed expression."""
|
|
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
|
|
if boxed_match:
|
|
return boxed_match.group(1)
|
|
return None
|
|
|
|
def _clean_for_comparison(self, text: str) -> str:
|
|
"""Clean text for comparison."""
|
|
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
|
|
cleaned = re.sub(r"[,\s]", "", cleaned)
|
|
cleaned = cleaned.lower()
|
|
return cleaned
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
|
"""Collect trajectories for the current item."""
|
|
problem_prompt, solution, generator_id = item
|
|
|
|
prefill = "\n<think>\n"
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
|
{"role": "assistant", "content": prefill},
|
|
]
|
|
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
completions = await self.server.completion(
|
|
prompt=prompt,
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
)
|
|
|
|
to_score = []
|
|
|
|
level = None
|
|
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
|
|
if generator_id in generator_ids:
|
|
level = lvl
|
|
break
|
|
|
|
for i, completion in enumerate(completions.choices):
|
|
model_answer = prefill + (
|
|
completion.text
|
|
if hasattr(completion, "text")
|
|
else completion.message.content
|
|
)
|
|
print("model_answer", model_answer)
|
|
|
|
full_messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
|
{"role": "assistant", "content": model_answer},
|
|
]
|
|
|
|
to_score.append((full_messages, solution, generator_id, level))
|
|
|
|
backlog = []
|
|
|
|
return to_score, backlog
|
|
|
|
async def score(self, rollout_group_data) -> ScoredDataGroup:
|
|
"""Score the collected trajectories."""
|
|
scored_data = ScoredDataGroup()
|
|
scored_data["tokens"] = []
|
|
scored_data["masks"] = []
|
|
scored_data["scores"] = []
|
|
scored_data["messages"] = []
|
|
|
|
for i, (messages, solution, generator_id, level) in enumerate(
|
|
rollout_group_data
|
|
):
|
|
model_answer = messages[-1]["content"]
|
|
current_score = 0.0
|
|
|
|
is_correct = self.check_answer(model_answer, solution)
|
|
if is_correct:
|
|
current_score += self.config.correct_reward
|
|
else:
|
|
current_score += self.config.incorrect_reward
|
|
|
|
self.percent_correct_buffer.append(1 if is_correct else 0)
|
|
if level is not None:
|
|
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
|
self.curriculum.record_performance(generator_id, is_correct)
|
|
|
|
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
|
|
if think_match:
|
|
think_content = think_match.group(1).strip()
|
|
if think_content:
|
|
current_score += self.config.think_block_bonus
|
|
|
|
after_think_part = (
|
|
model_answer.split("</think>")[-1].strip()
|
|
if "</think>" in model_answer
|
|
else model_answer
|
|
)
|
|
boxed_answer_content = self._extract_boxed_answer(after_think_part)
|
|
if boxed_answer_content is not None:
|
|
current_score += self.config.boxed_answer_bonus
|
|
|
|
logger.info(
|
|
f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}"
|
|
)
|
|
|
|
tokens_dict = tokenize_for_trainer(
|
|
self.tokenizer,
|
|
messages,
|
|
None,
|
|
)
|
|
|
|
scored_data["tokens"].append(tokens_dict["tokens"])
|
|
scored_data["masks"].append(tokens_dict["masks"])
|
|
scored_data["scores"].append(current_score)
|
|
scored_data["messages"].append(messages)
|
|
|
|
self.curriculum.advance_difficulty()
|
|
|
|
return scored_data
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[OpenaiConfig]]:
|
|
"""Initialize environment and OpenAI configurations with default values."""
|
|
env_config = InfiniteMathEnvConfig(
|
|
# BaseEnvConfig fields
|
|
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
max_num_workers=64,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=10000,
|
|
batch_size=1024,
|
|
steps_per_eval=25,
|
|
max_token_length=4096,
|
|
inference_weight=1.0,
|
|
wandb_name="infinite_math",
|
|
data_path_to_save_groups="data/infinite_math_groups.jsonl",
|
|
# InfiniteMathEnvConfig specific fields
|
|
starting_level=1,
|
|
progress_threshold=0.8,
|
|
min_evaluations=10,
|
|
max_attempts_per_problem=3,
|
|
correct_reward=1.0,
|
|
incorrect_reward=-0.5,
|
|
think_block_bonus=0.2,
|
|
boxed_answer_bonus=0.2,
|
|
apply_length_penalty=True,
|
|
length_threshold_ratio=0.6,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
# Specify the model and connection details for word problem generation
|
|
word_problem_model_name="gpt-4.1-mini",
|
|
word_problem_openai_api_key=None, # Default to None (uses OPENAI_API_KEY env var)
|
|
word_problem_openai_base_url=None, # Default to None (uses official OpenAI endpoint)
|
|
)
|
|
|
|
server_configs = [
|
|
OpenaiConfig(
|
|
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=64,
|
|
)
|
|
]
|
|
return env_config, server_configs
|
|
|
|
@classmethod
|
|
def cli(cls):
|
|
"""Command Line Interface runner for the environment."""
|
|
super().cli()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
InfiniteMathEnv.cli()
|