copied from trajectory handler branch

This commit is contained in:
Shannon Sands 2025-05-12 07:26:10 +10:00
parent 101cbdd803
commit 4e7fcd3c9a
8 changed files with 2238 additions and 0 deletions

View file

@ -0,0 +1,105 @@
# InfiniteMath Environment
## Environment Overview
This environment provides procedurally generated math problems with curriculum-based advancement. It allows an agent to solve increasingly difficult math problems, with the difficulty level adapting based on performance.
**Demonstrates:**
- Procedural content generation (math problems).
- Curriculum learning: The environment automatically adjusts the difficulty (levels 1-7) based on the LLM's success rate.
- Step-by-step reasoning evaluation: Rewards correctness, the presence of reasoning steps (within `<think>` tags), and the final answer format (`\boxed{}`).
- Handling LaTeX formatting for problems and answers.
**Training Goal:**
- To train LLMs to solve mathematical problems accurately.
- To encourage explicit step-by-step reasoning before providing an answer.
- To improve the LLM's ability to follow specific formatting instructions (using `<think>` tags and `\boxed{}`).
- To teach the model to handle progressively more complex problems through the curriculum.
## Features
- Progressive difficulty scaling across 7 levels of math problems
- Built-in curriculum system that adapts to agent performance
- Automatic problem generation with solutions
- Reward functions for accuracy, formatting, and boxed answer checking
## Usage
### Running with Default Configuration
To run the InfiniteMath environment with the default configuration:
```bash
python environments/infinite_math/infinimath_local_server.py
```
This will use the default configuration from `configs/envs/infinimath.yaml`.
### Custom Configuration
You can specify a custom configuration file:
```bash
python environments/infinite_math/infinimath_local_server.py --config my_custom_config
```
The `--config` parameter can be:
1. A name (without `.yaml` extension) which will be looked up in `configs/envs/`
2. A relative or absolute path to a YAML file
For example:
```bash
# Using a config in configs/envs/
python environments/infinite_math/infinimath_local_server.py --config infinimath_hard
# Using a config with full path
python environments/infinite_math/infinimath_local_server.py --config /path/to/my/config.yaml
```
## Configuration Structure
The configuration file follows this structure:
```yaml
# Base environment parameters
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
group_size: 1
use_wandb: false
# ... other base parameters
# InfiniteMath specific configuration
infinimath:
# Curriculum parameters
starting_level: 1
progress_threshold: 0.7
# ... other InfiniteMath specific parameters
# Server configuration
server_configs:
- model_name: "gpt-4.1-nano"
api_key: ${OPENAI_API_KEY}
num_requests_for_eval: 70
```
### Important Configuration Parameters
#### Base Parameters
- `tokenizer_name`: The tokenizer to use for encoding/decoding text
- `group_size`: Number of responses to collect per prompt
- `max_token_length`: Maximum token length for generation
- `steps_per_eval`: How often to run evaluations
#### InfiniteMath Specific Parameters
- `starting_level`: Initial difficulty level (1-7)
- `progress_threshold`: Success rate needed to advance levels
- `min_evaluations`: Minimum number of evaluations before level advancement
- `reward_functions`: List of reward functions to apply
#### Server Configuration
- `model_name`: LLM model to use
- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax)
- `num_requests_for_eval`: Number of evaluation requests to allocate

View file

@ -0,0 +1,8 @@
"""
Infinite Math - A reinforcement learning environment for procedurally generated math problems.
"""
from .curriculum import MathCurriculum
from .infinimath import InfiniteMath
__all__ = ["MathCurriculum", "InfiniteMath"]

View file

