diff --git a/environments/infinimath/README.md b/environments/infinimath/README.md
new file mode 100644
index 00000000..102ac0fc
--- /dev/null
+++ b/environments/infinimath/README.md
@@ -0,0 +1,105 @@
+# InfiniteMath Environment
+
+## Environment Overview
+
+This environment provides procedurally generated math problems with curriculum-based advancement. It allows an agent to solve increasingly difficult math problems, with the difficulty level adapting based on performance.
+
+**Demonstrates:**
+- Procedural content generation (math problems).
+- Curriculum learning: The environment automatically adjusts the difficulty (levels 1-7) based on the LLM's success rate.
+- Step-by-step reasoning evaluation: Rewards correctness, the presence of reasoning steps (within `` tags), and the final answer format (`\boxed{}`).
+- Handling LaTeX formatting for problems and answers.
+
+**Training Goal:**
+- To train LLMs to solve mathematical problems accurately.
+- To encourage explicit step-by-step reasoning before providing an answer.
+- To improve the LLM's ability to follow specific formatting instructions (using `` tags and `\boxed{}`).
+- To teach the model to handle progressively more complex problems through the curriculum.
+
+## Features
+
+- Progressive difficulty scaling across 7 levels of math problems
+- Built-in curriculum system that adapts to agent performance
+- Automatic problem generation with solutions
+- Reward functions for accuracy, formatting, and boxed answer checking
+
+## Usage
+
+### Running with Default Configuration
+
+To run the InfiniteMath environment with the default configuration:
+
+```bash
+python environments/infinite_math/infinimath_local_server.py
+```
+
+This will use the default configuration from `configs/envs/infinimath.yaml`.
+
+### Custom Configuration
+
+You can specify a custom configuration file:
+
+```bash
+python environments/infinite_math/infinimath_local_server.py --config my_custom_config
+```
+
+The `--config` parameter can be:
+
+1. A name (without `.yaml` extension) which will be looked up in `configs/envs/`
+2. A relative or absolute path to a YAML file
+
+For example:
+```bash
+# Using a config in configs/envs/
+python environments/infinite_math/infinimath_local_server.py --config infinimath_hard
+
+# Using a config with full path
+python environments/infinite_math/infinimath_local_server.py --config /path/to/my/config.yaml
+```
+
+## Configuration Structure
+
+The configuration file follows this structure:
+
+```yaml
+# Base environment parameters
+tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
+group_size: 1
+use_wandb: false
+# ... other base parameters
+
+# InfiniteMath specific configuration
+infinimath:
+ # Curriculum parameters
+ starting_level: 1
+ progress_threshold: 0.7
+ # ... other InfiniteMath specific parameters
+
+# Server configuration
+server_configs:
+ - model_name: "gpt-4.1-nano"
+ api_key: ${OPENAI_API_KEY}
+ num_requests_for_eval: 70
+```
+
+### Important Configuration Parameters
+
+#### Base Parameters
+
+- `tokenizer_name`: The tokenizer to use for encoding/decoding text
+- `group_size`: Number of responses to collect per prompt
+- `max_token_length`: Maximum token length for generation
+- `steps_per_eval`: How often to run evaluations
+
+#### InfiniteMath Specific Parameters
+
+- `starting_level`: Initial difficulty level (1-7)
+- `progress_threshold`: Success rate needed to advance levels
+- `min_evaluations`: Minimum number of evaluations before level advancement
+- `reward_functions`: List of reward functions to apply
+
+#### Server Configuration
+
+- `model_name`: LLM model to use
+- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax)
+- `num_requests_for_eval`: Number of evaluation requests to allocate
diff --git a/environments/infinimath/__init__.py b/environments/infinimath/__init__.py
new file mode 100644
index 00000000..b14bda0f
--- /dev/null
+++ b/environments/infinimath/__init__.py
@@ -0,0 +1,8 @@
+"""
+Infinite Math - A reinforcement learning environment for procedurally generated math problems.
+"""
+
+from .curriculum import MathCurriculum
+from .infinimath import InfiniteMath
+
+__all__ = ["MathCurriculum", "InfiniteMath"]
diff --git a/environments/infinimath/curriculum.py b/environments/infinimath/curriculum.py
new file mode 100644
index 00000000..7e9cd4a0
--- /dev/null
+++ b/environments/infinimath/curriculum.py
@@ -0,0 +1,378 @@
+import random
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import mathgenerator
+
+
+class MathCurriculum:
+ """
+ A curriculum manager for the mathgenerator library.
+
+ This class organizes math problems by difficulty and provides methods
+ to generate problems of appropriate difficulty based on the learner's
+ performance.
+ """
+
+ # Define difficulty levels and map generator IDs to each level
+ DIFFICULTY_LEVELS = {
+ # Level 1: Basic arithmetic operations
+ 1: [
+ 0,
+ 1,
+ 2,
+ 3,
+ 8,
+ 31,
+ 71,
+ 80,
+ 90,
+ ], # Addition, Subtraction, Multiplication, Division, Square, Factorial, Absolute difference, Percentage, IsPrime
+ # Level 2: Basic operations with fractions and pre-algebra
+ 2: [
+ 6,
+ 11,
+ 13,
+ 16,
+ 28,
+ 44,
+ 47,
+ 53,
+ 97,
+ 118,
+ 119,
+ 124,
+ ], # Square Root, Basic Algebra, Fraction to Decimal, Fraction Division, Fraction Multiplication, Compare Fractions, Cube Root, Exponentiation, Power of Powers, Percentage difference/error, Is Composite
+ # Level 3: Basic geometry and more algebra
+ 3: [
+ 18,
+ 19,
+ 22,
+ 24,
+ 25,
+ 49,
+ 58,
+ 75,
+ 96,
+ 104,
+ 108,
+ 112,
+ 115,
+ ], # Area of Triangle, Triangle exists check, Third Angle of Triangle, Distance between 2 points, Pythagorean Theorem, Fourth Angle of Quadrilateral, Sum of Angles of Polygon, Area of a Sector, Perimeter of Polygons, Circumference, Arc length, Area of Circle
+ # Level 4: More advanced algebra and basic statistics
+ 4: [
+ 9,
+ 10,
+ 20,
+ 21,
+ 23,
+ 26,
+ 40,
+ 41,
+ 45,
+ 50,
+ 76,
+ 78,
+ 105,
+ ], # LCM, GCD, Midpoint, Factoring Quadratic, System of Equations, Linear Equations, Common Factors, Intersection of Two Lines, Simple Interest, Quadratic Equation, Mean and Median, Compound Interest, Combine Like terms
+ # Level 5: Vectors, matrices, and solid geometry
+ 5: [
+ 17,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 43,
+ 46,
+ 60,
+ 61,
+ 70,
+ 72,
+ 77,
+ 95,
+ 113,
+ 117,
+ 122,
+ 123,
+ ], # Matrix Multiplication, Surface Areas, Volumes, Vector operations, etc.
+ # Level 6: Advanced topics (calculus, statistics, computer science)
+ 6: [
+ 4,
+ 5,
+ 7,
+ 12,
+ 14,
+ 15,
+ 27,
+ 30,
+ 42,
+ 48,
+ 52,
+ 54,
+ 55,
+ 56,
+ 59,
+ 62,
+ 64,
+ 73,
+ 79,
+ 84,
+ 88,
+ 89,
+ 91,
+ 103,
+ 107,
+ 110,
+ ], # Binary operations, Calculus, Combinatorics, Probability, etc.
+ # Level 7: Most complex topics
+ 7: [
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 74,
+ 85,
+ 92,
+ 93,
+ 94,
+ 98,
+ 99,
+ 100,
+ 101,
+ 106,
+ 109,
+ 111,
+ 121,
+ ], # Complex numbers, Advanced operations, etc.
+ }
+
+ def __init__(
+ self,
+ starting_level: int = 1,
+ progress_threshold: float = 0.8,
+ min_evaluations: int = 5,
+ ):
+ """
+ Initialize the curriculum manager.
+
+ Args:
+ starting_level: The difficulty level to start with (default: 1)
+ progress_threshold: The success rate required to advance to the next level (default: 0.8)
+ min_evaluations: Minimum number of evaluations needed before considering level advancement (default: 5)
+ """
+ self.current_level = starting_level
+ self.progress_threshold = progress_threshold
+ self.min_evaluations = min_evaluations
+
+ # Performance tracking
+ self.performance_history = {
+ level: [] for level in self.DIFFICULTY_LEVELS.keys()
+ }
+
+ # Ensure starting level is valid
+ if starting_level not in self.DIFFICULTY_LEVELS:
+ raise ValueError(
+ f"Invalid starting level: {starting_level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
+ )
+
+ def get_problem(self) -> Tuple[str, str, int]:
+ """
+ Generate a math problem at the current difficulty level.
+
+ Returns:
+ Tuple containing (problem_text, solution_text, generator_id)
+ """
+ # Get the available generator IDs for the current level
+ available_generators = self.DIFFICULTY_LEVELS[self.current_level]
+
+ # Try generators until one works
+ max_attempts = 5 # Limit the number of attempts to avoid infinite loops
+ attempts = 0
+
+ while attempts < max_attempts:
+ # Get a random generator ID from the current level
+ generator_id = random.choice(available_generators)
+
+ try:
+ # Generate the problem
+ problem, solution = mathgenerator.genById(generator_id)
+ return problem, solution, generator_id
+ except Exception as e:
+ # Log the error and try another generator
+ print(f"Error with generator {generator_id}: {str(e)}")
+ attempts += 1
+
+ # Remove the problematic generator from the available list for this session
+ if generator_id in available_generators:
+ available_generators.remove(generator_id)
+
+ # If we've exhausted all generators in this level, move to an adjacent level
+ if not available_generators:
+ fallback_level = max(
+ 1, min(7, self.current_level + random.choice([-1, 1]))
+ )
+ available_generators = self.DIFFICULTY_LEVELS[fallback_level].copy()
+
+ # If all attempts fail, return a simple addition problem as fallback
+ return "What is $2 + 2$?", "4", 0
+
+ def record_performance(self, generator_id: int, is_correct: bool) -> None:
+ """
+ Record the performance on a specific problem.
+
+ Args:
+ generator_id: The ID of the generator used
+ is_correct: Whether the answer was correct
+ """
+ # Find which level this generator belongs to
+ level = None
+ for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
+ if generator_id in generator_ids:
+ level = lvl
+ break
+
+ if level is not None:
+ # Add the result to the performance history
+ self.performance_history[level].append(is_correct)
+
+ def get_success_rate(self, level: int) -> Optional[float]:
+ """
+ Calculate the success rate for a specific level.
+
+ Args:
+ level: The difficulty level
+
+ Returns:
+ Success rate as a float between 0 and 1, or None if not enough data
+ """
+ history = self.performance_history[level]
+
+ if len(history) < self.min_evaluations:
+ return None
+
+ # Calculate success rate from recent evaluations
+ recent_history = history[-self.min_evaluations :]
+ return sum(recent_history) / len(recent_history)
+
+ def should_advance(self) -> bool:
+ """
+ Determine if the learner should advance to the next level.
+
+ Returns:
+ Boolean indicating whether to advance
+ """
+ success_rate = self.get_success_rate(self.current_level)
+
+ # If not enough data or below threshold, don't advance
+ if success_rate is None or success_rate < self.progress_threshold:
+ return False
+
+ # Check if there's a next level to advance to
+ return self.current_level < max(self.DIFFICULTY_LEVELS.keys())
+
+ def advance_difficulty(self) -> bool:
+ """
+ Advance to the next difficulty level if appropriate.
+
+ Returns:
+ Boolean indicating whether advancement occurred
+ """
+ if self.should_advance():
+ self.current_level += 1
+ return True
+ return False
+
+ def get_current_level(self) -> int:
+ """
+ Get the current difficulty level.
+
+ Returns:
+ Current level as an integer
+ """
+ return self.current_level
+
+ def get_num_levels(self) -> int:
+ """
+ Get the total number of difficulty levels.
+
+ Returns:
+ Total number of levels
+ """
+ return len(self.DIFFICULTY_LEVELS)
+
+ def get_level_description(self, level: Optional[int] = None) -> str:
+ """
+ Get a description of the specified difficulty level.
+
+ Args:
+ level: The level to describe (default: current level)
+
+ Returns:
+ String description of the level
+ """
+ if level is None:
+ level = self.current_level
+
+ level_descriptions = {
+ 1: "Basic arithmetic operations (addition, subtraction, multiplication, division)",
+ 2: "Basic operations with fractions and pre-algebra",
+ 3: "Basic geometry and more algebra",
+ 4: "More advanced algebra and basic statistics",
+ 5: "Vectors, matrices, and solid geometry",
+ 6: "Advanced topics (calculus, statistics, computer science)",
+ 7: "Most complex topics (complex numbers, advanced operations)",
+ }
+
+ return level_descriptions.get(
+ level, f"Custom level with IDs: {self.DIFFICULTY_LEVELS.get(level, [])}"
+ )
+
+ def reset(self, level: int = 1) -> None:
+ """
+ Reset the curriculum to a specific level and clear performance history.
+
+ Args:
+ level: The level to reset to (default: 1)
+ """
+ if level not in self.DIFFICULTY_LEVELS:
+ raise ValueError(
+ f"Invalid level: {level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
+ )
+
+ self.current_level = level
+ self.performance_history = {lvl: [] for lvl in self.DIFFICULTY_LEVELS.keys()}
+
+ def get_generator_info(self) -> List[Dict[str, Any]]:
+ """
+ Get information about all available generators.
+
+ Returns:
+ List of dictionaries containing generator information
+ """
+ generators = []
+ gen_list = mathgenerator.getGenList()
+
+ for gen in gen_list:
+ # Find which level this generator belongs to
+ level = None
+ for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
+ if gen[0] in generator_ids:
+ level = lvl
+ break
+
+ generators.append(
+ {
+ "id": gen[0],
+ "name": gen[1],
+ "function": gen[3],
+ "subject": gen[4],
+ "params": gen[5],
+ "difficulty_level": level,
+ }
+ )
+
+ return generators
diff --git a/environments/infinimath/infinimath.py b/environments/infinimath/infinimath.py
new file mode 100644
index 00000000..3135c137
--- /dev/null
+++ b/environments/infinimath/infinimath.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+"""
+Infinite Math - A reinforcement learning environment for math practice
+using the mathgenerator library with curriculum-based advancement.
+"""
+
+import random
+from typing import Any, Dict, List, Optional, Tuple
+
+from .curriculum import MathCurriculum
+
+
+class InfiniteMath:
+ """
+ A reinforcement learning environment for practicing math skills with
+ curriculum-based advancement.
+
+ This class uses the MathCurriculum to generate appropriate math problems
+ and track performance, advancing difficulty as the learner improves.
+ """
+
+ def __init__(
+ self,
+ starting_level: int = 1,
+ progress_threshold: float = 0.8,
+ min_evaluations: int = 5,
+ max_attempts_per_problem: int = 3,
+ ):
+ """
+ Initialize the InfiniteMath environment.
+
+ Args:
+ starting_level: Initial difficulty level (default: 1)
+ progress_threshold: Success rate needed to advance levels (default: 0.8)
+ min_evaluations: Minimum evaluations before considering advancement (default: 5)
+ max_attempts_per_problem: Maximum attempts allowed per problem (default: 3)
+ """
+ self.curriculum = MathCurriculum(
+ starting_level=starting_level,
+ progress_threshold=progress_threshold,
+ min_evaluations=min_evaluations,
+ )
+
+ self.max_attempts = max_attempts_per_problem
+ self.current_problem = None
+ self.current_solution = None
+ self.current_generator_id = None
+ self.attempts_remaining = 0
+ self.total_problems = 0
+ self.correct_problems = 0
+
+ # Generate the first problem
+ self._generate_problem()
+
+ def _generate_problem(self) -> None:
+ """Generate a new problem from the curriculum."""
+ self.current_problem, self.current_solution, self.current_generator_id = (
+ self.curriculum.get_problem()
+ )
+ self.attempts_remaining = self.max_attempts
+
+ def get_state(self) -> Dict[str, Any]:
+ """
+ Get the current state of the environment.
+
+ Returns:
+ Dictionary with current state information
+ """
+ return {
+ "problem": self.current_problem,
+ "attempts_remaining": self.attempts_remaining,
+ "current_level": self.curriculum.get_current_level(),
+ "total_levels": self.curriculum.get_num_levels(),
+ "level_description": self.curriculum.get_level_description(),
+ "total_problems": self.total_problems,
+ "correct_problems": self.correct_problems,
+ "accuracy": self.correct_problems / max(1, self.total_problems),
+ }
+
+ def submit_answer(self, answer: str) -> Dict[str, Any]:
+ """
+ Submit an answer to the current problem.
+
+ Args:
+ answer: The learner's answer to the current problem
+
+ Returns:
+ Dictionary with the result of the submission and updated state
+ """
+ if self.current_problem is None:
+ return {"error": "No active problem. Call reset() to start a new session."}
+
+ # Clean up the answer for comparison (strip whitespace, convert to lowercase)
+ cleaned_answer = str(answer).strip().lower()
+ cleaned_solution = str(self.current_solution).strip().lower()
+
+ # Check if the answer is correct
+ is_correct = cleaned_answer == cleaned_solution
+
+ # Update attempts
+ self.attempts_remaining -= 1
+
+ result = {
+ "is_correct": is_correct,
+ "correct_answer": (
+ self.current_solution
+ if self.attempts_remaining == 0 or is_correct
+ else None
+ ),
+ "attempts_remaining": self.attempts_remaining,
+ }
+
+ # If correct or out of attempts, record performance and generate a new problem
+ if is_correct or self.attempts_remaining == 0:
+ self.total_problems += 1
+ if is_correct:
+ self.correct_problems += 1
+
+ # Record performance in the curriculum
+ self.curriculum.record_performance(self.current_generator_id, is_correct)
+
+ # Check if we should advance to the next level
+ did_advance = self.curriculum.advance_difficulty()
+ result["did_advance_level"] = did_advance
+
+ if did_advance:
+ result["new_level"] = self.curriculum.get_current_level()
+ result["level_description"] = self.curriculum.get_level_description()
+
+ # Generate a new problem
+ self._generate_problem()
+ result["new_problem"] = self.current_problem
+
+ return result
+
+ def reset(self, level: Optional[int] = None) -> Dict[str, Any]:
+ """
+ Reset the environment, optionally to a specific difficulty level.
+
+ Args:
+ level: The difficulty level to reset to (default: current level)
+
+ Returns:
+ Dictionary with the new state
+ """
+ if level is not None:
+ self.curriculum.reset(level)
+
+ self.total_problems = 0
+ self.correct_problems = 0
+
+ # Generate a new problem
+ self._generate_problem()
+
+ return self.get_state()
+
+ def get_difficulty_stats(self) -> Dict[int, Dict[str, Any]]:
+ """
+ Get performance statistics for each difficulty level.
+
+ Returns:
+ Dictionary with statistics for each level
+ """
+ stats = {}
+
+ for level in self.curriculum.DIFFICULTY_LEVELS.keys():
+ success_rate = self.curriculum.get_success_rate(level)
+ history = self.curriculum.performance_history[level]
+
+ stats[level] = {
+ "description": self.curriculum.get_level_description(level),
+ "problems_attempted": len(history),
+ "success_rate": (
+ success_rate if success_rate is not None else float("nan")
+ ),
+ "is_current_level": level == self.curriculum.get_current_level(),
+ }
+
+ return stats
+
+
+def main():
+ """Example usage of the InfiniteMath environment."""
+ # Create the environment
+ env = InfiniteMath(starting_level=1)
+
+ print("Welcome to InfiniteMath!")
+ print(
+ f"Starting at level {env.get_state()['current_level']}: {env.get_state()['level_description']}"
+ )
+
+ playing = True
+ while playing:
+ # Display current problem
+ state = env.get_state()
+ print("\n" + "=" * 50)
+ print(f"Level {state['current_level']}/{state['total_levels']}")
+ print(f"Problem: {state['problem']}")
+ print(f"Attempts remaining: {state['attempts_remaining']}")
+
+ # Get user input
+ answer = input("Your answer (or 'q' to quit, 'r' to reset): ")
+
+ if answer.lower() == "q":
+ playing = False
+ continue
+ elif answer.lower() == "r":
+ level = input("Reset to level (1-7, or press Enter for current level): ")
+ if level and level.isdigit() and 1 <= int(level) <= 7:
+ env.reset(int(level))
+ else:
+ env.reset()
+ continue
+
+ # Submit the answer
+ result = env.submit_answer(answer)
+
+ # Display the result
+ if result.get("is_correct", False):
+ print("Correct!")
+ else:
+ print("Incorrect.")
+
+ if result.get("correct_answer") is not None:
+ print(f"The correct answer is: {result['correct_answer']}")
+
+ # Check if we advanced to a new level
+ if result.get("did_advance_level", False):
+ print(f"\nCongratulations! You've advanced to level {result['new_level']}!")
+ print(f"New level: {result['level_description']}")
+
+ # If we have a new problem, show it
+ if result.get("new_problem") is not None:
+ print("\nNext problem is ready.")
+
+ # Show final statistics
+ print("\nFinal Statistics:")
+ print(f"Total problems attempted: {env.total_problems}")
+ print(f"Correct answers: {env.correct_problems}")
+ if env.total_problems > 0:
+ print(f"Overall accuracy: {env.correct_problems / env.total_problems:.2%}")
+
+ level_stats = env.get_difficulty_stats()
+ print("\nPerformance by level:")
+ for level, stats in level_stats.items():
+ if stats["problems_attempted"] > 0:
+ print(
+ f"Level {level}: {stats['success_rate']:.2%} success rate ({stats['problems_attempted']} problems)"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/environments/infinimath/infinimath_env.py b/environments/infinimath/infinimath_env.py
new file mode 100644
index 00000000..e946e7e7
--- /dev/null
+++ b/environments/infinimath/infinimath_env.py
@@ -0,0 +1,745 @@
+import asyncio
+import json
+import logging
+import random
+import re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from trajectoryhandler.envs.base import (
+ BaseEnv,
+ BaseEnvConfig,
+ OpenaiConfig,
+ ScoredDataGroup,
+)
+from trajectoryhandler.envs.reward_fns import registry
+from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward
+from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer
+
+from .curriculum import MathCurriculum
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
+
+The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
+
+Follow these steps:
+1. Understand the problem carefully
+2. Plan your approach
+3. Execute the calculations step-by-step
+4. Verify your solution
+5. Express the final answer as \\boxed{your answer here}
+
+You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
+
+Your answer format should be:
+
+[Your detailed step-by-step reasoning process here]
+
+
+\\boxed{your final answer here}
+
+Remember to format your final answer correctly as this is important for evaluation."""
+
+
+class InfiniteMathEnvConfig(BaseEnvConfig):
+ """Configuration for the InfiniteMath environment."""
+
+ # Curriculum parameters
+ starting_level: int = 1
+ progress_threshold: float = 0.8
+ min_evaluations: int = 5
+
+ # Environment parameters
+ max_attempts_per_problem: int = 3
+ correct_reward: float = 1.0
+ incorrect_reward: float = -1.0
+
+ # Length penalty parameters
+ apply_length_penalty: bool = True
+ length_threshold_ratio: float = (
+ 0.5 # Percentage of max_token_length before penalties apply
+ )
+
+ # Completion parameters
+ temperature: float = 0.7
+ top_p: float = 0.9
+
+ # Reward functions
+ reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"]
+ accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward
+ format_reward_weight: float = (
+ 0.2 # Weight for the format reward relative to correctness
+ )
+ boxed_reward_weight: float = (
+ 0.3 # Weight for the boxed answer reward relative to correctness
+ )
+
+
+class InfiniteMathEnv(BaseEnv):
+ """Environment for procedurally generated math problems with curriculum advancement."""
+
+ def __init__(
+ self,
+ config: InfiniteMathEnvConfig,
+ server_configs: Union[List[OpenaiConfig], OpenaiConfig],
+ slurm=True,
+ testing=False,
+ ):
+ super().__init__(config, server_configs, slurm, testing)
+ self.config = config # Override with our specific config class
+
+ # Initialize tracking metrics
+ self.percent_correct_buffer = []
+ self.level_correct_buffer = {
+ i: [] for i in range(1, 8)
+ } # Track correctness for each level
+ self.eval_metrics = []
+
+ # Curriculum will be initialized in setup()
+ self.curriculum = None
+
+ # Set the system prompt
+ self.system_prompt = system_prompt
+
+ # Initialize reward function
+ self.reward_function = self._initialize_reward_function()
+
+ def _initialize_reward_function(self):
+ """Initialize the combined reward function for scoring."""
+ if hasattr(self.config, "reward_functions") and self.config.reward_functions:
+ # Configure parameters for specific reward functions
+ reward_configs = []
+
+ for reward_func in self.config.reward_functions:
+ if isinstance(reward_func, str):
+ # String name case - handle known rewards with custom params
+ if reward_func == "accuracy":
+ # Configure accuracy reward
+ accuracy_config = {
+ "type": "accuracy",
+ "weight": self.config.accuracy_reward_weight,
+ "params": {
+ "split_on_think_tag": True, # Only look at what's after
+ "tolerance": 1e-6, # Tolerance for number comparisons
+ },
+ }
+ logger.info(f"Adding accuracy reward with config: {accuracy_config}")
+ reward_configs.append(accuracy_config)
+ elif reward_func == "format":
+ # Configure format reward with think tags and explicit weight
+ format_config = {
+ "type": "format",
+ "weight": self.config.format_reward_weight,
+ "params": {
+ "preferred_tags": ["think"],
+ },
+ }
+ logger.info(f"Adding format reward with config: {format_config}")
+ reward_configs.append(format_config)
+ elif reward_func == "boxed":
+ # Configure boxed reward with proper parameters and explicit weight
+ boxed_config = {
+ "type": "boxed",
+ "weight": self.config.boxed_reward_weight,
+ "params": {
+ "require_outside_think": True,
+ },
+ }
+ logger.info(f"Adding boxed reward with config: {boxed_config}")
+ reward_configs.append(boxed_config)
+ else:
+ # Pass through other reward functions as is
+ logger.info(f"Adding generic reward function: {reward_func}")
+ reward_configs.append(reward_func)
+ else:
+ # Dict case - pass through as is
+ logger.info(f"Adding reward config: {reward_func}")
+ reward_configs.append(reward_func)
+
+ # Create the reward function(s)
+ if len(reward_configs) == 1:
+ logger.info(f"Creating single reward function: {reward_configs[0]}")
+ return registry.create(reward_configs[0])
+ else:
+ logger.info(f"Creating combined reward function with {len(reward_configs)} rewards")
+ # Add explicit normalization to sum to 1.0
+ return CombinedReward(rewards=reward_configs, normalization="none")
+
+ async def setup(self):
+ """Initialize the environment and curriculum."""
+ logger.info("Setting up InfiniteMathEnv")
+
+ # Initialize curriculum
+ self.curriculum = MathCurriculum(
+ starting_level=self.config.starting_level,
+ progress_threshold=self.config.progress_threshold,
+ min_evaluations=self.config.min_evaluations,
+ )
+
+ # Generate some test problems for each level for evaluation
+ self.eval_problems = {}
+ for level in range(1, 8):
+ self.eval_problems[level] = []
+ temp_curriculum = MathCurriculum(starting_level=level)
+ # Generate 10 test problems for each level
+ attempts = 0
+ max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones
+
+ while (
+ len(self.eval_problems[level]) < 10
+ and attempts < max_attempts_per_level
+ ):
+ try:
+ problem, solution, generator_id = temp_curriculum.get_problem()
+ # Strip LaTeX delimiters
+ problem = self.strip_latex_delimiters(problem)
+ solution = self.strip_latex_delimiters(solution)
+ self.eval_problems[level].append((problem, solution, generator_id))
+ except Exception as e:
+ logger.warning(
+ f"Error generating evaluation problem for level {level}: {e}"
+ )
+ attempts += 1
+
+ logger.info(
+ f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
+ )
+
+ # If any levels have no problems, add a simple fallback
+ for level in range(1, 8):
+ if not self.eval_problems[level]:
+ logger.warning(
+ f"No valid evaluation problems for level {level}, adding fallback"
+ )
+ if level == 1:
+ self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
+ elif level == 2:
+ self.eval_problems[level].append(
+ ("What is the square root of 16?", "4", 6)
+ )
+ elif level == 3:
+ self.eval_problems[level].append(
+ (
+ "What is the area of a triangle with base 6 and height 8?",
+ "24",
+ 18,
+ )
+ )
+ elif level == 4:
+ self.eval_problems[level].append(
+ ("What is the solution to x + 5 = 12?", "7", 26)
+ )
+ elif level == 5:
+ self.eval_problems[level].append(
+ ("What is the volume of a cube with side length 3?", "27", 33)
+ )
+ elif level == 6:
+ self.eval_problems[level].append(
+ ("What is 5 factorial?", "120", 31)
+ )
+ else:
+ self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
+
+ def strip_latex_delimiters(self, text: str) -> str:
+ """Strip LaTeX delimiters ($...$) from text."""
+ # Handle both inline expressions $...$ and expressions that make up the entire string
+ return re.sub(r"\$(.*?)\$", r"\1", text)
+
+ def save_checkpoint(self, step, data=None):
+ """Save curriculum state in checkpoint."""
+ if data is None:
+ data = {}
+
+ # Save curriculum state
+ data["curriculum_level"] = self.curriculum.get_current_level()
+ data["performance_history"] = {
+ str(k): v for k, v in self.curriculum.performance_history.items()
+ }
+
+ super().save_checkpoint(step, data)
+
+ def load_checkpoint(self):
+ """Load curriculum state from checkpoint."""
+ super().load_checkpoint()
+
+ # Check if we have curriculum data in the checkpoint
+ checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
+ try:
+ with open(checkpoint_path, "r") as f:
+ data = json.load(f)
+
+ # Restore curriculum state if available
+ if "curriculum_level" in data:
+ level = data["curriculum_level"]
+ self.curriculum.current_level = level
+
+ if "performance_history" in data:
+ # Convert string keys back to integers
+ self.curriculum.performance_history = {
+ int(k): v for k, v in data["performance_history"].items()
+ }
+ except (FileNotFoundError, json.JSONDecodeError) as e:
+ logger.warning(f"Failed to load checkpoint: {e}")
+
+ async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
+ """Log metrics to wandb."""
+ if wandb_metrics is None:
+ wandb_metrics = {}
+
+ # Log overall correct percentage
+ try:
+ wandb_metrics["train/percent_correct"] = sum(
+ self.percent_correct_buffer
+ ) / max(1, len(self.percent_correct_buffer))
+ except ZeroDivisionError:
+ pass
+
+ # Log per-level metrics
+ for level, buffer in self.level_correct_buffer.items():
+ if buffer:
+ wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
+ buffer
+ )
+ wandb_metrics[f"train/level_{level}_count"] = len(buffer)
+
+ # Log current level and curriculum progress
+ if self.curriculum:
+ current_level = self.curriculum.get_current_level()
+ max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
+
+ wandb_metrics["curriculum/current_level"] = current_level
+ wandb_metrics["curriculum/max_level"] = max_level
+ wandb_metrics["curriculum/progress_percent"] = (
+ current_level / max_level
+ ) * 100
+
+ # Log level description
+ wandb_metrics["curriculum/level_description"] = (
+ self.curriculum.get_level_description()
+ )
+
+ # Log performance history for current level
+ if current_level in self.curriculum.performance_history:
+ history = self.curriculum.performance_history[current_level]
+ if history:
+ recent_history = history[
+ -min(len(history), self.curriculum.min_evaluations) :
+ ]
+ if recent_history:
+ success_rate = sum(recent_history) / len(recent_history)
+ wandb_metrics["curriculum/current_level_success_rate"] = (
+ success_rate
+ )
+ wandb_metrics["curriculum/threshold_to_advance"] = (
+ self.curriculum.progress_threshold
+ )
+ wandb_metrics["curriculum/remaining_to_threshold"] = max(
+ 0, self.curriculum.progress_threshold - success_rate
+ )
+
+ # Log reward function metrics
+ if hasattr(self, "reward_function") and self.wandb:
+ if hasattr(self.reward_function, "set_wandb_logger"):
+ self.reward_function.set_wandb_logger(self.wandb)
+
+ # Log the reward configurations
+ if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
+ # Log the reward configuration
+ wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
+ wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
+
+ if hasattr(self.config, "format_reward_weight"):
+ wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
+
+ if hasattr(self.config, "boxed_reward_weight"):
+ wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
+
+ # Add eval metrics
+ for item in self.eval_metrics:
+ wandb_metrics[item[0]] = item[1]
+
+ # Reset buffers
+ self.percent_correct_buffer = []
+ for level in self.level_correct_buffer:
+ self.level_correct_buffer[level] = []
+ self.eval_metrics = []
+
+ # Call the parent method to handle remaining metrics
+ await super().wandb_log(wandb_metrics)
+
+ async def get_next_item(self):
+ """Get the next problem based on current curriculum level."""
+ problem, solution, generator_id = self.curriculum.get_problem()
+
+ # Strip LaTeX delimiters from problem and solution
+ problem = self.strip_latex_delimiters(problem)
+ solution = self.strip_latex_delimiters(solution)
+
+ # Create a message with the problem
+ prompt = tuple([frozenset({"role": "user", "content": problem}.items())])
+
+ # Return the problem with metadata
+ return (prompt, solution, generator_id)
+
+ async def evaluate(self, *args, **kwargs):
+ """Evaluate the model on test problems at the current curriculum level."""
+ current_level = self.curriculum.get_current_level()
+ logger.info(f"Starting evaluation for curriculum level {current_level}")
+
+ # Only evaluate problems at the current level
+ eval_tasks = []
+ eval_generator_ids = []
+ if current_level in self.eval_problems:
+ for problem, solution, generator_id in self.eval_problems[current_level]:
+ eval_tasks.append(
+ self.evaluate_single_problem(problem, solution, current_level)
+ )
+ eval_generator_ids.append(generator_id)
+
+ if not eval_tasks:
+ logger.warning(
+ f"No evaluation problems available for level {current_level}"
+ )
+ return []
+
+ # Run evaluation tasks
+ logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
+ results = await asyncio.gather(*eval_tasks)
+
+ # Calculate accuracy for the current level
+ correct_count = sum(1 for _, is_correct in results if is_correct)
+ total_count = len(results)
+ accuracy = correct_count / total_count if total_count > 0 else 0
+
+ logger.info(
+ f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
+ )
+
+ # Record metrics for the current level
+ self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
+ self.eval_metrics.append(("eval/current_level", current_level))
+
+ # Record the actual evaluation results in the curriculum's performance history
+ for i, (_, is_correct) in enumerate(results):
+ if i < len(eval_generator_ids):
+ # Record the actual result
+ self.curriculum.record_performance(eval_generator_ids[i], is_correct)
+ else:
+ # Fallback if somehow the lists are different lengths
+ sample_generator_id = random.choice(
+ self.curriculum.DIFFICULTY_LEVELS[current_level]
+ )
+ self.curriculum.record_performance(sample_generator_id, is_correct)
+
+ # Try to advance to the next level
+ advanced = self.curriculum.advance_difficulty()
+ new_level = self.curriculum.get_current_level()
+
+ if advanced:
+ logger.info(f"Advanced from level {current_level} to level {new_level}!")
+ self.eval_metrics.append(("eval/advanced_level", 1))
+ else:
+ logger.info(f"Remaining at level {current_level}")
+ self.eval_metrics.append(("eval/advanced_level", 0))
+
+ return self.eval_metrics
+
+ async def evaluate_single_problem(
+ self, problem: str, solution: str, level: int
+ ) -> Tuple[int, bool]:
+ """Evaluate a single problem."""
+ try:
+ logger.debug(f"Evaluating level {level} problem: {problem[:30]}...")
+
+ # Convert messages to a single prompt using the tokenizer
+ messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": problem},
+ ]
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
+
+ # Add prefilled thinking starter
+ prefill = "\n\n"
+ prefilled_prompt = prompt + prefill
+
+ # Generate completion using the prompt
+ logger.debug(f"Requesting completion for problem: {problem[:30]}...")
+ completion = await self.server.completion(
+ prompt=prefilled_prompt,
+ n=1,
+ max_tokens=self.config.max_token_length,
+ temperature=0.0, # Use 0 temperature for deterministic results
+ top_p=1.0,
+ split="eval",
+ )
+
+ # Extract the completion text and prepend the thinking starter
+ model_answer = prefill + (
+ completion.choices[0].text
+ if hasattr(completion.choices[0], "text")
+ else completion.choices[0].message.content
+ )
+
+ # Check if the answer is correct
+ is_correct = self.check_answer(model_answer, solution)
+ logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
+
+ return level, is_correct
+ except Exception as e:
+ logger.error(f"Error evaluating problem: {e}")
+ # Return a failed result in case of error
+ return level, False
+
+ def check_answer(self, model_answer: str, solution: str) -> bool:
+ """Check if the model's answer matches the solution."""
+ # Extract the part after the thinking block
+ after_think_part = (
+ model_answer.split("")[-1].strip()
+ if "" in model_answer
+ else model_answer
+ )
+
+ # Extract the boxed answer if present
+ boxed_answer = self.extract_boxed_answer(after_think_part)
+ if not boxed_answer:
+ # Try to find the answer in the last line
+ lines = after_think_part.strip().split("\n")
+ if lines:
+ boxed_answer = lines[-1].strip()
+
+ # Clean up answers for comparison (remove spaces, convert to lowercase)
+ model_clean = self.clean_for_comparison(
+ boxed_answer if boxed_answer else after_think_part
+ )
+ solution_clean = self.clean_for_comparison(solution)
+
+ # Check if they match
+ return model_clean == solution_clean
+
+ def extract_boxed_answer(self, text: str) -> Optional[str]:
+ """Extract answer from a LaTeX boxed expression."""
+ # Try to find boxed content
+ boxed_match = re.search(r"\\boxed{([^}]*)}", text)
+ if boxed_match:
+ return boxed_match.group(1)
+ return None
+
+ def clean_for_comparison(self, text: str) -> str:
+ """Clean text for comparison."""
+ # Remove LaTeX commands, spaces, commas, and convert to lowercase
+ cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
+ cleaned = re.sub(r"[,\s]", "", cleaned)
+ cleaned = cleaned.lower()
+ return cleaned
+
+ async def collect_trajectories(self, item) -> Tuple[List, List]:
+ """Collect trajectories for the current item."""
+ # Extract information from the item
+ problem_prompt, solution, generator_id = item
+
+ # Create prompt using tokenizer's chat template
+ # Add prefilled thinking starter
+ prefill = "\n\n"
+ messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": dict(problem_prompt[0])["content"]},
+ {"role": "assistant", "content": prefill},
+ ]
+
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
+ # Generate completions using completion API
+ completions = await self.server.completion(
+ prompt=prompt,
+ n=self.config.group_size,
+ max_tokens=self.config.max_token_length,
+ temperature=self.config.temperature,
+ top_p=self.config.top_p,
+ )
+
+ # Prepare data for scoring
+ to_score = []
+
+ # Track level for metrics
+ level = None
+ for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
+ if generator_id in generator_ids:
+ level = lvl
+ break
+
+ # Process each completion
+ for i, completion in enumerate(completions.choices):
+ # Get the completion text and prepend the thinking starter
+ model_answer = prefill + (
+ completion.text
+ if hasattr(completion, "text")
+ else completion.message.content
+ )
+ print("model_answer", model_answer)
+
+ # Build complete message sequence
+ full_messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": dict(problem_prompt[0])["content"]},
+ {"role": "assistant", "content": model_answer},
+ ]
+
+ # Add to scoring list
+ to_score.append((full_messages, solution, generator_id, level))
+
+ # Record performance in curriculum for each item we're scoring
+ # This will be called again after scoring, but that's fine
+
+ # No additional items for backlog
+ backlog = []
+
+ return to_score, backlog
+
+ async def score(self, rollout_group_data) -> ScoredDataGroup:
+ """Score the collected trajectories."""
+ scored_data = ScoredDataGroup()
+ scored_data["tokens"] = []
+ scored_data["masks"] = []
+ scored_data["scores"] = []
+ scored_data["messages"] = []
+
+ # Format completions for reward function evaluation
+ format_completions = []
+
+ # Process each item in the rollout data
+ for messages, solution, generator_id, level in rollout_group_data:
+ # Extract the model's answer
+ model_answer = messages[-1]["content"]
+
+ # Add to format completions list for reward function
+ format_completions.append([{"role": "assistant", "content": model_answer}])
+
+ # Record performance in curriculum based on the answer and solution
+ # This will be updated after the reward functions are applied
+
+ # Apply all reward functions
+ reward_scores = []
+ unweighted_scores = []
+ if hasattr(self, "reward_function") and self.reward_function:
+ try:
+ # Apply the reward function (which may be a combined reward)
+ reward_scores = self.reward_function(format_completions, solution=solution)
+ logger.info(f"Reward scores: {reward_scores}")
+
+ # Debug individual rewards if it's a combined reward
+ if hasattr(self.reward_function, "rewards"):
+ logger.info(f"Combined reward with {len(self.reward_function.rewards)} components")
+ for i, reward in enumerate(self.reward_function.rewards):
+ if hasattr(reward, "compute"):
+ # Get raw unweighted scores
+ raw_scores = reward.compute(format_completions, solution=solution)
+ if hasattr(reward, "weight"):
+ logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}")
+ else:
+ logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}")
+ else:
+ logger.info(f"Using single reward: {type(self.reward_function).__name__}")
+
+ except Exception as e:
+ logger.error(f"Error applying reward functions: {e}")
+ logger.exception(e)
+ reward_scores = [0.0] * len(format_completions)
+
+ # Now update curriculum based on accuracy reward results
+ for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
+ # Extract accuracy from the combined reward if available
+ is_correct = False
+ if reward_scores and hasattr(self.reward_function, "rewards"):
+ for reward in self.reward_function.rewards:
+ if type(reward).__name__ == "AccuracyReward":
+ # Get raw scores from accuracy reward
+ accuracy_scores = reward.compute(format_completions, solution=solution)
+ is_correct = accuracy_scores[i] > 0
+ break
+
+ # Record answer correctness for tracking
+ self.percent_correct_buffer.append(1 if is_correct else 0)
+ if level is not None:
+ self.level_correct_buffer[level].append(1 if is_correct else 0)
+
+ # Record performance in curriculum
+ self.curriculum.record_performance(generator_id, is_correct)
+
+ # Combine scores and add to scored data
+ for i, (messages, _, _, _) in enumerate(rollout_group_data):
+ # Use the reward score directly (all weights are applied)
+ combined_score = reward_scores[i] if reward_scores else 0.0
+
+ logger.info(f"Final score for item {i}: {combined_score}")
+
+ # Tokenize for the trainer
+ tokens_dict = tokenize_for_trainer(
+ self.tokenizer,
+ messages,
+ None,
+ )
+
+ # Add to scored data
+ scored_data["tokens"].append(tokens_dict["tokens"])
+ scored_data["masks"].append(tokens_dict["masks"])
+ scored_data["scores"].append(combined_score)
+ scored_data["messages"].append(messages)
+
+ # Advance difficulty if criteria met
+ self.curriculum.advance_difficulty()
+
+ return scored_data
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ async def main():
+ config = InfiniteMathEnvConfig(
+ tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
+ group_size=8,
+ use_wandb=True,
+ max_num_workers=64,
+ rollout_server_url="http://localhost:8000",
+ total_steps=10000,
+ batch_size=1024,
+ steps_per_eval=25,
+ max_token_length=4096,
+ inference_weight=1.0,
+ wandb_name="infinite_math",
+ data_path_to_save_groups="data/infinite_math_groups.jsonl",
+ # InfiniteMath specific config
+ starting_level=1,
+ progress_threshold=0.8,
+ min_evaluations=10,
+ correct_reward=1.0,
+ incorrect_reward=-0.5,
+ apply_length_penalty=True,
+ length_threshold_ratio=0.6,
+ # Completion parameters
+ temperature=0.7,
+ top_p=0.9,
+ # Reward function configuration - use name directly
+ reward_functions=["accuracy", "format", "boxed"],
+ accuracy_reward_weight=1.0,
+ format_reward_weight=0.2,
+ boxed_reward_weight=0.3,
+ )
+
+ openai_config = OpenaiConfig(
+ model_name="NousResearch/Nous-Hermes-2-Yi-34B",
+ base_url="http://localhost:9004/v1",
+ api_key="x",
+ num_requests_for_eval=64,
+ )
+
+ env = InfiniteMathEnv(
+ config=config,
+ server_configs=[openai_config],
+ slurm=False,
+ )
+
+ await env.env_manager()
+
+ asyncio.run(main())
diff --git a/environments/infinimath/infinimath_local_server.py b/environments/infinimath/infinimath_local_server.py
new file mode 100644
index 00000000..08cd0efa
--- /dev/null
+++ b/environments/infinimath/infinimath_local_server.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+import asyncio
+import logging
+import os
+import argparse
+
+from dotenv import load_dotenv
+from openai import OpenAI
+
+from environments.infinimath.infinimath_env import (
+ InfiniteMathEnv,
+ InfiniteMathEnvConfig,
+)
+from atroposlib.envs.base import OpenaiConfig
+from atroposlib.utils.config_handler import ConfigHandler
+
+load_dotenv()
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="InfiniteMath environment server")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="infinimath",
+ help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
+ )
+ return parser.parse_args()
+
+
+async def main():
+ logger.info("Starting InfiniteMath environment server")
+
+ # Parse command line arguments
+ args = parse_arguments()
+
+ # Initialize config handler and load configuration
+ config_handler = ConfigHandler()
+
+ # Determine config path
+ if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
+ config_path = args.config
+ else:
+ # short form that defaults to the envs directory
+ config_path = os.path.join(
+ config_handler.config_dir, f"envs/{args.config}.yaml"
+ )
+
+ logger.info(f"Loading configuration from: {config_path}")
+
+ try:
+ with open(config_path, "r") as f:
+ import yaml
+ raw_config = yaml.safe_load(f)
+ logger.info(f"Loaded configuration successfully")
+ except Exception as e:
+ logger.error(f"Error loading config directly: {e}")
+ logger.info("Falling back to default config handler")
+ raw_config = config_handler.load_config(args)
+
+ # Configure the InfiniteMath environment with values from config
+ config = InfiniteMathEnvConfig(
+ # Base environment parameters
+ tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
+ group_size=raw_config.get("group_size", 1),
+ use_wandb=raw_config.get("use_wandb", False),
+ max_num_workers=raw_config.get("max_num_workers", 1),
+ rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
+ total_steps=raw_config.get("total_steps", 1),
+ batch_size=raw_config.get("batch_size", 1),
+ steps_per_eval=raw_config.get("steps_per_eval", 2),
+ max_token_length=raw_config.get("max_token_length", 4096),
+ wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
+ ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
+
+ # InfiniteMath specific parameters
+ starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
+ progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
+ min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
+ correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
+ incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
+ apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
+ length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
+ temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
+ top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
+ reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
+ accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
+ format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
+ boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
+ )
+
+ # Server configuration from config file or defaults
+ server_configs = []
+
+ if "server_configs" in raw_config:
+ for server_config in raw_config["server_configs"]:
+ api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
+ # Handle environment variable references like ${OPENAI_API_KEY}
+ if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
+ env_var = api_key[2:-1]
+ api_key = os.environ.get(env_var, "")
+
+ server_configs.append(
+ OpenaiConfig(
+ model_name=server_config.get("model_name", "gpt-4.1-nano"),
+ base_url=server_config.get("base_url", None),
+ api_key=api_key,
+ num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
+ )
+ )
+ else:
+ # Default configuration if not specified in config file
+ server_configs.append(
+ OpenaiConfig(
+ model_name="gpt-4.1-nano",
+ base_url=None,
+ api_key=os.environ.get("OPENAI_API_KEY"),
+ num_requests_for_eval=70,
+ )
+ )
+
+ # Create the environment
+ env = InfiniteMathEnv(
+ config=config,
+ server_configs=server_configs,
+ slurm=False,
+ )
+
+ # Setup the environment
+ await env.setup()
+ logger.info("Environment setup complete")
+
+ # Log the number of evaluation problems
+ total_problems = sum(len(probs) for probs in env.eval_problems.values())
+ logger.info(
+ f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
+ )
+
+ # Get a math problem
+ item = await env.get_next_item()
+ problem_prompt, solution, generator_id = item
+
+ logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
+ logger.info(f"Solution: {solution}")
+
+ # Collect trajectories
+ logger.info("Collecting trajectories...")
+ trajectories_data, backlog = await env.collect_trajectories(item)
+
+ # Score the collected trajectories
+ logger.info("Scoring trajectories...")
+ scored_data = await env.score(trajectories_data)
+
+ input("Press Enter to continue...")
+ # Print scores
+ logger.info(f"Scores: {scored_data['scores']}")
+
+ # Log the correct/incorrect counts
+ correct_count = sum(1 for score in scored_data["scores"] if score > 0)
+ logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
+
+ # Test evaluation function specifically
+ logger.info("\n=== Testing Evaluation Function ===")
+
+ # Record the current level
+ initial_level = env.curriculum.get_current_level()
+ logger.info(f"Current level before evaluation: {initial_level}")
+ logger.info(f"Level description: {env.curriculum.get_level_description()}")
+ logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
+ logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
+
+ # Run the evaluate method
+ eval_metrics = await env.evaluate()
+
+ # Display evaluation results
+ logger.info("Evaluation metrics:")
+ for metric_name, metric_value in eval_metrics:
+ logger.info(f" - {metric_name}: {metric_value}")
+
+ # Check if the level advanced
+ new_level = env.curriculum.get_current_level()
+ if new_level > initial_level:
+ logger.info(f"Successfully advanced to level {new_level}!")
+ logger.info(f"New level description: {env.curriculum.get_level_description()}")
+ else:
+ # Show current progress toward advancement
+ current_level = env.curriculum.get_current_level()
+ if current_level in env.curriculum.performance_history:
+ history = env.curriculum.performance_history[current_level]
+ if len(history) >= env.curriculum.min_evaluations:
+ recent_history = history[-env.curriculum.min_evaluations :]
+ success_rate = sum(recent_history) / len(recent_history)
+ logger.info(
+ f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
+ )
+ else:
+ logger.info(
+ f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
+ )
+
+ # Show all levels and their performance history
+ logger.info("\nPerformance history by level:")
+ for level in sorted(env.curriculum.performance_history.keys()):
+ history = env.curriculum.performance_history[level]
+ if history:
+ success_rate = sum(history) / len(history)
+ logger.info(
+ f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
+ )
+ else:
+ logger.info(f" Level {level}: No data")
+
+ # Test curriculum advancement with simulated performance history
+ logger.info("\n=== Testing Curriculum Advancement ===")
+
+ # Simulate good performance at current level
+ for _ in range(env.config.min_evaluations):
+ # Get a problem from current level
+ item = await env.get_next_item()
+ _, _, generator_id = item
+
+ # Record positive performance
+ env.curriculum.record_performance(generator_id, True)
+
+ # Try to advance difficulty
+ did_advance = env.curriculum.advance_difficulty()
+ new_level = env.curriculum.get_current_level()
+
+ logger.info(f"Curriculum advancement test:")
+ logger.info(f" - Starting level: {initial_level}")
+ logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
+ logger.info(f" - Did advance: {did_advance}")
+ logger.info(f" - New level: {new_level}")
+
+ logger.info("Test server completed successfully")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/environments/infinimath/infinimath_server.py b/environments/infinimath/infinimath_server.py
new file mode 100644
index 00000000..08cd0efa
--- /dev/null
+++ b/environments/infinimath/infinimath_server.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+import asyncio
+import logging
+import os
+import argparse
+
+from dotenv import load_dotenv
+from openai import OpenAI
+
+from environments.infinimath.infinimath_env import (
+ InfiniteMathEnv,
+ InfiniteMathEnvConfig,
+)
+from atroposlib.envs.base import OpenaiConfig
+from atroposlib.utils.config_handler import ConfigHandler
+
+load_dotenv()
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="InfiniteMath environment server")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="infinimath",
+ help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
+ )
+ return parser.parse_args()
+
+
+async def main():
+ logger.info("Starting InfiniteMath environment server")
+
+ # Parse command line arguments
+ args = parse_arguments()
+
+ # Initialize config handler and load configuration
+ config_handler = ConfigHandler()
+
+ # Determine config path
+ if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
+ config_path = args.config
+ else:
+ # short form that defaults to the envs directory
+ config_path = os.path.join(
+ config_handler.config_dir, f"envs/{args.config}.yaml"
+ )
+
+ logger.info(f"Loading configuration from: {config_path}")
+
+ try:
+ with open(config_path, "r") as f:
+ import yaml
+ raw_config = yaml.safe_load(f)
+ logger.info(f"Loaded configuration successfully")
+ except Exception as e:
+ logger.error(f"Error loading config directly: {e}")
+ logger.info("Falling back to default config handler")
+ raw_config = config_handler.load_config(args)
+
+ # Configure the InfiniteMath environment with values from config
+ config = InfiniteMathEnvConfig(
+ # Base environment parameters
+ tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
+ group_size=raw_config.get("group_size", 1),
+ use_wandb=raw_config.get("use_wandb", False),
+ max_num_workers=raw_config.get("max_num_workers", 1),
+ rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
+ total_steps=raw_config.get("total_steps", 1),
+ batch_size=raw_config.get("batch_size", 1),
+ steps_per_eval=raw_config.get("steps_per_eval", 2),
+ max_token_length=raw_config.get("max_token_length", 4096),
+ wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
+ ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
+
+ # InfiniteMath specific parameters
+ starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
+ progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
+ min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
+ correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
+ incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
+ apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
+ length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
+ temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
+ top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
+ reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
+ accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
+ format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
+ boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
+ )
+
+ # Server configuration from config file or defaults
+ server_configs = []
+
+ if "server_configs" in raw_config:
+ for server_config in raw_config["server_configs"]:
+ api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
+ # Handle environment variable references like ${OPENAI_API_KEY}
+ if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
+ env_var = api_key[2:-1]
+ api_key = os.environ.get(env_var, "")
+
+ server_configs.append(
+ OpenaiConfig(
+ model_name=server_config.get("model_name", "gpt-4.1-nano"),
+ base_url=server_config.get("base_url", None),
+ api_key=api_key,
+ num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
+ )
+ )
+ else:
+ # Default configuration if not specified in config file
+ server_configs.append(
+ OpenaiConfig(
+ model_name="gpt-4.1-nano",
+ base_url=None,
+ api_key=os.environ.get("OPENAI_API_KEY"),
+ num_requests_for_eval=70,
+ )
+ )
+
+ # Create the environment
+ env = InfiniteMathEnv(
+ config=config,
+ server_configs=server_configs,
+ slurm=False,
+ )
+
+ # Setup the environment
+ await env.setup()
+ logger.info("Environment setup complete")
+
+ # Log the number of evaluation problems
+ total_problems = sum(len(probs) for probs in env.eval_problems.values())
+ logger.info(
+ f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
+ )
+
+ # Get a math problem
+ item = await env.get_next_item()
+ problem_prompt, solution, generator_id = item
+
+ logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
+ logger.info(f"Solution: {solution}")
+
+ # Collect trajectories
+ logger.info("Collecting trajectories...")
+ trajectories_data, backlog = await env.collect_trajectories(item)
+
+ # Score the collected trajectories
+ logger.info("Scoring trajectories...")
+ scored_data = await env.score(trajectories_data)
+
+ input("Press Enter to continue...")
+ # Print scores
+ logger.info(f"Scores: {scored_data['scores']}")
+
+ # Log the correct/incorrect counts
+ correct_count = sum(1 for score in scored_data["scores"] if score > 0)
+ logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
+
+ # Test evaluation function specifically
+ logger.info("\n=== Testing Evaluation Function ===")
+
+ # Record the current level
+ initial_level = env.curriculum.get_current_level()
+ logger.info(f"Current level before evaluation: {initial_level}")
+ logger.info(f"Level description: {env.curriculum.get_level_description()}")
+ logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
+ logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
+
+ # Run the evaluate method
+ eval_metrics = await env.evaluate()
+
+ # Display evaluation results
+ logger.info("Evaluation metrics:")
+ for metric_name, metric_value in eval_metrics:
+ logger.info(f" - {metric_name}: {metric_value}")
+
+ # Check if the level advanced
+ new_level = env.curriculum.get_current_level()
+ if new_level > initial_level:
+ logger.info(f"Successfully advanced to level {new_level}!")
+ logger.info(f"New level description: {env.curriculum.get_level_description()}")
+ else:
+ # Show current progress toward advancement
+ current_level = env.curriculum.get_current_level()
+ if current_level in env.curriculum.performance_history:
+ history = env.curriculum.performance_history[current_level]
+ if len(history) >= env.curriculum.min_evaluations:
+ recent_history = history[-env.curriculum.min_evaluations :]
+ success_rate = sum(recent_history) / len(recent_history)
+ logger.info(
+ f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
+ )
+ else:
+ logger.info(
+ f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
+ )
+
+ # Show all levels and their performance history
+ logger.info("\nPerformance history by level:")
+ for level in sorted(env.curriculum.performance_history.keys()):
+ history = env.curriculum.performance_history[level]
+ if history:
+ success_rate = sum(history) / len(history)
+ logger.info(
+ f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
+ )
+ else:
+ logger.info(f" Level {level}: No data")
+
+ # Test curriculum advancement with simulated performance history
+ logger.info("\n=== Testing Curriculum Advancement ===")
+
+ # Simulate good performance at current level
+ for _ in range(env.config.min_evaluations):
+ # Get a problem from current level
+ item = await env.get_next_item()
+ _, _, generator_id = item
+
+ # Record positive performance
+ env.curriculum.record_performance(generator_id, True)
+
+ # Try to advance difficulty
+ did_advance = env.curriculum.advance_difficulty()
+ new_level = env.curriculum.get_current_level()
+
+ logger.info(f"Curriculum advancement test:")
+ logger.info(f" - Starting level: {initial_level}")
+ logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
+ logger.info(f" - Did advance: {did_advance}")
+ logger.info(f" - New level: {new_level}")
+
+ logger.info("Test server completed successfully")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/environments/infinimath/test_curriculum.py b/environments/infinimath/test_curriculum.py
new file mode 100644
index 00000000..d6df31ee
--- /dev/null
+++ b/environments/infinimath/test_curriculum.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+"""
+Test script for the Infinite Math curriculum manager.
+This script tests the core functionality of both the MathCurriculum and InfiniteMath classes.
+"""
+
+import os
+import random
+import sys
+from typing import Any, Dict, List, Optional, Tuple
+
+# Use relative imports
+from .curriculum import MathCurriculum
+from .infinimath import InfiniteMath
+
+
+def test_curriculum_initialization():
+ """Test that the curriculum initializes correctly with different levels."""
+ print("\n=== Testing Curriculum Initialization ===")
+
+ # Test default initialization
+ curriculum = MathCurriculum()
+ assert curriculum.get_current_level() == 1, "Default level should be 1"
+ print("✓ Default initialization successful")
+
+ # Test initialization with specific level
+ curriculum = MathCurriculum(starting_level=3)
+ assert curriculum.get_current_level() == 3, "Starting level should be 3"
+ print("✓ Custom level initialization successful")
+
+ # Test initialization with invalid level
+ try:
+ curriculum = MathCurriculum(starting_level=10)
+ print("✗ Invalid level initialization should fail")
+ except ValueError:
+ print("✓ Invalid level initialization correctly raises ValueError")
+
+
+def test_problem_generation():
+ """Test that problems are generated correctly at different difficulty levels."""
+ print("\n=== Testing Problem Generation ===")
+
+ # Test problem generation at different levels
+ for level in range(1, 8):
+ curriculum = MathCurriculum(starting_level=level)
+ problem, solution, generator_id = curriculum.get_problem()
+
+ # Verify we got a problem and solution
+ assert (
+ isinstance(problem, str) and len(problem) > 0
+ ), f"Problem at level {level} should be a non-empty string"
+ assert solution is not None, f"Solution at level {level} should not be None"
+
+ # Verify the generator ID belongs to the correct level
+ assert (
+ generator_id in curriculum.DIFFICULTY_LEVELS[level]
+ ), f"Generator ID {generator_id} should be in level {level}"
+
+ print(
+ f"✓ Level {level} problem generated: {problem[:50]}{'...' if len(problem) > 50 else ''}"
+ )
+ print(f" Solution: {solution}")
+ print(f" Generator ID: {generator_id}")
+
+
+def test_performance_tracking():
+ """Test performance tracking and level advancement."""
+ print("\n=== Testing Performance Tracking and Advancement ===")
+
+ # Create curriculum with test parameters
+ curriculum = MathCurriculum(
+ starting_level=1, progress_threshold=0.7, min_evaluations=3
+ )
+
+ # Record some correct answers (not enough to advance)
+ generator_id = curriculum.DIFFICULTY_LEVELS[1][0] # Get a generator from level 1
+ curriculum.record_performance(generator_id, True)
+ curriculum.record_performance(generator_id, True)
+
+ # Check if we advance (should not)
+ did_advance = curriculum.advance_difficulty()
+ assert not did_advance, "Should not advance with only 2 evaluations"
+ assert curriculum.get_current_level() == 1, "Level should still be 1"
+ print("✓ Correctly did not advance with insufficient evaluations")
+
+ # Add one more correct answer (now should advance)
+ curriculum.record_performance(generator_id, True)
+ did_advance = curriculum.advance_difficulty()
+ assert did_advance, "Should advance with 3 correct evaluations (100% success rate)"
+ assert curriculum.get_current_level() == 2, "Level should be 2 after advancement"
+ print("✓ Correctly advanced to level 2 after sufficient success")
+
+ # Test with too low success rate
+ curriculum = MathCurriculum(
+ starting_level=1, progress_threshold=0.7, min_evaluations=3
+ )
+ generator_id = curriculum.DIFFICULTY_LEVELS[1][0]
+ curriculum.record_performance(generator_id, True) # 1 correct
+ curriculum.record_performance(generator_id, False) # 1 incorrect
+ curriculum.record_performance(generator_id, False) # 1 incorrect
+
+ did_advance = curriculum.advance_difficulty()
+ assert (
+ not did_advance
+ ), "Should not advance with 33% success rate when threshold is 70%"
+ print("✓ Correctly did not advance with insufficient success rate")
+
+ # Test advancement at the highest level
+ curriculum = MathCurriculum(
+ starting_level=7, progress_threshold=0.7, min_evaluations=3
+ )
+ generator_id = curriculum.DIFFICULTY_LEVELS[7][0]
+ curriculum.record_performance(generator_id, True)
+ curriculum.record_performance(generator_id, True)
+ curriculum.record_performance(generator_id, True)
+
+ did_advance = curriculum.advance_difficulty()
+ assert not did_advance, "Should not advance beyond the highest level"
+ print("✓ Correctly did not advance beyond the highest level")
+
+
+def test_level_descriptions():
+ """Test that level descriptions are correct."""
+ print("\n=== Testing Level Descriptions ===")
+
+ curriculum = MathCurriculum()
+
+ for level in range(1, 8):
+ description = curriculum.get_level_description(level)
+ assert (
+ isinstance(description, str) and len(description) > 0
+ ), f"Description for level {level} should be a non-empty string"
+ print(f"✓ Level {level}: {description}")
+
+
+def test_infinite_math_environment():
+ """Test the InfiniteMath environment functionality."""
+ print("\n=== Testing InfiniteMath Environment ===")
+
+ # Initialize the environment
+ env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=3)
+
+ # Get the initial state
+ state = env.get_state()
+ assert "problem" in state, "State should include a problem"
+ assert "current_level" in state, "State should include current level"
+ print(
+ f"✓ Initial state: Level {state['current_level']}, Problem: {state['problem'][:50]}{'...' if len(state['problem']) > 50 else ''}"
+ )
+
+ # Test answering a problem incorrectly
+ result = env.submit_answer("wrong answer")
+ assert "is_correct" in result, "Result should indicate correctness"
+ assert not result["is_correct"], "Result should be incorrect"
+ print("✓ Incorrect answer handled correctly")
+
+ # Test answering a problem correctly
+ # Note: We can't predict the correct answer, so we'll get it from the environment
+ correct_solution = env.current_solution
+ result = env.submit_answer(correct_solution)
+ assert result["is_correct"], "Result should be correct"
+ print("✓ Correct answer handled successfully")
+
+ # Test resetting the environment
+ env.reset(level=3)
+ state = env.get_state()
+ assert state["current_level"] == 3, "After reset, level should be 3"
+ print(f"✓ Reset to level 3 successful")
+
+ # Test getting difficulty stats
+ stats = env.get_difficulty_stats()
+ assert len(stats) == 7, "Should have stats for all 7 difficulty levels"
+ print("✓ Difficulty statistics retrieved successfully")
+
+
+def simulate_learning():
+ """Simulate a learning agent improving performance over time."""
+ print("\n=== Simulating Learning Process ===")
+
+ env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=5)
+ episodes = 30
+
+ print(f"Starting simulation with {episodes} episodes")
+ print(f"Initial level: {env.get_state()['current_level']}")
+
+ for i in range(episodes):
+ state = env.get_state()
+ current_level = state["current_level"]
+
+ # Simulated agent gradually improves - higher chance of correct answer over time
+ # and higher chance at lower levels
+ success_probability = min(0.5 + (i / episodes) + (1 / current_level), 0.95)
+
+ # Simulate an answer
+ is_correct = random.random() < success_probability
+
+ # If we decide to be correct, use the actual solution
+ if is_correct:
+ answer = env.current_solution
+ else:
+ # Otherwise provide a wrong answer
+ answer = "wrong answer"
+
+ # Submit the answer
+ result = env.submit_answer(answer)
+
+ # Check for level advancement
+ if result.get("did_advance_level", False):
+ new_level = result["new_level"]
+ print(
+ f"Episode {i+1}: Advanced to level {new_level}! (success probability: {success_probability:.2f})"
+ )
+ elif i % 5 == 0: # Print status occasionally
+ print(
+ f"Episode {i+1}: Still at level {current_level} (success probability: {success_probability:.2f})"
+ )
+
+ # Print final stats
+ final_state = env.get_state()
+ print(f"\nFinal level: {final_state['current_level']}")
+ print(
+ f"Overall accuracy: {final_state['correct_problems'] / final_state['total_problems']:.2%}"
+ )
+
+ # Print level-by-level stats
+ stats = env.get_difficulty_stats()
+ print("\nPerformance by level:")
+ for level, level_stats in stats.items():
+ if level_stats["problems_attempted"] > 0:
+ success_rate = level_stats["success_rate"]
+ if success_rate is not None:
+ print(
+ f"Level {level}: {success_rate:.2%} success rate ({level_stats['problems_attempted']} problems)"
+ )
+ else:
+ print(
+ f"Level {level}: Not enough data ({level_stats['problems_attempted']} problems)"
+ )
+
+
+def main():
+ """Run all tests."""
+ print("=== Starting Curriculum Manager Tests ===")
+
+ try:
+ test_curriculum_initialization()
+ test_problem_generation()
+ test_performance_tracking()
+ test_level_descriptions()
+ test_infinite_math_environment()
+ simulate_learning()
+
+ print("\n=== All tests completed successfully! ===")
+ except AssertionError as e:
+ print(f"\n❌ Test failed: {e}")
+ return 1
+ except Exception as e:
+ print(f"\n❌ Unexpected error: {e}")
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())