atropos/environments/infinimath/infinimath_env.py
2025-05-12 08:07:39 +10:00

545 lines
20 KiB
Python

import asyncio
import json
import logging
import random
import re
from typing import Dict, List, Optional, Tuple, Union
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .curriculum import MathCurriculum
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
system_prompt = """You are an expert mathematician that can use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
Follow these steps:
1. Understand the problem carefully
2. Plan your approach
3. Execute the calculations step-by-step
4. Verify your solution
5. Express the final answer as \\boxed{your answer here}
You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
Your answer format should be:
<think>
[Your detailed step-by-step reasoning process here]
</think>
\\boxed{your final answer here}
Remember to format your final answer correctly as this is important for evaluation."""
class InfiniteMathEnvConfig(BaseEnvConfig):
"""Configuration for the InfiniteMath environment."""
starting_level: int = 1
progress_threshold: float = 0.8
min_evaluations: int = 5
max_attempts_per_problem: int = 3
correct_reward: float = 1.0
incorrect_reward: float = -1.0
think_block_bonus: float = 0.2
boxed_answer_bonus: float = 0.2
apply_length_penalty: bool = True
length_threshold_ratio: float = 0.5
temperature: float = 0.7
top_p: float = 0.9
class InfiniteMathEnv(BaseEnv):
"""Environment for procedurally generated math problems with curriculum advancement."""
def __init__(
self,
config: InfiniteMathEnvConfig,
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.config = config
self.percent_correct_buffer = []
self.level_correct_buffer = {i: [] for i in range(1, 8)}
self.eval_metrics = []
self.curriculum = None
self.system_prompt = system_prompt
async def setup(self):
"""Initialize the environment and curriculum."""
logger.info("Setting up InfiniteMathEnv")
self.curriculum = MathCurriculum(
starting_level=self.config.starting_level,
progress_threshold=self.config.progress_threshold,
min_evaluations=self.config.min_evaluations,
)
self.eval_problems = {}
for level in range(1, 8):
self.eval_problems[level] = []
temp_curriculum = MathCurriculum(starting_level=level)
attempts = 0
max_attempts_per_level = 20
while (
len(self.eval_problems[level]) < 10
and attempts < max_attempts_per_level
):
try:
problem, solution, generator_id = temp_curriculum.get_problem()
problem = self._strip_latex_delimiters(problem)
solution = self._strip_latex_delimiters(solution)
self.eval_problems[level].append((problem, solution, generator_id))
except Exception as e:
logger.warning(
f"Error generating evaluation problem for level {level}: {e}"
)
attempts += 1
logger.info(
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
)
for level in range(1, 8):
if not self.eval_problems[level]:
logger.warning(
f"No valid evaluation problems for level {level}, adding fallback"
)
if level == 1:
self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
elif level == 2:
self.eval_problems[level].append(
("What is the square root of 16?", "4", 6)
)
elif level == 3:
self.eval_problems[level].append(
(
"What is the area of a triangle with base 6 and height 8?",
"24",
18,
)
)
elif level == 4:
self.eval_problems[level].append(
("What is the solution to x + 5 = 12?", "7", 26)
)
elif level == 5:
self.eval_problems[level].append(
("What is the volume of a cube with side length 3?", "27", 33)
)
elif level == 6:
self.eval_problems[level].append(
("What is 5 factorial?", "120", 31)
)
else:
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
def _strip_latex_delimiters(self, text: str) -> str:
"""Strip LaTeX delimiters ($...$) from text."""
return re.sub(r"\$(.*?)\$", r"\1", text)
def save_checkpoint(self, step, data=None):
"""Save curriculum state in checkpoint."""
if data is None:
data = {}
data["curriculum_level"] = self.curriculum.get_current_level()
data["performance_history"] = {
str(k): v for k, v in self.curriculum.performance_history.items()
}
super().save_checkpoint(step, data)
def load_checkpoint(self):
"""Load curriculum state from checkpoint."""
super().load_checkpoint()
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
try:
with open(checkpoint_path, "r") as f:
data = json.load(f)
if "curriculum_level" in data:
level = data["curriculum_level"]
self.curriculum.current_level = level
if "performance_history" in data:
self.curriculum.performance_history = {
int(k): v for k, v in data["performance_history"].items()
}
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load checkpoint: {e}")
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / max(1, len(self.percent_correct_buffer))
except ZeroDivisionError:
pass
for level, buffer in self.level_correct_buffer.items():
if buffer:
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
buffer
)
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
if self.curriculum:
current_level = self.curriculum.get_current_level()
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
wandb_metrics["curriculum/current_level"] = current_level
wandb_metrics["curriculum/max_level"] = max_level
wandb_metrics["curriculum/progress_percent"] = (
current_level / max_level
) * 100
wandb_metrics["curriculum/level_description"] = (
self.curriculum.get_level_description()
)
if current_level in self.curriculum.performance_history:
history = self.curriculum.performance_history[current_level]
if history:
recent_history = history[
-min(len(history), self.curriculum.min_evaluations) :
]
if recent_history:
success_rate = sum(recent_history) / len(recent_history)
wandb_metrics["curriculum/current_level_success_rate"] = (
success_rate
)
wandb_metrics["curriculum/threshold_to_advance"] = (
self.curriculum.progress_threshold
)
wandb_metrics["curriculum/remaining_to_threshold"] = max(
0, self.curriculum.progress_threshold - success_rate
)
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.percent_correct_buffer = []
for level in self.level_correct_buffer:
self.level_correct_buffer[level] = []
self.eval_metrics = []
await super().wandb_log(wandb_metrics)
async def get_next_item(self):
"""Get the next problem based on current curriculum level."""
problem, solution, generator_id = self.curriculum.get_problem()
problem = self._strip_latex_delimiters(problem)
solution = self._strip_latex_delimiters(solution)
prompt = tuple([frozenset({"role": "user", "content": problem}.items())])
return (prompt, solution, generator_id)
async def evaluate(self, *args, **kwargs):
"""Evaluate the model on test problems at the current curriculum level."""
current_level = self.curriculum.get_current_level()
logger.info(f"Starting evaluation for curriculum level {current_level}")
eval_tasks = []
eval_generator_ids = []
if current_level in self.eval_problems:
for problem, solution, generator_id in self.eval_problems[current_level]:
eval_tasks.append(
self.evaluate_single_problem(problem, solution, current_level)
)
eval_generator_ids.append(generator_id)
if not eval_tasks:
logger.warning(
f"No evaluation problems available for level {current_level}"
)
return []
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
results = await asyncio.gather(*eval_tasks)
correct_count = sum(1 for _, is_correct in results if is_correct)
total_count = len(results)
accuracy = correct_count / total_count if total_count > 0 else 0
logger.info(
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
)
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
self.eval_metrics.append(("eval/current_level", current_level))
for i, (_, is_correct) in enumerate(results):
if i < len(eval_generator_ids):
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
else:
sample_generator_id = random.choice(
self.curriculum.DIFFICULTY_LEVELS[current_level]
)
self.curriculum.record_performance(sample_generator_id, is_correct)
advanced = self.curriculum.advance_difficulty()
new_level = self.curriculum.get_current_level()
if advanced:
logger.info(f"Advanced from level {current_level} to level {new_level}!")
self.eval_metrics.append(("eval/advanced_level", 1))
else:
logger.info(f"Remaining at level {current_level}")
self.eval_metrics.append(("eval/advanced_level", 0))
return self.eval_metrics
async def evaluate_single_problem(
self, problem: str, solution: str, level: int
) -> Tuple[int, bool]:
"""Evaluate a single problem."""
try:
logger.debug(f"Evaluating level {level} problem: {problem[:30]}...")
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": problem},
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
prefill = "\n<think>\n"
prefilled_prompt = prompt + prefill
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
completion = await self.server.completion(
prompt=prefilled_prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
top_p=1.0,
split="eval",
)
model_answer = prefill + (
completion.choices[0].text
if hasattr(completion.choices[0], "text")
else completion.choices[0].message.content
)
is_correct = self.check_answer(model_answer, solution)
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
return level, is_correct
except Exception as e:
logger.error(f"Error evaluating problem: {e}")
return level, False
def check_answer(self, model_answer: str, solution: str) -> bool:
"""Check if the model's answer matches the solution."""
after_think_part = (
model_answer.split("</think>")[-1].strip()
if "</think>" in model_answer
else model_answer
)
boxed_answer = self._extract_boxed_answer(after_think_part)
if not boxed_answer:
lines = after_think_part.strip().split("\n")
if lines:
boxed_answer = lines[-1].strip()
model_clean = self._clean_for_comparison(
boxed_answer if boxed_answer else after_think_part
)
solution_clean = self._clean_for_comparison(solution)
return model_clean == solution_clean
def _extract_boxed_answer(self, text: str) -> Optional[str]:
"""Extract answer from a LaTeX boxed expression."""
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
if boxed_match:
return boxed_match.group(1)
return None
def _clean_for_comparison(self, text: str) -> str:
"""Clean text for comparison."""
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
cleaned = re.sub(r"[,\s]", "", cleaned)
cleaned = cleaned.lower()
return cleaned
async def collect_trajectories(self, item) -> Tuple[List, List]:
"""Collect trajectories for the current item."""
problem_prompt, solution, generator_id = item
prefill = "\n<think>\n"
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": dict(problem_prompt[0])["content"]},
{"role": "assistant", "content": prefill},
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
to_score = []
level = None
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
if generator_id in generator_ids:
level = lvl
break
for i, completion in enumerate(completions.choices):
model_answer = prefill + (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
print("model_answer", model_answer)
full_messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": dict(problem_prompt[0])["content"]},
{"role": "assistant", "content": model_answer},
]
to_score.append((full_messages, solution, generator_id, level))
backlog = []
return to_score, backlog
async def score(self, rollout_group_data) -> ScoredDataGroup:
"""Score the collected trajectories."""
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["messages"] = []
for i, (messages, solution, generator_id, level) in enumerate(
rollout_group_data
):
model_answer = messages[-1]["content"]
current_score = 0.0
is_correct = self.check_answer(model_answer, solution)
if is_correct:
current_score += self.config.correct_reward
else:
current_score += self.config.incorrect_reward
self.percent_correct_buffer.append(1 if is_correct else 0)
if level is not None:
self.level_correct_buffer[level].append(1 if is_correct else 0)
self.curriculum.record_performance(generator_id, is_correct)
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
if think_match:
think_content = think_match.group(1).strip()
if think_content:
current_score += self.config.think_block_bonus
after_think_part = (
model_answer.split("</think>")[-1].strip()
if "</think>" in model_answer
else model_answer
)
boxed_answer_content = self._extract_boxed_answer(after_think_part)
if boxed_answer_content is not None:
current_score += self.config.boxed_answer_bonus
logger.info(
f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}"
)
tokens_dict = tokenize_for_trainer(
self.tokenizer,
messages,
None,
)
scored_data["tokens"].append(tokens_dict["tokens"])
scored_data["masks"].append(tokens_dict["masks"])
scored_data["scores"].append(current_score)
scored_data["messages"].append(messages)
self.curriculum.advance_difficulty()
return scored_data
@classmethod
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[OpenaiConfig]]:
"""Initialize environment and OpenAI configurations with default values."""
env_config = InfiniteMathEnvConfig(
# BaseEnvConfig fields
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
group_size=8,
use_wandb=True,
max_num_workers=64,
rollout_server_url="http://localhost:8000",
total_steps=10000,
batch_size=1024,
steps_per_eval=25,
max_token_length=4096,
inference_weight=1.0,
wandb_name="infinite_math",
data_path_to_save_groups="data/infinite_math_groups.jsonl",
# InfiniteMathEnvConfig specific fields
starting_level=1,
progress_threshold=0.8,
min_evaluations=10,
max_attempts_per_problem=3, # Default from class, not in old main
correct_reward=1.0, # As in old main
incorrect_reward=-0.5, # As in old main (class default was -1.0)
think_block_bonus=0.2, # As per previous update
boxed_answer_bonus=0.2, # As per previous update
apply_length_penalty=True, # As in old main
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
temperature=0.7, # As in old main
top_p=0.9, # As in old main
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=64,
)
]
return env_config, server_configs
@classmethod
def cli(cls):
"""Command Line Interface runner for the environment."""
super().cli()
if __name__ == "__main__":
InfiniteMathEnv.cli()