@ -0,0 +1,378 @@
import random
from typing import Any, Callable, Dict, List, Optional, Tuple
import mathgenerator
class MathCurriculum:
"""
A curriculum manager for the mathgenerator library.
This class organizes math problems by difficulty and provides methods
to generate problems of appropriate difficulty based on the learner's
performance.
"""
# Define difficulty levels and map generator IDs to each level
DIFFICULTY_LEVELS = {
# Level 1: Basic arithmetic operations
1: [
0,
1,
2,
3,
8,
31,
71,
80,
90,
], # Addition, Subtraction, Multiplication, Division, Square, Factorial, Absolute difference, Percentage, IsPrime
# Level 2: Basic operations with fractions and pre-algebra
2: [
6,
11,
13,
16,
28,
44,
47,
53,
97,
118,
119,
124,
], # Square Root, Basic Algebra, Fraction to Decimal, Fraction Division, Fraction Multiplication, Compare Fractions, Cube Root, Exponentiation, Power of Powers, Percentage difference/error, Is Composite
# Level 3: Basic geometry and more algebra
3: [
18,
19,
22,
24,
25,
49,
58,
75,
96,
104,
108,
112,
115,
], # Area of Triangle, Triangle exists check, Third Angle of Triangle, Distance between 2 points, Pythagorean Theorem, Fourth Angle of Quadrilateral, Sum of Angles of Polygon, Area of a Sector, Perimeter of Polygons, Circumference, Arc length, Area of Circle
# Level 4: More advanced algebra and basic statistics
4: [
9,
10,
20,
21,
23,
26,
40,
41,
45,
50,
76,
78,
105,
], # LCM, GCD, Midpoint, Factoring Quadratic, System of Equations, Linear Equations, Common Factors, Intersection of Two Lines, Simple Interest, Quadratic Equation, Mean and Median, Compound Interest, Combine Like terms
# Level 5: Vectors, matrices, and solid geometry
5: [
17,
32,
33,
34,
35,
36,
37,
38,
39,
43,
46,
60,
61,
70,
72,
77,
95,
113,
117,
122,
123,
], # Matrix Multiplication, Surface Areas, Volumes, Vector operations, etc.
# Level 6: Advanced topics (calculus, statistics, computer science)
6: [
4,
5,
7,
12,
14,
15,
27,
30,
42,
48,
52,
54,
55,
56,
59,
62,
64,
73,
79,
84,
88,
89,
91,
103,
107,
110,
], # Binary operations, Calculus, Combinatorics, Probability, etc.
# Level 7: Most complex topics
7: [
65,
66,
67,
68,
69,
74,
85,
92,
93,
94,
98,
99,
100,
101,
106,
109,
111,
121,
], # Complex numbers, Advanced operations, etc.
}
def __init__(
self,
starting_level: int = 1,
progress_threshold: float = 0.8,
min_evaluations: int = 5,
):
"""
Initialize the curriculum manager.
Args:
starting_level: The difficulty level to start with (default: 1)
progress_threshold: The success rate required to advance to the next level (default: 0.8)
min_evaluations: Minimum number of evaluations needed before considering level advancement (default: 5)
"""
self.current_level = starting_level
self.progress_threshold = progress_threshold
self.min_evaluations = min_evaluations
# Performance tracking
self.performance_history = {
level: [] for level in self.DIFFICULTY_LEVELS.keys()
}
# Ensure starting level is valid
if starting_level not in self.DIFFICULTY_LEVELS:
raise ValueError(
f"Invalid starting level: {starting_level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
)
def get_problem(self) -> Tuple[str, str, int]:
"""
Generate a math problem at the current difficulty level.
Returns:
Tuple containing (problem_text, solution_text, generator_id)
"""
# Get the available generator IDs for the current level
available_generators = self.DIFFICULTY_LEVELS[self.current_level]
# Try generators until one works
max_attempts = 5 # Limit the number of attempts to avoid infinite loops
attempts = 0
while attempts < max_attempts:
# Get a random generator ID from the current level
generator_id = random.choice(available_generators)
try:
# Generate the problem
problem, solution = mathgenerator.genById(generator_id)
return problem, solution, generator_id
except Exception as e:
# Log the error and try another generator
print(f"Error with generator {generator_id}: {str(e)}")
attempts += 1
# Remove the problematic generator from the available list for this session
if generator_id in available_generators:
available_generators.remove(generator_id)
# If we've exhausted all generators in this level, move to an adjacent level
if not available_generators:
fallback_level = max(
1, min(7, self.current_level + random.choice([-1, 1]))
)
available_generators = self.DIFFICULTY_LEVELS[fallback_level].copy()
# If all attempts fail, return a simple addition problem as fallback
return "What is $2 + 2$?", "4", 0
def record_performance(self, generator_id: int, is_correct: bool) -> None:
"""
Record the performance on a specific problem.
Args:
generator_id: The ID of the generator used
is_correct: Whether the answer was correct
"""
# Find which level this generator belongs to
level = None
for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
if generator_id in generator_ids:
level = lvl
break
if level is not None:
# Add the result to the performance history
self.performance_history[level].append(is_correct)
def get_success_rate(self, level: int) -> Optional[float]:
"""
Calculate the success rate for a specific level.
Args:
level: The difficulty level
Returns:
Success rate as a float between 0 and 1, or None if not enough data
"""
history = self.performance_history[level]
if len(history) < self.min_evaluations:
return None
# Calculate success rate from recent evaluations
recent_history = history[-self.min_evaluations :]
return sum(recent_history) / len(recent_history)
def should_advance(self) -> bool:
"""
Determine if the learner should advance to the next level.
Returns:
Boolean indicating whether to advance
"""
success_rate = self.get_success_rate(self.current_level)
# If not enough data or below threshold, don't advance
if success_rate is None or success_rate < self.progress_threshold:
return False
# Check if there's a next level to advance to
return self.current_level < max(self.DIFFICULTY_LEVELS.keys())
def advance_difficulty(self) -> bool:
"""
Advance to the next difficulty level if appropriate.
Returns:
Boolean indicating whether advancement occurred
"""
if self.should_advance():
self.current_level += 1
return True
return False
def get_current_level(self) -> int:
"""
Get the current difficulty level.
Returns:
Current level as an integer
"""
return self.current_level
def get_num_levels(self) -> int:
"""
Get the total number of difficulty levels.
Returns:
Total number of levels
"""
return len(self.DIFFICULTY_LEVELS)
def get_level_description(self, level: Optional[int] = None) -> str:
"""
Get a description of the specified difficulty level.
Args:
level: The level to describe (default: current level)
Returns:
String description of the level
"""
if level is None:
level = self.current_level
level_descriptions = {
1: "Basic arithmetic operations (addition, subtraction, multiplication, division)",
2: "Basic operations with fractions and pre-algebra",
3: "Basic geometry and more algebra",
4: "More advanced algebra and basic statistics",
5: "Vectors, matrices, and solid geometry",
6: "Advanced topics (calculus, statistics, computer science)",
7: "Most complex topics (complex numbers, advanced operations)",
}
return level_descriptions.get(
level, f"Custom level with IDs: {self.DIFFICULTY_LEVELS.get(level, [])}"
)
def reset(self, level: int = 1) -> None:
"""
Reset the curriculum to a specific level and clear performance history.
Args:
level: The level to reset to (default: 1)
"""
if level not in self.DIFFICULTY_LEVELS:
raise ValueError(
f"Invalid level: {level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
)
self.current_level = level
self.performance_history = {lvl: [] for lvl in self.DIFFICULTY_LEVELS.keys()}
def get_generator_info(self) -> List[Dict[str, Any]]:
"""
Get information about all available generators.
Returns:
List of dictionaries containing generator information
"""
generators = []
gen_list = mathgenerator.getGenList()
for gen in gen_list:
# Find which level this generator belongs to
level = None
for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
if gen[0] in generator_ids:
level = lvl
break
generators.append(
{
"id": gen[0],
"name": gen[1],
"function": gen[3],
"subject": gen[4],
"params": gen[5],
"difficulty_level": level,
}
)
return generators

View file

@ -0,0 +1,253 @@
#!/usr/bin/env python3
"""
Infinite Math - A reinforcement learning environment for math practice
using the mathgenerator library with curriculum-based advancement.
"""
import random
from typing import Any, Dict, List, Optional, Tuple
from .curriculum import MathCurriculum
class InfiniteMath:
"""
A reinforcement learning environment for practicing math skills with
curriculum-based advancement.
This class uses the MathCurriculum to generate appropriate math problems
and track performance, advancing difficulty as the learner improves.
"""
def __init__(
self,
starting_level: int = 1,
progress_threshold: float = 0.8,
min_evaluations: int = 5,
max_attempts_per_problem: int = 3,
):
"""
Initialize the InfiniteMath environment.
Args:
starting_level: Initial difficulty level (default: 1)
progress_threshold: Success rate needed to advance levels (default: 0.8)
min_evaluations: Minimum evaluations before considering advancement (default: 5)
max_attempts_per_problem: Maximum attempts allowed per problem (default: 3)
"""
self.curriculum = MathCurriculum(
starting_level=starting_level,
progress_threshold=progress_threshold,
min_evaluations=min_evaluations,
)
self.max_attempts = max_attempts_per_problem
self.current_problem = None
self.current_solution = None
self.current_generator_id = None
self.attempts_remaining = 0
self.total_problems = 0
self.correct_problems = 0
# Generate the first problem
self._generate_problem()
def _generate_problem(self) -> None:
"""Generate a new problem from the curriculum."""
self.current_problem, self.current_solution, self.current_generator_id = (
self.curriculum.get_problem()
)
self.attempts_remaining = self.max_attempts
def get_state(self) -> Dict[str, Any]:
"""
Get the current state of the environment.
Returns:
Dictionary with current state information
"""
return {
"problem": self.current_problem,
"attempts_remaining": self.attempts_remaining,
"current_level": self.curriculum.get_current_level(),
"total_levels": self.curriculum.get_num_levels(),
"level_description": self.curriculum.get_level_description(),
"total_problems": self.total_problems,
"correct_problems": self.correct_problems,
"accuracy": self.correct_problems / max(1, self.total_problems),
}
def submit_answer(self, answer: str) -> Dict[str, Any]:
"""
Submit an answer to the current problem.
Args:
answer: The learner's answer to the current problem
Returns:
Dictionary with the result of the submission and updated state
"""
if self.current_problem is None:
return {"error": "No active problem. Call reset() to start a new session."}
# Clean up the answer for comparison (strip whitespace, convert to lowercase)
cleaned_answer = str(answer).strip().lower()
cleaned_solution = str(self.current_solution).strip().lower()
# Check if the answer is correct
is_correct = cleaned_answer == cleaned_solution
# Update attempts
self.attempts_remaining -= 1
result = {
"is_correct": is_correct,
"correct_answer": (
self.current_solution
if self.attempts_remaining == 0 or is_correct
else None
),
"attempts_remaining": self.attempts_remaining,
}
# If correct or out of attempts, record performance and generate a new problem
if is_correct or self.attempts_remaining == 0:
self.total_problems += 1
if is_correct:
self.correct_problems += 1
# Record performance in the curriculum
self.curriculum.record_performance(self.current_generator_id, is_correct)
# Check if we should advance to the next level
did_advance = self.curriculum.advance_difficulty()
result["did_advance_level"] = did_advance
if did_advance:
result["new_level"] = self.curriculum.get_current_level()
result["level_description"] = self.curriculum.get_level_description()
# Generate a new problem
self._generate_problem()
result["new_problem"] = self.current_problem
return result
def reset(self, level: Optional[int] = None) -> Dict[str, Any]:
"""
Reset the environment, optionally to a specific difficulty level.
Args:
level: The difficulty level to reset to (default: current level)
Returns:
Dictionary with the new state
"""
if level is not None:
self.curriculum.reset(level)
self.total_problems = 0
self.correct_problems = 0
# Generate a new problem
self._generate_problem()
return self.get_state()
def get_difficulty_stats(self) -> Dict[int, Dict[str, Any]]:
"""
Get performance statistics for each difficulty level.
Returns:
Dictionary with statistics for each level
"""
stats = {}
for level in self.curriculum.DIFFICULTY_LEVELS.keys():
success_rate = self.curriculum.get_success_rate(level)
history = self.curriculum.performance_history[level]
stats[level] = {
"description": self.curriculum.get_level_description(level),
"problems_attempted": len(history),
"success_rate": (
success_rate if success_rate is not None else float("nan")
),
"is_current_level": level == self.curriculum.get_current_level(),
}
return stats
def main():
"""Example usage of the InfiniteMath environment."""
# Create the environment
env = InfiniteMath(starting_level=1)
print("Welcome to InfiniteMath!")
print(
f"Starting at level {env.get_state()['current_level']}: {env.get_state()['level_description']}"
)
playing = True
while playing:
# Display current problem
state = env.get_state()
print("\n" + "=" * 50)
print(f"Level {state['current_level']}/{state['total_levels']}")
print(f"Problem: {state['problem']}")
print(f"Attempts remaining: {state['attempts_remaining']}")
# Get user input
answer = input("Your answer (or 'q' to quit, 'r' to reset): ")
if answer.lower() == "q":
playing = False
continue
elif answer.lower() == "r":
level = input("Reset to level (1-7, or press Enter for current level): ")
if level and level.isdigit() and 1 <= int(level) <= 7:
env.reset(int(level))
else:
env.reset()
continue
# Submit the answer
result = env.submit_answer(answer)
# Display the result
if result.get("is_correct", False):
print("Correct!")
else:
print("Incorrect.")
if result.get("correct_answer") is not None:
print(f"The correct answer is: {result['correct_answer']}")
# Check if we advanced to a new level
if result.get("did_advance_level", False):
print(f"\nCongratulations! You've advanced to level {result['new_level']}!")
print(f"New level: {result['level_description']}")
# If we have a new problem, show it
if result.get("new_problem") is not None:
print("\nNext problem is ready.")
# Show final statistics
print("\nFinal Statistics:")
print(f"Total problems attempted: {env.total_problems}")
print(f"Correct answers: {env.correct_problems}")
if env.total_problems > 0:
print(f"Overall accuracy: {env.correct_problems / env.total_problems:.2%}")
level_stats = env.get_difficulty_stats()
print("\nPerformance by level:")
for level, stats in level_stats.items():
if stats["problems_attempted"] > 0:
print(
f"Level {level}: {stats['success_rate']:.2%} success rate ({stats['problems_attempted']} problems)"
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,745 @@
import asyncio
import json
import logging
import random
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from trajectoryhandler.envs.base import (
BaseEnv,
BaseEnvConfig,
OpenaiConfig,
ScoredDataGroup,
)
from trajectoryhandler.envs.reward_fns import registry
from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward
from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer
from .curriculum import MathCurriculum
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
Follow these steps:
1. Understand the problem carefully
2. Plan your approach
3. Execute the calculations step-by-step
4. Verify your solution
5. Express the final answer as \\boxed{your answer here}
You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
Your answer format should be:
<think>
[Your detailed step-by-step reasoning process here]
</think>
\\boxed{your final answer here}
Remember to format your final answer correctly as this is important for evaluation."""
class InfiniteMathEnvConfig(BaseEnvConfig):
"""Configuration for the InfiniteMath environment."""
# Curriculum parameters
starting_level: int = 1
progress_threshold: float = 0.8
min_evaluations: int = 5
# Environment parameters
max_attempts_per_problem: int = 3
correct_reward: float = 1.0
incorrect_reward: float = -1.0
# Length penalty parameters
apply_length_penalty: bool = True
length_threshold_ratio: float = (
0.5 # Percentage of max_token_length before penalties apply
)
# Completion parameters
temperature: float = 0.7
top_p: float = 0.9
# Reward functions
reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"]
accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward
format_reward_weight: float = (
0.2 # Weight for the format reward relative to correctness
)
boxed_reward_weight: float = (
0.3 # Weight for the boxed answer reward relative to correctness
)
class InfiniteMathEnv(BaseEnv):
"""Environment for procedurally generated math problems with curriculum advancement."""
def __init__(
self,
config: InfiniteMathEnvConfig,
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.config = config # Override with our specific config class
# Initialize tracking metrics
self.percent_correct_buffer = []
self.level_correct_buffer = {
i: [] for i in range(1, 8)
} # Track correctness for each level
self.eval_metrics = []
# Curriculum will be initialized in setup()
self.curriculum = None
# Set the system prompt
self.system_prompt = system_prompt
# Initialize reward function
self.reward_function = self._initialize_reward_function()
def _initialize_reward_function(self):
"""Initialize the combined reward function for scoring."""
if hasattr(self.config, "reward_functions") and self.config.reward_functions:
# Configure parameters for specific reward functions
reward_configs = []
for reward_func in self.config.reward_functions:
if isinstance(reward_func, str):
# String name case - handle known rewards with custom params
if reward_func == "accuracy":
# Configure accuracy reward
accuracy_config = {
"type": "accuracy",
"weight": self.config.accuracy_reward_weight,
"params": {
"split_on_think_tag": True, # Only look at what's after </think>
"tolerance": 1e-6, # Tolerance for number comparisons
},
}
logger.info(f"Adding accuracy reward with config: {accuracy_config}")
reward_configs.append(accuracy_config)
elif reward_func == "format":
# Configure format reward with think tags and explicit weight
format_config = {
"type": "format",
"weight": self.config.format_reward_weight,
"params": {
"preferred_tags": ["think"],
},
}
logger.info(f"Adding format reward with config: {format_config}")
reward_configs.append(format_config)
elif reward_func == "boxed":
# Configure boxed reward with proper parameters and explicit weight
boxed_config = {
"type": "boxed",
"weight": self.config.boxed_reward_weight,
"params": {
"require_outside_think": True,
},
}
logger.info(f"Adding boxed reward with config: {boxed_config}")
reward_configs.append(boxed_config)
else:
# Pass through other reward functions as is
logger.info(f"Adding generic reward function: {reward_func}")
reward_configs.append(reward_func)
else:
# Dict case - pass through as is
logger.info(f"Adding reward config: {reward_func}")
reward_configs.append(reward_func)
# Create the reward function(s)
if len(reward_configs) == 1:
logger.info(f"Creating single reward function: {reward_configs[0]}")
return registry.create(reward_configs[0])
else:
logger.info(f"Creating combined reward function with {len(reward_configs)} rewards")
# Add explicit normalization to sum to 1.0
return CombinedReward(rewards=reward_configs, normalization="none")
async def setup(self):
"""Initialize the environment and curriculum."""
logger.info("Setting up InfiniteMathEnv")
# Initialize curriculum
self.curriculum = MathCurriculum(
starting_level=self.config.starting_level,
progress_threshold=self.config.progress_threshold,
min_evaluations=self.config.min_evaluations,
)
# Generate some test problems for each level for evaluation
self.eval_problems = {}
for level in range(1, 8):
self.eval_problems[level] = []
temp_curriculum = MathCurriculum(starting_level=level)
# Generate 10 test problems for each level
attempts = 0
max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones
while (
len(self.eval_problems[level]) < 10
and attempts < max_attempts_per_level
):
try:
problem, solution, generator_id = temp_curriculum.get_problem()
# Strip LaTeX delimiters
problem = self.strip_latex_delimiters(problem)
solution = self.strip_latex_delimiters(solution)
self.eval_problems[level].append((problem, solution, generator_id))
except Exception as e:
logger.warning(
f"Error generating evaluation problem for level {level}: {e}"
)
attempts += 1
logger.info(
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
)
# If any levels have no problems, add a simple fallback
for level in range(1, 8):
if not self.eval_problems[level]:
logger.warning(
f"No valid evaluation problems for level {level}, adding fallback"
)
if level == 1:
self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
elif level == 2:
self.eval_problems[level].append(
("What is the square root of 16?", "4", 6)
)
elif level == 3:
self.eval_problems[level].append(
(
"What is the area of a triangle with base 6 and height 8?",
"24",
18,
)
)
elif level == 4:
self.eval_problems[level].append(
("What is the solution to x + 5 = 12?", "7", 26)
)
elif level == 5:
self.eval_problems[level].append(
("What is the volume of a cube with side length 3?", "27", 33)
)
elif level == 6:
self.eval_problems[level].append(
("What is 5 factorial?", "120", 31)
)
else:
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
def strip_latex_delimiters(self, text: str) -> str:
"""Strip LaTeX delimiters ($...$) from text."""
# Handle both inline expressions $...$ and expressions that make up the entire string
return re.sub(r"\$(.*?)\$", r"\1", text)
def save_checkpoint(self, step, data=None):
"""Save curriculum state in checkpoint."""
if data is None:
data = {}
# Save curriculum state
data["curriculum_level"] = self.curriculum.get_current_level()
data["performance_history"] = {
str(k): v for k, v in self.curriculum.performance_history.items()
}
super().save_checkpoint(step, data)
def load_checkpoint(self):
"""Load curriculum state from checkpoint."""
super().load_checkpoint()
# Check if we have curriculum data in the checkpoint
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
try:
with open(checkpoint_path, "r") as f:
data = json.load(f)
# Restore curriculum state if available
if "curriculum_level" in data:
level = data["curriculum_level"]
self.curriculum.current_level = level
if "performance_history" in data:
# Convert string keys back to integers
self.curriculum.performance_history = {
int(k): v for k, v in data["performance_history"].items()
}
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load checkpoint: {e}")
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Log overall correct percentage
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / max(1, len(self.percent_correct_buffer))
except ZeroDivisionError:
pass
# Log per-level metrics
for level, buffer in self.level_correct_buffer.items():
if buffer:
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
buffer
)
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
# Log current level and curriculum progress
if self.curriculum:
current_level = self.curriculum.get_current_level()
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
wandb_metrics["curriculum/current_level"] = current_level
wandb_metrics["curriculum/max_level"] = max_level
wandb_metrics["curriculum/progress_percent"] = (
current_level / max_level
) * 100
# Log level description
wandb_metrics["curriculum/level_description"] = (
self.curriculum.get_level_description()
)
# Log performance history for current level
if current_level in self.curriculum.performance_history:
history = self.curriculum.performance_history[current_level]
if history:
recent_history = history[
-min(len(history), self.curriculum.min_evaluations) :
]
if recent_history:
success_rate = sum(recent_history) / len(recent_history)
wandb_metrics["curriculum/current_level_success_rate"] = (
success_rate
)
wandb_metrics["curriculum/threshold_to_advance"] = (
self.curriculum.progress_threshold
)
wandb_metrics["curriculum/remaining_to_threshold"] = max(
0, self.curriculum.progress_threshold - success_rate
)
# Log reward function metrics
if hasattr(self, "reward_function") and self.wandb:
if hasattr(self.reward_function, "set_wandb_logger"):
self.reward_function.set_wandb_logger(self.wandb)
# Log the reward configurations
if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
# Log the reward configuration
wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
if hasattr(self.config, "format_reward_weight"):
wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
if hasattr(self.config, "boxed_reward_weight"):
wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
# Add eval metrics
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
# Reset buffers
self.percent_correct_buffer = []
for level in self.level_correct_buffer:
self.level_correct_buffer[level] = []
self.eval_metrics = []
# Call the parent method to handle remaining metrics
await super().wandb_log(wandb_metrics)
async def get_next_item(self):
"""Get the next problem based on current curriculum level."""
problem, solution, generator_id = self.curriculum.get_problem()
# Strip LaTeX delimiters from problem and solution
problem = self.strip_latex_delimiters(problem)
solution = self.strip_latex_delimiters(solution)
# Create a message with the problem
prompt = tuple([frozenset({"role": "user", "content": problem}.items())])
# Return the problem with metadata
return (prompt, solution, generator_id)
async def evaluate(self, *args, **kwargs):
"""Evaluate the model on test problems at the current curriculum level."""
current_level = self.curriculum.get_current_level()
logger.info(f"Starting evaluation for curriculum level {current_level}")
# Only evaluate problems at the current level
eval_tasks = []
eval_generator_ids = []
if current_level in self.eval_problems:
for problem, solution, generator_id in self.eval_problems[current_level]:
eval_tasks.append(
self.evaluate_single_problem(problem, solution, current_level)
)
eval_generator_ids.append(generator_id)
if not eval_tasks:
logger.warning(
f"No evaluation problems available for level {current_level}"
)
return []
# Run evaluation tasks
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
results = await asyncio.gather(*eval_tasks)
# Calculate accuracy for the current level
correct_count = sum(1 for _, is_correct in results if is_correct)
total_count = len(results)
accuracy = correct_count / total_count if total_count > 0 else 0
logger.info(
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
)
# Record metrics for the current level
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
self.eval_metrics.append(("eval/current_level", current_level))
# Record the actual evaluation results in the curriculum's performance history
for i, (_, is_correct) in enumerate(results):
if i < len(eval_generator_ids):
# Record the actual result
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
else:
# Fallback if somehow the lists are different lengths
sample_generator_id = random.choice(
self.curriculum.DIFFICULTY_LEVELS[current_level]
)
self.curriculum.record_performance(sample_generator_id, is_correct)
# Try to advance to the next level
advanced = self.curriculum.advance_difficulty()
new_level = self.curriculum.get_current_level()
if advanced:
logger.info(f"Advanced from level {current_level} to level {new_level}!")
self.eval_metrics.append(("eval/advanced_level", 1))
else:
logger.info(f"Remaining at level {current_level}")
self.eval_metrics.append(("eval/advanced_level", 0))
return self.eval_metrics
async def evaluate_single_problem(
self, problem: str, solution: str, level: int
) -> Tuple[int, bool]:
"""Evaluate a single problem."""
try:
logger.debug(f"Evaluating level {level} problem: {problem[:30]}...")
# Convert messages to a single prompt using the tokenizer
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": problem},
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Add prefilled thinking starter
prefill = "\n<think>\n"
prefilled_prompt = prompt + prefill
# Generate completion using the prompt
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
completion = await self.server.completion(
prompt=prefilled_prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0, # Use 0 temperature for deterministic results
top_p=1.0,
split="eval",
)
# Extract the completion text and prepend the thinking starter
model_answer = prefill + (
completion.choices[0].text
if hasattr(completion.choices[0], "text")
else completion.choices[0].message.content
)
# Check if the answer is correct
is_correct = self.check_answer(model_answer, solution)
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
return level, is_correct
except Exception as e:
logger.error(f"Error evaluating problem: {e}")
# Return a failed result in case of error
return level, False
def check_answer(self, model_answer: str, solution: str) -> bool:
"""Check if the model's answer matches the solution."""
# Extract the part after the thinking block
after_think_part = (
model_answer.split("</think>")[-1].strip()
if "</think>" in model_answer
else model_answer
)
# Extract the boxed answer if present
boxed_answer = self.extract_boxed_answer(after_think_part)
if not boxed_answer:
# Try to find the answer in the last line
lines = after_think_part.strip().split("\n")
if lines:
boxed_answer = lines[-1].strip()
# Clean up answers for comparison (remove spaces, convert to lowercase)
model_clean = self.clean_for_comparison(
boxed_answer if boxed_answer else after_think_part
)
solution_clean = self.clean_for_comparison(solution)
# Check if they match
return model_clean == solution_clean
def extract_boxed_answer(self, text: str) -> Optional[str]:
"""Extract answer from a LaTeX boxed expression."""
# Try to find boxed content
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
if boxed_match:
return boxed_match.group(1)
return None
def clean_for_comparison(self, text: str) -> str:
"""Clean text for comparison."""
# Remove LaTeX commands, spaces, commas, and convert to lowercase
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
cleaned = re.sub(r"[,\s]", "", cleaned)
cleaned = cleaned.lower()
return cleaned
async def collect_trajectories(self, item) -> Tuple[List, List]:
"""Collect trajectories for the current item."""
# Extract information from the item
problem_prompt, solution, generator_id = item
# Create prompt using tokenizer's chat template
# Add prefilled thinking starter
prefill = "\n<think>\n"
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": dict(problem_prompt[0])["content"]},
{"role": "assistant", "content": prefill},
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Generate completions using completion API
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
# Prepare data for scoring
to_score = []
# Track level for metrics
level = None
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
if generator_id in generator_ids:
level = lvl
break
# Process each completion
for i, completion in enumerate(completions.choices):
# Get the completion text and prepend the thinking starter
model_answer = prefill + (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
print("model_answer", model_answer)
# Build complete message sequence
full_messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": dict(problem_prompt[0])["content"]},
{"role": "assistant", "content": model_answer},
]
# Add to scoring list
to_score.append((full_messages, solution, generator_id, level))
# Record performance in curriculum for each item we're scoring
# This will be called again after scoring, but that's fine
# No additional items for backlog
backlog = []
return to_score, backlog
async def score(self, rollout_group_data) -> ScoredDataGroup:
"""Score the collected trajectories."""
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["messages"] = []
# Format completions for reward function evaluation
format_completions = []
# Process each item in the rollout data
for messages, solution, generator_id, level in rollout_group_data:
# Extract the model's answer
model_answer = messages[-1]["content"]
# Add to format completions list for reward function
format_completions.append([{"role": "assistant", "content": model_answer}])
# Record performance in curriculum based on the answer and solution
# This will be updated after the reward functions are applied
# Apply all reward functions
reward_scores = []
unweighted_scores = []
if hasattr(self, "reward_function") and self.reward_function:
try:
# Apply the reward function (which may be a combined reward)
reward_scores = self.reward_function(format_completions, solution=solution)
logger.info(f"Reward scores: {reward_scores}")
# Debug individual rewards if it's a combined reward
if hasattr(self.reward_function, "rewards"):
logger.info(f"Combined reward with {len(self.reward_function.rewards)} components")
for i, reward in enumerate(self.reward_function.rewards):
if hasattr(reward, "compute"):
# Get raw unweighted scores
raw_scores = reward.compute(format_completions, solution=solution)
if hasattr(reward, "weight"):
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}")
else:
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}")
else:
logger.info(f"Using single reward: {type(self.reward_function).__name__}")
except Exception as e:
logger.error(f"Error applying reward functions: {e}")
logger.exception(e)
reward_scores = [0.0] * len(format_completions)
# Now update curriculum based on accuracy reward results
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
# Extract accuracy from the combined reward if available
is_correct = False
if reward_scores and hasattr(self.reward_function, "rewards"):
for reward in self.reward_function.rewards:
if type(reward).__name__ == "AccuracyReward":
# Get raw scores from accuracy reward
accuracy_scores = reward.compute(format_completions, solution=solution)
is_correct = accuracy_scores[i] > 0
break
# Record answer correctness for tracking
self.percent_correct_buffer.append(1 if is_correct else 0)
if level is not None:
self.level_correct_buffer[level].append(1 if is_correct else 0)
# Record performance in curriculum
self.curriculum.record_performance(generator_id, is_correct)
# Combine scores and add to scored data
for i, (messages, _, _, _) in enumerate(rollout_group_data):
# Use the reward score directly (all weights are applied)
combined_score = reward_scores[i] if reward_scores else 0.0
logger.info(f"Final score for item {i}: {combined_score}")
# Tokenize for the trainer
tokens_dict = tokenize_for_trainer(
self.tokenizer,
messages,
None,
)
# Add to scored data
scored_data["tokens"].append(tokens_dict["tokens"])
scored_data["masks"].append(tokens_dict["masks"])
scored_data["scores"].append(combined_score)
scored_data["messages"].append(messages)
# Advance difficulty if criteria met
self.curriculum.advance_difficulty()
return scored_data
if __name__ == "__main__":
import asyncio
async def main():
config = InfiniteMathEnvConfig(
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
group_size=8,
use_wandb=True,
max_num_workers=64,
rollout_server_url="http://localhost:8000",
total_steps=10000,
batch_size=1024,
steps_per_eval=25,
max_token_length=4096,
inference_weight=1.0,
wandb_name="infinite_math",
data_path_to_save_groups="data/infinite_math_groups.jsonl",
# InfiniteMath specific config
starting_level=1,
progress_threshold=0.8,
min_evaluations=10,
correct_reward=1.0,
incorrect_reward=-0.5,
apply_length_penalty=True,
length_threshold_ratio=0.6,
# Completion parameters
temperature=0.7,
top_p=0.9,
# Reward function configuration - use name directly
reward_functions=["accuracy", "format", "boxed"],
accuracy_reward_weight=1.0,
format_reward_weight=0.2,
boxed_reward_weight=0.3,
)
openai_config = OpenaiConfig(
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=64,
)
env = InfiniteMathEnv(
config=config,
server_configs=[openai_config],
slurm=False,
)
await env.env_manager()
asyncio.run(main())

View file

@ -0,0 +1,242 @@
#!/usr/bin/env python3
import asyncio
import logging
import os
import argparse
from dotenv import load_dotenv
from openai import OpenAI
from environments.infinimath.infinimath_env import (
InfiniteMathEnv,
InfiniteMathEnvConfig,
)
from atroposlib.envs.base import OpenaiConfig
from atroposlib.utils.config_handler import ConfigHandler
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_arguments():
parser = argparse.ArgumentParser(description="InfiniteMath environment server")
parser.add_argument(
"--config",
type=str,
default="infinimath",
help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
)
return parser.parse_args()
async def main():
logger.info("Starting InfiniteMath environment server")
# Parse command line arguments
args = parse_arguments()
# Initialize config handler and load configuration
config_handler = ConfigHandler()
# Determine config path
if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
config_path = args.config
else:
# short form that defaults to the envs directory
config_path = os.path.join(
config_handler.config_dir, f"envs/{args.config}.yaml"
)
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, "r") as f:
import yaml
raw_config = yaml.safe_load(f)
logger.info(f"Loaded configuration successfully")
except Exception as e:
logger.error(f"Error loading config directly: {e}")
logger.info("Falling back to default config handler")
raw_config = config_handler.load_config(args)
# Configure the InfiniteMath environment with values from config
config = InfiniteMathEnvConfig(
# Base environment parameters
tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
group_size=raw_config.get("group_size", 1),
use_wandb=raw_config.get("use_wandb", False),
max_num_workers=raw_config.get("max_num_workers", 1),
rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
total_steps=raw_config.get("total_steps", 1),
batch_size=raw_config.get("batch_size", 1),
steps_per_eval=raw_config.get("steps_per_eval", 2),
max_token_length=raw_config.get("max_token_length", 4096),
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
# InfiniteMath specific parameters
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
)
# Server configuration from config file or defaults
server_configs = []
if "server_configs" in raw_config:
for server_config in raw_config["server_configs"]:
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
# Handle environment variable references like ${OPENAI_API_KEY}
if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
env_var = api_key[2:-1]
api_key = os.environ.get(env_var, "")
server_configs.append(
OpenaiConfig(
model_name=server_config.get("model_name", "gpt-4.1-nano"),
base_url=server_config.get("base_url", None),
api_key=api_key,
num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
)
)
else:
# Default configuration if not specified in config file
server_configs.append(
OpenaiConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=70,
)
)
# Create the environment
env = InfiniteMathEnv(
config=config,
server_configs=server_configs,
slurm=False,
)
# Setup the environment
await env.setup()
logger.info("Environment setup complete")
# Log the number of evaluation problems
total_problems = sum(len(probs) for probs in env.eval_problems.values())
logger.info(
f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
)
# Get a math problem
item = await env.get_next_item()
problem_prompt, solution, generator_id = item
logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
logger.info(f"Solution: {solution}")
# Collect trajectories
logger.info("Collecting trajectories...")
trajectories_data, backlog = await env.collect_trajectories(item)
# Score the collected trajectories
logger.info("Scoring trajectories...")
scored_data = await env.score(trajectories_data)
input("Press Enter to continue...")
# Print scores
logger.info(f"Scores: {scored_data['scores']}")
# Log the correct/incorrect counts
correct_count = sum(1 for score in scored_data["scores"] if score > 0)
logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
# Test evaluation function specifically
logger.info("\n=== Testing Evaluation Function ===")
# Record the current level
initial_level = env.curriculum.get_current_level()
logger.info(f"Current level before evaluation: {initial_level}")
logger.info(f"Level description: {env.curriculum.get_level_description()}")
logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
# Run the evaluate method
eval_metrics = await env.evaluate()
# Display evaluation results
logger.info("Evaluation metrics:")
for metric_name, metric_value in eval_metrics:
logger.info(f" - {metric_name}: {metric_value}")
# Check if the level advanced
new_level = env.curriculum.get_current_level()
if new_level > initial_level:
logger.info(f"Successfully advanced to level {new_level}!")
logger.info(f"New level description: {env.curriculum.get_level_description()}")
else:
# Show current progress toward advancement
current_level = env.curriculum.get_current_level()
if current_level in env.curriculum.performance_history:
history = env.curriculum.performance_history[current_level]
if len(history) >= env.curriculum.min_evaluations:
recent_history = history[-env.curriculum.min_evaluations :]
success_rate = sum(recent_history) / len(recent_history)
logger.info(
f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
)
else:
logger.info(
f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
)
# Show all levels and their performance history
logger.info("\nPerformance history by level:")
for level in sorted(env.curriculum.performance_history.keys()):
history = env.curriculum.performance_history[level]
if history:
success_rate = sum(history) / len(history)
logger.info(
f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
)
else:
logger.info(f" Level {level}: No data")
# Test curriculum advancement with simulated performance history
logger.info("\n=== Testing Curriculum Advancement ===")
# Simulate good performance at current level
for _ in range(env.config.min_evaluations):
# Get a problem from current level
item = await env.get_next_item()
_, _, generator_id = item
# Record positive performance
env.curriculum.record_performance(generator_id, True)
# Try to advance difficulty
did_advance = env.curriculum.advance_difficulty()
new_level = env.curriculum.get_current_level()
logger.info(f"Curriculum advancement test:")
logger.info(f" - Starting level: {initial_level}")
logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
logger.info(f" - Did advance: {did_advance}")
logger.info(f" - New level: {new_level}")
logger.info("Test server completed successfully")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,242 @@
#!/usr/bin/env python3
import asyncio
import logging
import os
import argparse
from dotenv import load_dotenv
from openai import OpenAI
from environments.infinimath.infinimath_env import (
InfiniteMathEnv,
InfiniteMathEnvConfig,
)
from atroposlib.envs.base import OpenaiConfig
from atroposlib.utils.config_handler import ConfigHandler
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_arguments():
parser = argparse.ArgumentParser(description="InfiniteMath environment server")
parser.add_argument(
"--config",
type=str,
default="infinimath",
help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
)
return parser.parse_args()
async def main():
logger.info("Starting InfiniteMath environment server")
# Parse command line arguments
args = parse_arguments()
# Initialize config handler and load configuration
config_handler = ConfigHandler()
# Determine config path
if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
config_path = args.config
else:
# short form that defaults to the envs directory
config_path = os.path.join(
config_handler.config_dir, f"envs/{args.config}.yaml"
)
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, "r") as f:
import yaml
raw_config = yaml.safe_load(f)
logger.info(f"Loaded configuration successfully")
except Exception as e:
logger.error(f"Error loading config directly: {e}")
logger.info("Falling back to default config handler")
raw_config = config_handler.load_config(args)
# Configure the InfiniteMath environment with values from config
config = InfiniteMathEnvConfig(
# Base environment parameters
tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
group_size=raw_config.get("group_size", 1),
use_wandb=raw_config.get("use_wandb", False),
max_num_workers=raw_config.get("max_num_workers", 1),
rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
total_steps=raw_config.get("total_steps", 1),
batch_size=raw_config.get("batch_size", 1),
steps_per_eval=raw_config.get("steps_per_eval", 2),
max_token_length=raw_config.get("max_token_length", 4096),
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
# InfiniteMath specific parameters
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
)
# Server configuration from config file or defaults
server_configs = []
if "server_configs" in raw_config:
for server_config in raw_config["server_configs"]:
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
# Handle environment variable references like ${OPENAI_API_KEY}
if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
env_var = api_key[2:-1]
api_key = os.environ.get(env_var, "")
server_configs.append(
OpenaiConfig(
model_name=server_config.get("model_name", "gpt-4.1-nano"),
base_url=server_config.get("base_url", None),
api_key=api_key,
num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
)
)
else:
# Default configuration if not specified in config file
server_configs.append(
OpenaiConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=70,
)
)
# Create the environment
env = InfiniteMathEnv(
config=config,
server_configs=server_configs,
slurm=False,
)
# Setup the environment
await env.setup()
logger.info("Environment setup complete")
# Log the number of evaluation problems
total_problems = sum(len(probs) for probs in env.eval_problems.values())
logger.info(
f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
)
# Get a math problem
item = await env.get_next_item()
problem_prompt, solution, generator_id = item
logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
logger.info(f"Solution: {solution}")
# Collect trajectories
logger.info("Collecting trajectories...")
trajectories_data, backlog = await env.collect_trajectories(item)
# Score the collected trajectories
logger.info("Scoring trajectories...")
scored_data = await env.score(trajectories_data)
input("Press Enter to continue...")
# Print scores
logger.info(f"Scores: {scored_data['scores']}")
# Log the correct/incorrect counts
correct_count = sum(1 for score in scored_data["scores"] if score > 0)
logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
# Test evaluation function specifically
logger.info("\n=== Testing Evaluation Function ===")
# Record the current level
initial_level = env.curriculum.get_current_level()
logger.info(f"Current level before evaluation: {initial_level}")
logger.info(f"Level description: {env.curriculum.get_level_description()}")
logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
# Run the evaluate method
eval_metrics = await env.evaluate()
# Display evaluation results
logger.info("Evaluation metrics:")
for metric_name, metric_value in eval_metrics:
logger.info(f" - {metric_name}: {metric_value}")
# Check if the level advanced
new_level = env.curriculum.get_current_level()
if new_level > initial_level:
logger.info(f"Successfully advanced to level {new_level}!")
logger.info(f"New level description: {env.curriculum.get_level_description()}")
else:
# Show current progress toward advancement
current_level = env.curriculum.get_current_level()
if current_level in env.curriculum.performance_history:
history = env.curriculum.performance_history[current_level]
if len(history) >= env.curriculum.min_evaluations:
recent_history = history[-env.curriculum.min_evaluations :]
success_rate = sum(recent_history) / len(recent_history)
logger.info(
f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
)
else:
logger.info(
f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
)
# Show all levels and their performance history
logger.info("\nPerformance history by level:")
for level in sorted(env.curriculum.performance_history.keys()):
history = env.curriculum.performance_history[level]
if history:
success_rate = sum(history) / len(history)
logger.info(
f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
)
else:
logger.info(f" Level {level}: No data")
# Test curriculum advancement with simulated performance history
logger.info("\n=== Testing Curriculum Advancement ===")
# Simulate good performance at current level
for _ in range(env.config.min_evaluations):
# Get a problem from current level
item = await env.get_next_item()
_, _, generator_id = item
# Record positive performance
env.curriculum.record_performance(generator_id, True)
# Try to advance difficulty
did_advance = env.curriculum.advance_difficulty()
new_level = env.curriculum.get_current_level()
logger.info(f"Curriculum advancement test:")
logger.info(f" - Starting level: {initial_level}")
logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
logger.info(f" - Did advance: {did_advance}")
logger.info(f" - New level: {new_level}")
logger.info("Test server completed successfully")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,265 @@
#!/usr/bin/env python3
"""
Test script for the Infinite Math curriculum manager.
This script tests the core functionality of both the MathCurriculum and InfiniteMath classes.
"""
import os
import random
import sys
from typing import Any, Dict, List, Optional, Tuple
# Use relative imports
from .curriculum import MathCurriculum
from .infinimath import InfiniteMath
def test_curriculum_initialization():
"""Test that the curriculum initializes correctly with different levels."""
print("\n=== Testing Curriculum Initialization ===")
# Test default initialization
curriculum = MathCurriculum()
assert curriculum.get_current_level() == 1, "Default level should be 1"
print("✓ Default initialization successful")
# Test initialization with specific level
curriculum = MathCurriculum(starting_level=3)
assert curriculum.get_current_level() == 3, "Starting level should be 3"
print("✓ Custom level initialization successful")
# Test initialization with invalid level
try:
curriculum = MathCurriculum(starting_level=10)
print("✗ Invalid level initialization should fail")
except ValueError:
print("✓ Invalid level initialization correctly raises ValueError")
def test_problem_generation():
"""Test that problems are generated correctly at different difficulty levels."""
print("\n=== Testing Problem Generation ===")
# Test problem generation at different levels
for level in range(1, 8):
curriculum = MathCurriculum(starting_level=level)
problem, solution, generator_id = curriculum.get_problem()
# Verify we got a problem and solution
assert (
isinstance(problem, str) and len(problem) > 0
), f"Problem at level {level} should be a non-empty string"
assert solution is not None, f"Solution at level {level} should not be None"
# Verify the generator ID belongs to the correct level
assert (
generator_id in curriculum.DIFFICULTY_LEVELS[level]
), f"Generator ID {generator_id} should be in level {level}"
print(
f"✓ Level {level} problem generated: {problem[:50]}{'...' if len(problem) > 50 else ''}"
)
print(f" Solution: {solution}")
print(f" Generator ID: {generator_id}")
def test_performance_tracking():
"""Test performance tracking and level advancement."""
print("\n=== Testing Performance Tracking and Advancement ===")
# Create curriculum with test parameters
curriculum = MathCurriculum(
starting_level=1, progress_threshold=0.7, min_evaluations=3
)
# Record some correct answers (not enough to advance)
generator_id = curriculum.DIFFICULTY_LEVELS[1][0] # Get a generator from level 1
curriculum.record_performance(generator_id, True)
curriculum.record_performance(generator_id, True)
# Check if we advance (should not)
did_advance = curriculum.advance_difficulty()
assert not did_advance, "Should not advance with only 2 evaluations"
assert curriculum.get_current_level() == 1, "Level should still be 1"
print("✓ Correctly did not advance with insufficient evaluations")
# Add one more correct answer (now should advance)
curriculum.record_performance(generator_id, True)
did_advance = curriculum.advance_difficulty()
assert did_advance, "Should advance with 3 correct evaluations (100% success rate)"
assert curriculum.get_current_level() == 2, "Level should be 2 after advancement"
print("✓ Correctly advanced to level 2 after sufficient success")
# Test with too low success rate
curriculum = MathCurriculum(
starting_level=1, progress_threshold=0.7, min_evaluations=3
)
generator_id = curriculum.DIFFICULTY_LEVELS[1][0]
curriculum.record_performance(generator_id, True) # 1 correct
curriculum.record_performance(generator_id, False) # 1 incorrect
curriculum.record_performance(generator_id, False) # 1 incorrect
did_advance = curriculum.advance_difficulty()
assert (
not did_advance
), "Should not advance with 33% success rate when threshold is 70%"
print("✓ Correctly did not advance with insufficient success rate")
# Test advancement at the highest level
curriculum = MathCurriculum(
starting_level=7, progress_threshold=0.7, min_evaluations=3
)
generator_id = curriculum.DIFFICULTY_LEVELS[7][0]
curriculum.record_performance(generator_id, True)
curriculum.record_performance(generator_id, True)
curriculum.record_performance(generator_id, True)
did_advance = curriculum.advance_difficulty()
assert not did_advance, "Should not advance beyond the highest level"
print("✓ Correctly did not advance beyond the highest level")
def test_level_descriptions():
"""Test that level descriptions are correct."""
print("\n=== Testing Level Descriptions ===")
curriculum = MathCurriculum()
for level in range(1, 8):
description = curriculum.get_level_description(level)
assert (
isinstance(description, str) and len(description) > 0
), f"Description for level {level} should be a non-empty string"
print(f"✓ Level {level}: {description}")
def test_infinite_math_environment():
"""Test the InfiniteMath environment functionality."""
print("\n=== Testing InfiniteMath Environment ===")
# Initialize the environment
env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=3)
# Get the initial state
state = env.get_state()
assert "problem" in state, "State should include a problem"
assert "current_level" in state, "State should include current level"
print(
f"✓ Initial state: Level {state['current_level']}, Problem: {state['problem'][:50]}{'...' if len(state['problem']) > 50 else ''}"
)
# Test answering a problem incorrectly
result = env.submit_answer("wrong answer")
assert "is_correct" in result, "Result should indicate correctness"
assert not result["is_correct"], "Result should be incorrect"
print("✓ Incorrect answer handled correctly")
# Test answering a problem correctly
# Note: We can't predict the correct answer, so we'll get it from the environment
correct_solution = env.current_solution
result = env.submit_answer(correct_solution)
assert result["is_correct"], "Result should be correct"
print("✓ Correct answer handled successfully")
# Test resetting the environment
env.reset(level=3)
state = env.get_state()
assert state["current_level"] == 3, "After reset, level should be 3"
print(f"✓ Reset to level 3 successful")
# Test getting difficulty stats
stats = env.get_difficulty_stats()
assert len(stats) == 7, "Should have stats for all 7 difficulty levels"
print("✓ Difficulty statistics retrieved successfully")
def simulate_learning():
"""Simulate a learning agent improving performance over time."""
print("\n=== Simulating Learning Process ===")
env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=5)
episodes = 30
print(f"Starting simulation with {episodes} episodes")
print(f"Initial level: {env.get_state()['current_level']}")
for i in range(episodes):
state = env.get_state()
current_level = state["current_level"]
# Simulated agent gradually improves - higher chance of correct answer over time
# and higher chance at lower levels
success_probability = min(0.5 + (i / episodes) + (1 / current_level), 0.95)
# Simulate an answer
is_correct = random.random() < success_probability
# If we decide to be correct, use the actual solution
if is_correct:
answer = env.current_solution
else:
# Otherwise provide a wrong answer
answer = "wrong answer"
# Submit the answer
result = env.submit_answer(answer)
# Check for level advancement
if result.get("did_advance_level", False):
new_level = result["new_level"]
print(
f"Episode {i+1}: Advanced to level {new_level}! (success probability: {success_probability:.2f})"
)
elif i % 5 == 0: # Print status occasionally
print(
f"Episode {i+1}: Still at level {current_level} (success probability: {success_probability:.2f})"
)
# Print final stats
final_state = env.get_state()
print(f"\nFinal level: {final_state['current_level']}")
print(
f"Overall accuracy: {final_state['correct_problems'] / final_state['total_problems']:.2%}"
)
# Print level-by-level stats
stats = env.get_difficulty_stats()
print("\nPerformance by level:")
for level, level_stats in stats.items():
if level_stats["problems_attempted"] > 0:
success_rate = level_stats["success_rate"]
if success_rate is not None:
print(
f"Level {level}: {success_rate:.2%} success rate ({level_stats['problems_attempted']} problems)"
)
else:
print(
f"Level {level}: Not enough data ({level_stats['problems_attempted']} problems)"
)
def main():
"""Run all tests."""
print("=== Starting Curriculum Manager Tests ===")
try:
test_curriculum_initialization()
test_problem_generation()
test_performance_tracking()
test_level_descriptions()
test_infinite_math_environment()
simulate_learning()
print("\n=== All tests completed successfully! ===")
except AssertionError as e:
print(f"\n❌ Test failed: {e}")
return 1
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())