diff --git a/environments/infinimath/README.md b/environments/infinimath/README.md new file mode 100644 index 00000000..102ac0fc --- /dev/null +++ b/environments/infinimath/README.md @@ -0,0 +1,105 @@ +# InfiniteMath Environment + +## Environment Overview + +This environment provides procedurally generated math problems with curriculum-based advancement. It allows an agent to solve increasingly difficult math problems, with the difficulty level adapting based on performance. + +**Demonstrates:** +- Procedural content generation (math problems). +- Curriculum learning: The environment automatically adjusts the difficulty (levels 1-7) based on the LLM's success rate. +- Step-by-step reasoning evaluation: Rewards correctness, the presence of reasoning steps (within `` tags), and the final answer format (`\boxed{}`). +- Handling LaTeX formatting for problems and answers. + +**Training Goal:** +- To train LLMs to solve mathematical problems accurately. +- To encourage explicit step-by-step reasoning before providing an answer. +- To improve the LLM's ability to follow specific formatting instructions (using `` tags and `\boxed{}`). +- To teach the model to handle progressively more complex problems through the curriculum. + +## Features + +- Progressive difficulty scaling across 7 levels of math problems +- Built-in curriculum system that adapts to agent performance +- Automatic problem generation with solutions +- Reward functions for accuracy, formatting, and boxed answer checking + +## Usage + +### Running with Default Configuration + +To run the InfiniteMath environment with the default configuration: + +```bash +python environments/infinite_math/infinimath_local_server.py +``` + +This will use the default configuration from `configs/envs/infinimath.yaml`. + +### Custom Configuration + +You can specify a custom configuration file: + +```bash +python environments/infinite_math/infinimath_local_server.py --config my_custom_config +``` + +The `--config` parameter can be: + +1. A name (without `.yaml` extension) which will be looked up in `configs/envs/` +2. A relative or absolute path to a YAML file + +For example: +```bash +# Using a config in configs/envs/ +python environments/infinite_math/infinimath_local_server.py --config infinimath_hard + +# Using a config with full path +python environments/infinite_math/infinimath_local_server.py --config /path/to/my/config.yaml +``` + +## Configuration Structure + +The configuration file follows this structure: + +```yaml +# Base environment parameters +tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview" +group_size: 1 +use_wandb: false +# ... other base parameters + +# InfiniteMath specific configuration +infinimath: + # Curriculum parameters + starting_level: 1 + progress_threshold: 0.7 + # ... other InfiniteMath specific parameters + +# Server configuration +server_configs: + - model_name: "gpt-4.1-nano" + api_key: ${OPENAI_API_KEY} + num_requests_for_eval: 70 +``` + +### Important Configuration Parameters + +#### Base Parameters + +- `tokenizer_name`: The tokenizer to use for encoding/decoding text +- `group_size`: Number of responses to collect per prompt +- `max_token_length`: Maximum token length for generation +- `steps_per_eval`: How often to run evaluations + +#### InfiniteMath Specific Parameters + +- `starting_level`: Initial difficulty level (1-7) +- `progress_threshold`: Success rate needed to advance levels +- `min_evaluations`: Minimum number of evaluations before level advancement +- `reward_functions`: List of reward functions to apply + +#### Server Configuration + +- `model_name`: LLM model to use +- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax) +- `num_requests_for_eval`: Number of evaluation requests to allocate diff --git a/environments/infinimath/__init__.py b/environments/infinimath/__init__.py new file mode 100644 index 00000000..b14bda0f --- /dev/null +++ b/environments/infinimath/__init__.py @@ -0,0 +1,8 @@ +""" +Infinite Math - A reinforcement learning environment for procedurally generated math problems. +""" + +from .curriculum import MathCurriculum +from .infinimath import InfiniteMath + +__all__ = ["MathCurriculum", "InfiniteMath"] diff --git a/environments/infinimath/curriculum.py b/environments/infinimath/curriculum.py new file mode 100644 index 00000000..7e9cd4a0 --- /dev/null +++ b/environments/infinimath/curriculum.py @@ -0,0 +1,378 @@ +import random +from typing import Any, Callable, Dict, List, Optional, Tuple + +import mathgenerator + + +class MathCurriculum: + """ + A curriculum manager for the mathgenerator library. + + This class organizes math problems by difficulty and provides methods + to generate problems of appropriate difficulty based on the learner's + performance. + """ + + # Define difficulty levels and map generator IDs to each level + DIFFICULTY_LEVELS = { + # Level 1: Basic arithmetic operations + 1: [ + 0, + 1, + 2, + 3, + 8, + 31, + 71, + 80, + 90, + ], # Addition, Subtraction, Multiplication, Division, Square, Factorial, Absolute difference, Percentage, IsPrime + # Level 2: Basic operations with fractions and pre-algebra + 2: [ + 6, + 11, + 13, + 16, + 28, + 44, + 47, + 53, + 97, + 118, + 119, + 124, + ], # Square Root, Basic Algebra, Fraction to Decimal, Fraction Division, Fraction Multiplication, Compare Fractions, Cube Root, Exponentiation, Power of Powers, Percentage difference/error, Is Composite + # Level 3: Basic geometry and more algebra + 3: [ + 18, + 19, + 22, + 24, + 25, + 49, + 58, + 75, + 96, + 104, + 108, + 112, + 115, + ], # Area of Triangle, Triangle exists check, Third Angle of Triangle, Distance between 2 points, Pythagorean Theorem, Fourth Angle of Quadrilateral, Sum of Angles of Polygon, Area of a Sector, Perimeter of Polygons, Circumference, Arc length, Area of Circle + # Level 4: More advanced algebra and basic statistics + 4: [ + 9, + 10, + 20, + 21, + 23, + 26, + 40, + 41, + 45, + 50, + 76, + 78, + 105, + ], # LCM, GCD, Midpoint, Factoring Quadratic, System of Equations, Linear Equations, Common Factors, Intersection of Two Lines, Simple Interest, Quadratic Equation, Mean and Median, Compound Interest, Combine Like terms + # Level 5: Vectors, matrices, and solid geometry + 5: [ + 17, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 43, + 46, + 60, + 61, + 70, + 72, + 77, + 95, + 113, + 117, + 122, + 123, + ], # Matrix Multiplication, Surface Areas, Volumes, Vector operations, etc. + # Level 6: Advanced topics (calculus, statistics, computer science) + 6: [ + 4, + 5, + 7, + 12, + 14, + 15, + 27, + 30, + 42, + 48, + 52, + 54, + 55, + 56, + 59, + 62, + 64, + 73, + 79, + 84, + 88, + 89, + 91, + 103, + 107, + 110, + ], # Binary operations, Calculus, Combinatorics, Probability, etc. + # Level 7: Most complex topics + 7: [ + 65, + 66, + 67, + 68, + 69, + 74, + 85, + 92, + 93, + 94, + 98, + 99, + 100, + 101, + 106, + 109, + 111, + 121, + ], # Complex numbers, Advanced operations, etc. + } + + def __init__( + self, + starting_level: int = 1, + progress_threshold: float = 0.8, + min_evaluations: int = 5, + ): + """ + Initialize the curriculum manager. + + Args: + starting_level: The difficulty level to start with (default: 1) + progress_threshold: The success rate required to advance to the next level (default: 0.8) + min_evaluations: Minimum number of evaluations needed before considering level advancement (default: 5) + """ + self.current_level = starting_level + self.progress_threshold = progress_threshold + self.min_evaluations = min_evaluations + + # Performance tracking + self.performance_history = { + level: [] for level in self.DIFFICULTY_LEVELS.keys() + } + + # Ensure starting level is valid + if starting_level not in self.DIFFICULTY_LEVELS: + raise ValueError( + f"Invalid starting level: {starting_level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}" + ) + + def get_problem(self) -> Tuple[str, str, int]: + """ + Generate a math problem at the current difficulty level. + + Returns: + Tuple containing (problem_text, solution_text, generator_id) + """ + # Get the available generator IDs for the current level + available_generators = self.DIFFICULTY_LEVELS[self.current_level] + + # Try generators until one works + max_attempts = 5 # Limit the number of attempts to avoid infinite loops + attempts = 0 + + while attempts < max_attempts: + # Get a random generator ID from the current level + generator_id = random.choice(available_generators) + + try: + # Generate the problem + problem, solution = mathgenerator.genById(generator_id) + return problem, solution, generator_id + except Exception as e: + # Log the error and try another generator + print(f"Error with generator {generator_id}: {str(e)}") + attempts += 1 + + # Remove the problematic generator from the available list for this session + if generator_id in available_generators: + available_generators.remove(generator_id) + + # If we've exhausted all generators in this level, move to an adjacent level + if not available_generators: + fallback_level = max( + 1, min(7, self.current_level + random.choice([-1, 1])) + ) + available_generators = self.DIFFICULTY_LEVELS[fallback_level].copy() + + # If all attempts fail, return a simple addition problem as fallback + return "What is $2 + 2$?", "4", 0 + + def record_performance(self, generator_id: int, is_correct: bool) -> None: + """ + Record the performance on a specific problem. + + Args: + generator_id: The ID of the generator used + is_correct: Whether the answer was correct + """ + # Find which level this generator belongs to + level = None + for lvl, generator_ids in self.DIFFICULTY_LEVELS.items(): + if generator_id in generator_ids: + level = lvl + break + + if level is not None: + # Add the result to the performance history + self.performance_history[level].append(is_correct) + + def get_success_rate(self, level: int) -> Optional[float]: + """ + Calculate the success rate for a specific level. + + Args: + level: The difficulty level + + Returns: + Success rate as a float between 0 and 1, or None if not enough data + """ + history = self.performance_history[level] + + if len(history) < self.min_evaluations: + return None + + # Calculate success rate from recent evaluations + recent_history = history[-self.min_evaluations :] + return sum(recent_history) / len(recent_history) + + def should_advance(self) -> bool: + """ + Determine if the learner should advance to the next level. + + Returns: + Boolean indicating whether to advance + """ + success_rate = self.get_success_rate(self.current_level) + + # If not enough data or below threshold, don't advance + if success_rate is None or success_rate < self.progress_threshold: + return False + + # Check if there's a next level to advance to + return self.current_level < max(self.DIFFICULTY_LEVELS.keys()) + + def advance_difficulty(self) -> bool: + """ + Advance to the next difficulty level if appropriate. + + Returns: + Boolean indicating whether advancement occurred + """ + if self.should_advance(): + self.current_level += 1 + return True + return False + + def get_current_level(self) -> int: + """ + Get the current difficulty level. + + Returns: + Current level as an integer + """ + return self.current_level + + def get_num_levels(self) -> int: + """ + Get the total number of difficulty levels. + + Returns: + Total number of levels + """ + return len(self.DIFFICULTY_LEVELS) + + def get_level_description(self, level: Optional[int] = None) -> str: + """ + Get a description of the specified difficulty level. + + Args: + level: The level to describe (default: current level) + + Returns: + String description of the level + """ + if level is None: + level = self.current_level + + level_descriptions = { + 1: "Basic arithmetic operations (addition, subtraction, multiplication, division)", + 2: "Basic operations with fractions and pre-algebra", + 3: "Basic geometry and more algebra", + 4: "More advanced algebra and basic statistics", + 5: "Vectors, matrices, and solid geometry", + 6: "Advanced topics (calculus, statistics, computer science)", + 7: "Most complex topics (complex numbers, advanced operations)", + } + + return level_descriptions.get( + level, f"Custom level with IDs: {self.DIFFICULTY_LEVELS.get(level, [])}" + ) + + def reset(self, level: int = 1) -> None: + """ + Reset the curriculum to a specific level and clear performance history. + + Args: + level: The level to reset to (default: 1) + """ + if level not in self.DIFFICULTY_LEVELS: + raise ValueError( + f"Invalid level: {level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}" + ) + + self.current_level = level + self.performance_history = {lvl: [] for lvl in self.DIFFICULTY_LEVELS.keys()} + + def get_generator_info(self) -> List[Dict[str, Any]]: + """ + Get information about all available generators. + + Returns: + List of dictionaries containing generator information + """ + generators = [] + gen_list = mathgenerator.getGenList() + + for gen in gen_list: + # Find which level this generator belongs to + level = None + for lvl, generator_ids in self.DIFFICULTY_LEVELS.items(): + if gen[0] in generator_ids: + level = lvl + break + + generators.append( + { + "id": gen[0], + "name": gen[1], + "function": gen[3], + "subject": gen[4], + "params": gen[5], + "difficulty_level": level, + } + ) + + return generators diff --git a/environments/infinimath/infinimath.py b/environments/infinimath/infinimath.py new file mode 100644 index 00000000..3135c137 --- /dev/null +++ b/environments/infinimath/infinimath.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Infinite Math - A reinforcement learning environment for math practice +using the mathgenerator library with curriculum-based advancement. +""" + +import random +from typing import Any, Dict, List, Optional, Tuple + +from .curriculum import MathCurriculum + + +class InfiniteMath: + """ + A reinforcement learning environment for practicing math skills with + curriculum-based advancement. + + This class uses the MathCurriculum to generate appropriate math problems + and track performance, advancing difficulty as the learner improves. + """ + + def __init__( + self, + starting_level: int = 1, + progress_threshold: float = 0.8, + min_evaluations: int = 5, + max_attempts_per_problem: int = 3, + ): + """ + Initialize the InfiniteMath environment. + + Args: + starting_level: Initial difficulty level (default: 1) + progress_threshold: Success rate needed to advance levels (default: 0.8) + min_evaluations: Minimum evaluations before considering advancement (default: 5) + max_attempts_per_problem: Maximum attempts allowed per problem (default: 3) + """ + self.curriculum = MathCurriculum( + starting_level=starting_level, + progress_threshold=progress_threshold, + min_evaluations=min_evaluations, + ) + + self.max_attempts = max_attempts_per_problem + self.current_problem = None + self.current_solution = None + self.current_generator_id = None + self.attempts_remaining = 0 + self.total_problems = 0 + self.correct_problems = 0 + + # Generate the first problem + self._generate_problem() + + def _generate_problem(self) -> None: + """Generate a new problem from the curriculum.""" + self.current_problem, self.current_solution, self.current_generator_id = ( + self.curriculum.get_problem() + ) + self.attempts_remaining = self.max_attempts + + def get_state(self) -> Dict[str, Any]: + """ + Get the current state of the environment. + + Returns: + Dictionary with current state information + """ + return { + "problem": self.current_problem, + "attempts_remaining": self.attempts_remaining, + "current_level": self.curriculum.get_current_level(), + "total_levels": self.curriculum.get_num_levels(), + "level_description": self.curriculum.get_level_description(), + "total_problems": self.total_problems, + "correct_problems": self.correct_problems, + "accuracy": self.correct_problems / max(1, self.total_problems), + } + + def submit_answer(self, answer: str) -> Dict[str, Any]: + """ + Submit an answer to the current problem. + + Args: + answer: The learner's answer to the current problem + + Returns: + Dictionary with the result of the submission and updated state + """ + if self.current_problem is None: + return {"error": "No active problem. Call reset() to start a new session."} + + # Clean up the answer for comparison (strip whitespace, convert to lowercase) + cleaned_answer = str(answer).strip().lower() + cleaned_solution = str(self.current_solution).strip().lower() + + # Check if the answer is correct + is_correct = cleaned_answer == cleaned_solution + + # Update attempts + self.attempts_remaining -= 1 + + result = { + "is_correct": is_correct, + "correct_answer": ( + self.current_solution + if self.attempts_remaining == 0 or is_correct + else None + ), + "attempts_remaining": self.attempts_remaining, + } + + # If correct or out of attempts, record performance and generate a new problem + if is_correct or self.attempts_remaining == 0: + self.total_problems += 1 + if is_correct: + self.correct_problems += 1 + + # Record performance in the curriculum + self.curriculum.record_performance(self.current_generator_id, is_correct) + + # Check if we should advance to the next level + did_advance = self.curriculum.advance_difficulty() + result["did_advance_level"] = did_advance + + if did_advance: + result["new_level"] = self.curriculum.get_current_level() + result["level_description"] = self.curriculum.get_level_description() + + # Generate a new problem + self._generate_problem() + result["new_problem"] = self.current_problem + + return result + + def reset(self, level: Optional[int] = None) -> Dict[str, Any]: + """ + Reset the environment, optionally to a specific difficulty level. + + Args: + level: The difficulty level to reset to (default: current level) + + Returns: + Dictionary with the new state + """ + if level is not None: + self.curriculum.reset(level) + + self.total_problems = 0 + self.correct_problems = 0 + + # Generate a new problem + self._generate_problem() + + return self.get_state() + + def get_difficulty_stats(self) -> Dict[int, Dict[str, Any]]: + """ + Get performance statistics for each difficulty level. + + Returns: + Dictionary with statistics for each level + """ + stats = {} + + for level in self.curriculum.DIFFICULTY_LEVELS.keys(): + success_rate = self.curriculum.get_success_rate(level) + history = self.curriculum.performance_history[level] + + stats[level] = { + "description": self.curriculum.get_level_description(level), + "problems_attempted": len(history), + "success_rate": ( + success_rate if success_rate is not None else float("nan") + ), + "is_current_level": level == self.curriculum.get_current_level(), + } + + return stats + + +def main(): + """Example usage of the InfiniteMath environment.""" + # Create the environment + env = InfiniteMath(starting_level=1) + + print("Welcome to InfiniteMath!") + print( + f"Starting at level {env.get_state()['current_level']}: {env.get_state()['level_description']}" + ) + + playing = True + while playing: + # Display current problem + state = env.get_state() + print("\n" + "=" * 50) + print(f"Level {state['current_level']}/{state['total_levels']}") + print(f"Problem: {state['problem']}") + print(f"Attempts remaining: {state['attempts_remaining']}") + + # Get user input + answer = input("Your answer (or 'q' to quit, 'r' to reset): ") + + if answer.lower() == "q": + playing = False + continue + elif answer.lower() == "r": + level = input("Reset to level (1-7, or press Enter for current level): ") + if level and level.isdigit() and 1 <= int(level) <= 7: + env.reset(int(level)) + else: + env.reset() + continue + + # Submit the answer + result = env.submit_answer(answer) + + # Display the result + if result.get("is_correct", False): + print("Correct!") + else: + print("Incorrect.") + + if result.get("correct_answer") is not None: + print(f"The correct answer is: {result['correct_answer']}") + + # Check if we advanced to a new level + if result.get("did_advance_level", False): + print(f"\nCongratulations! You've advanced to level {result['new_level']}!") + print(f"New level: {result['level_description']}") + + # If we have a new problem, show it + if result.get("new_problem") is not None: + print("\nNext problem is ready.") + + # Show final statistics + print("\nFinal Statistics:") + print(f"Total problems attempted: {env.total_problems}") + print(f"Correct answers: {env.correct_problems}") + if env.total_problems > 0: + print(f"Overall accuracy: {env.correct_problems / env.total_problems:.2%}") + + level_stats = env.get_difficulty_stats() + print("\nPerformance by level:") + for level, stats in level_stats.items(): + if stats["problems_attempted"] > 0: + print( + f"Level {level}: {stats['success_rate']:.2%} success rate ({stats['problems_attempted']} problems)" + ) + + +if __name__ == "__main__": + main() diff --git a/environments/infinimath/infinimath_env.py b/environments/infinimath/infinimath_env.py new file mode 100644 index 00000000..e946e7e7 --- /dev/null +++ b/environments/infinimath/infinimath_env.py @@ -0,0 +1,745 @@ +import asyncio +import json +import logging +import random +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +from trajectoryhandler.envs.base import ( + BaseEnv, + BaseEnvConfig, + OpenaiConfig, + ScoredDataGroup, +) +from trajectoryhandler.envs.reward_fns import registry +from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward +from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer + +from .curriculum import MathCurriculum + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}. + +The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary). + +Follow these steps: +1. Understand the problem carefully +2. Plan your approach +3. Execute the calculations step-by-step +4. Verify your solution +5. Express the final answer as \\boxed{your answer here} + +You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. + +Your answer format should be: + +[Your detailed step-by-step reasoning process here] + + +\\boxed{your final answer here} + +Remember to format your final answer correctly as this is important for evaluation.""" + + +class InfiniteMathEnvConfig(BaseEnvConfig): + """Configuration for the InfiniteMath environment.""" + + # Curriculum parameters + starting_level: int = 1 + progress_threshold: float = 0.8 + min_evaluations: int = 5 + + # Environment parameters + max_attempts_per_problem: int = 3 + correct_reward: float = 1.0 + incorrect_reward: float = -1.0 + + # Length penalty parameters + apply_length_penalty: bool = True + length_threshold_ratio: float = ( + 0.5 # Percentage of max_token_length before penalties apply + ) + + # Completion parameters + temperature: float = 0.7 + top_p: float = 0.9 + + # Reward functions + reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"] + accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward + format_reward_weight: float = ( + 0.2 # Weight for the format reward relative to correctness + ) + boxed_reward_weight: float = ( + 0.3 # Weight for the boxed answer reward relative to correctness + ) + + +class InfiniteMathEnv(BaseEnv): + """Environment for procedurally generated math problems with curriculum advancement.""" + + def __init__( + self, + config: InfiniteMathEnvConfig, + server_configs: Union[List[OpenaiConfig], OpenaiConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config = config # Override with our specific config class + + # Initialize tracking metrics + self.percent_correct_buffer = [] + self.level_correct_buffer = { + i: [] for i in range(1, 8) + } # Track correctness for each level + self.eval_metrics = [] + + # Curriculum will be initialized in setup() + self.curriculum = None + + # Set the system prompt + self.system_prompt = system_prompt + + # Initialize reward function + self.reward_function = self._initialize_reward_function() + + def _initialize_reward_function(self): + """Initialize the combined reward function for scoring.""" + if hasattr(self.config, "reward_functions") and self.config.reward_functions: + # Configure parameters for specific reward functions + reward_configs = [] + + for reward_func in self.config.reward_functions: + if isinstance(reward_func, str): + # String name case - handle known rewards with custom params + if reward_func == "accuracy": + # Configure accuracy reward + accuracy_config = { + "type": "accuracy", + "weight": self.config.accuracy_reward_weight, + "params": { + "split_on_think_tag": True, # Only look at what's after + "tolerance": 1e-6, # Tolerance for number comparisons + }, + } + logger.info(f"Adding accuracy reward with config: {accuracy_config}") + reward_configs.append(accuracy_config) + elif reward_func == "format": + # Configure format reward with think tags and explicit weight + format_config = { + "type": "format", + "weight": self.config.format_reward_weight, + "params": { + "preferred_tags": ["think"], + }, + } + logger.info(f"Adding format reward with config: {format_config}") + reward_configs.append(format_config) + elif reward_func == "boxed": + # Configure boxed reward with proper parameters and explicit weight + boxed_config = { + "type": "boxed", + "weight": self.config.boxed_reward_weight, + "params": { + "require_outside_think": True, + }, + } + logger.info(f"Adding boxed reward with config: {boxed_config}") + reward_configs.append(boxed_config) + else: + # Pass through other reward functions as is + logger.info(f"Adding generic reward function: {reward_func}") + reward_configs.append(reward_func) + else: + # Dict case - pass through as is + logger.info(f"Adding reward config: {reward_func}") + reward_configs.append(reward_func) + + # Create the reward function(s) + if len(reward_configs) == 1: + logger.info(f"Creating single reward function: {reward_configs[0]}") + return registry.create(reward_configs[0]) + else: + logger.info(f"Creating combined reward function with {len(reward_configs)} rewards") + # Add explicit normalization to sum to 1.0 + return CombinedReward(rewards=reward_configs, normalization="none") + + async def setup(self): + """Initialize the environment and curriculum.""" + logger.info("Setting up InfiniteMathEnv") + + # Initialize curriculum + self.curriculum = MathCurriculum( + starting_level=self.config.starting_level, + progress_threshold=self.config.progress_threshold, + min_evaluations=self.config.min_evaluations, + ) + + # Generate some test problems for each level for evaluation + self.eval_problems = {} + for level in range(1, 8): + self.eval_problems[level] = [] + temp_curriculum = MathCurriculum(starting_level=level) + # Generate 10 test problems for each level + attempts = 0 + max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones + + while ( + len(self.eval_problems[level]) < 10 + and attempts < max_attempts_per_level + ): + try: + problem, solution, generator_id = temp_curriculum.get_problem() + # Strip LaTeX delimiters + problem = self.strip_latex_delimiters(problem) + solution = self.strip_latex_delimiters(solution) + self.eval_problems[level].append((problem, solution, generator_id)) + except Exception as e: + logger.warning( + f"Error generating evaluation problem for level {level}: {e}" + ) + attempts += 1 + + logger.info( + f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}" + ) + + # If any levels have no problems, add a simple fallback + for level in range(1, 8): + if not self.eval_problems[level]: + logger.warning( + f"No valid evaluation problems for level {level}, adding fallback" + ) + if level == 1: + self.eval_problems[level].append(("What is 2 + 3?", "5", 0)) + elif level == 2: + self.eval_problems[level].append( + ("What is the square root of 16?", "4", 6) + ) + elif level == 3: + self.eval_problems[level].append( + ( + "What is the area of a triangle with base 6 and height 8?", + "24", + 18, + ) + ) + elif level == 4: + self.eval_problems[level].append( + ("What is the solution to x + 5 = 12?", "7", 26) + ) + elif level == 5: + self.eval_problems[level].append( + ("What is the volume of a cube with side length 3?", "27", 33) + ) + elif level == 6: + self.eval_problems[level].append( + ("What is 5 factorial?", "120", 31) + ) + else: + self.eval_problems[level].append(("What is |3 - 10|?", "7", 71)) + + def strip_latex_delimiters(self, text: str) -> str: + """Strip LaTeX delimiters ($...$) from text.""" + # Handle both inline expressions $...$ and expressions that make up the entire string + return re.sub(r"\$(.*?)\$", r"\1", text) + + def save_checkpoint(self, step, data=None): + """Save curriculum state in checkpoint.""" + if data is None: + data = {} + + # Save curriculum state + data["curriculum_level"] = self.curriculum.get_current_level() + data["performance_history"] = { + str(k): v for k, v in self.curriculum.performance_history.items() + } + + super().save_checkpoint(step, data) + + def load_checkpoint(self): + """Load curriculum state from checkpoint.""" + super().load_checkpoint() + + # Check if we have curriculum data in the checkpoint + checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json" + try: + with open(checkpoint_path, "r") as f: + data = json.load(f) + + # Restore curriculum state if available + if "curriculum_level" in data: + level = data["curriculum_level"] + self.curriculum.current_level = level + + if "performance_history" in data: + # Convert string keys back to integers + self.curriculum.performance_history = { + int(k): v for k, v in data["performance_history"].items() + } + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.warning(f"Failed to load checkpoint: {e}") + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + + # Log overall correct percentage + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / max(1, len(self.percent_correct_buffer)) + except ZeroDivisionError: + pass + + # Log per-level metrics + for level, buffer in self.level_correct_buffer.items(): + if buffer: + wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len( + buffer + ) + wandb_metrics[f"train/level_{level}_count"] = len(buffer) + + # Log current level and curriculum progress + if self.curriculum: + current_level = self.curriculum.get_current_level() + max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys()) + + wandb_metrics["curriculum/current_level"] = current_level + wandb_metrics["curriculum/max_level"] = max_level + wandb_metrics["curriculum/progress_percent"] = ( + current_level / max_level + ) * 100 + + # Log level description + wandb_metrics["curriculum/level_description"] = ( + self.curriculum.get_level_description() + ) + + # Log performance history for current level + if current_level in self.curriculum.performance_history: + history = self.curriculum.performance_history[current_level] + if history: + recent_history = history[ + -min(len(history), self.curriculum.min_evaluations) : + ] + if recent_history: + success_rate = sum(recent_history) / len(recent_history) + wandb_metrics["curriculum/current_level_success_rate"] = ( + success_rate + ) + wandb_metrics["curriculum/threshold_to_advance"] = ( + self.curriculum.progress_threshold + ) + wandb_metrics["curriculum/remaining_to_threshold"] = max( + 0, self.curriculum.progress_threshold - success_rate + ) + + # Log reward function metrics + if hasattr(self, "reward_function") and self.wandb: + if hasattr(self.reward_function, "set_wandb_logger"): + self.reward_function.set_wandb_logger(self.wandb) + + # Log the reward configurations + if isinstance(self.config.reward_functions, list) and self.config.reward_functions: + # Log the reward configuration + wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions + wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions + + if hasattr(self.config, "format_reward_weight"): + wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight + + if hasattr(self.config, "boxed_reward_weight"): + wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight + + # Add eval metrics + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + + # Reset buffers + self.percent_correct_buffer = [] + for level in self.level_correct_buffer: + self.level_correct_buffer[level] = [] + self.eval_metrics = [] + + # Call the parent method to handle remaining metrics + await super().wandb_log(wandb_metrics) + + async def get_next_item(self): + """Get the next problem based on current curriculum level.""" + problem, solution, generator_id = self.curriculum.get_problem() + + # Strip LaTeX delimiters from problem and solution + problem = self.strip_latex_delimiters(problem) + solution = self.strip_latex_delimiters(solution) + + # Create a message with the problem + prompt = tuple([frozenset({"role": "user", "content": problem}.items())]) + + # Return the problem with metadata + return (prompt, solution, generator_id) + + async def evaluate(self, *args, **kwargs): + """Evaluate the model on test problems at the current curriculum level.""" + current_level = self.curriculum.get_current_level() + logger.info(f"Starting evaluation for curriculum level {current_level}") + + # Only evaluate problems at the current level + eval_tasks = [] + eval_generator_ids = [] + if current_level in self.eval_problems: + for problem, solution, generator_id in self.eval_problems[current_level]: + eval_tasks.append( + self.evaluate_single_problem(problem, solution, current_level) + ) + eval_generator_ids.append(generator_id) + + if not eval_tasks: + logger.warning( + f"No evaluation problems available for level {current_level}" + ) + return [] + + # Run evaluation tasks + logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}") + results = await asyncio.gather(*eval_tasks) + + # Calculate accuracy for the current level + correct_count = sum(1 for _, is_correct in results if is_correct) + total_count = len(results) + accuracy = correct_count / total_count if total_count > 0 else 0 + + logger.info( + f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})" + ) + + # Record metrics for the current level + self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy)) + self.eval_metrics.append(("eval/current_level", current_level)) + + # Record the actual evaluation results in the curriculum's performance history + for i, (_, is_correct) in enumerate(results): + if i < len(eval_generator_ids): + # Record the actual result + self.curriculum.record_performance(eval_generator_ids[i], is_correct) + else: + # Fallback if somehow the lists are different lengths + sample_generator_id = random.choice( + self.curriculum.DIFFICULTY_LEVELS[current_level] + ) + self.curriculum.record_performance(sample_generator_id, is_correct) + + # Try to advance to the next level + advanced = self.curriculum.advance_difficulty() + new_level = self.curriculum.get_current_level() + + if advanced: + logger.info(f"Advanced from level {current_level} to level {new_level}!") + self.eval_metrics.append(("eval/advanced_level", 1)) + else: + logger.info(f"Remaining at level {current_level}") + self.eval_metrics.append(("eval/advanced_level", 0)) + + return self.eval_metrics + + async def evaluate_single_problem( + self, problem: str, solution: str, level: int + ) -> Tuple[int, bool]: + """Evaluate a single problem.""" + try: + logger.debug(f"Evaluating level {level} problem: {problem[:30]}...") + + # Convert messages to a single prompt using the tokenizer + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": problem}, + ] + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + + # Add prefilled thinking starter + prefill = "\n\n" + prefilled_prompt = prompt + prefill + + # Generate completion using the prompt + logger.debug(f"Requesting completion for problem: {problem[:30]}...") + completion = await self.server.completion( + prompt=prefilled_prompt, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, # Use 0 temperature for deterministic results + top_p=1.0, + split="eval", + ) + + # Extract the completion text and prepend the thinking starter + model_answer = prefill + ( + completion.choices[0].text + if hasattr(completion.choices[0], "text") + else completion.choices[0].message.content + ) + + # Check if the answer is correct + is_correct = self.check_answer(model_answer, solution) + logger.debug(f"Problem evaluated: level={level}, correct={is_correct}") + + return level, is_correct + except Exception as e: + logger.error(f"Error evaluating problem: {e}") + # Return a failed result in case of error + return level, False + + def check_answer(self, model_answer: str, solution: str) -> bool: + """Check if the model's answer matches the solution.""" + # Extract the part after the thinking block + after_think_part = ( + model_answer.split("")[-1].strip() + if "" in model_answer + else model_answer + ) + + # Extract the boxed answer if present + boxed_answer = self.extract_boxed_answer(after_think_part) + if not boxed_answer: + # Try to find the answer in the last line + lines = after_think_part.strip().split("\n") + if lines: + boxed_answer = lines[-1].strip() + + # Clean up answers for comparison (remove spaces, convert to lowercase) + model_clean = self.clean_for_comparison( + boxed_answer if boxed_answer else after_think_part + ) + solution_clean = self.clean_for_comparison(solution) + + # Check if they match + return model_clean == solution_clean + + def extract_boxed_answer(self, text: str) -> Optional[str]: + """Extract answer from a LaTeX boxed expression.""" + # Try to find boxed content + boxed_match = re.search(r"\\boxed{([^}]*)}", text) + if boxed_match: + return boxed_match.group(1) + return None + + def clean_for_comparison(self, text: str) -> str: + """Clean text for comparison.""" + # Remove LaTeX commands, spaces, commas, and convert to lowercase + cleaned = re.sub(r"\\[a-zA-Z]+", "", text) + cleaned = re.sub(r"[,\s]", "", cleaned) + cleaned = cleaned.lower() + return cleaned + + async def collect_trajectories(self, item) -> Tuple[List, List]: + """Collect trajectories for the current item.""" + # Extract information from the item + problem_prompt, solution, generator_id = item + + # Create prompt using tokenizer's chat template + # Add prefilled thinking starter + prefill = "\n\n" + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": dict(problem_prompt[0])["content"]}, + {"role": "assistant", "content": prefill}, + ] + + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + # Generate completions using completion API + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + # Prepare data for scoring + to_score = [] + + # Track level for metrics + level = None + for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items(): + if generator_id in generator_ids: + level = lvl + break + + # Process each completion + for i, completion in enumerate(completions.choices): + # Get the completion text and prepend the thinking starter + model_answer = prefill + ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + print("model_answer", model_answer) + + # Build complete message sequence + full_messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": dict(problem_prompt[0])["content"]}, + {"role": "assistant", "content": model_answer}, + ] + + # Add to scoring list + to_score.append((full_messages, solution, generator_id, level)) + + # Record performance in curriculum for each item we're scoring + # This will be called again after scoring, but that's fine + + # No additional items for backlog + backlog = [] + + return to_score, backlog + + async def score(self, rollout_group_data) -> ScoredDataGroup: + """Score the collected trajectories.""" + scored_data = ScoredDataGroup() + scored_data["tokens"] = [] + scored_data["masks"] = [] + scored_data["scores"] = [] + scored_data["messages"] = [] + + # Format completions for reward function evaluation + format_completions = [] + + # Process each item in the rollout data + for messages, solution, generator_id, level in rollout_group_data: + # Extract the model's answer + model_answer = messages[-1]["content"] + + # Add to format completions list for reward function + format_completions.append([{"role": "assistant", "content": model_answer}]) + + # Record performance in curriculum based on the answer and solution + # This will be updated after the reward functions are applied + + # Apply all reward functions + reward_scores = [] + unweighted_scores = [] + if hasattr(self, "reward_function") and self.reward_function: + try: + # Apply the reward function (which may be a combined reward) + reward_scores = self.reward_function(format_completions, solution=solution) + logger.info(f"Reward scores: {reward_scores}") + + # Debug individual rewards if it's a combined reward + if hasattr(self.reward_function, "rewards"): + logger.info(f"Combined reward with {len(self.reward_function.rewards)} components") + for i, reward in enumerate(self.reward_function.rewards): + if hasattr(reward, "compute"): + # Get raw unweighted scores + raw_scores = reward.compute(format_completions, solution=solution) + if hasattr(reward, "weight"): + logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}") + else: + logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}") + else: + logger.info(f"Using single reward: {type(self.reward_function).__name__}") + + except Exception as e: + logger.error(f"Error applying reward functions: {e}") + logger.exception(e) + reward_scores = [0.0] * len(format_completions) + + # Now update curriculum based on accuracy reward results + for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data): + # Extract accuracy from the combined reward if available + is_correct = False + if reward_scores and hasattr(self.reward_function, "rewards"): + for reward in self.reward_function.rewards: + if type(reward).__name__ == "AccuracyReward": + # Get raw scores from accuracy reward + accuracy_scores = reward.compute(format_completions, solution=solution) + is_correct = accuracy_scores[i] > 0 + break + + # Record answer correctness for tracking + self.percent_correct_buffer.append(1 if is_correct else 0) + if level is not None: + self.level_correct_buffer[level].append(1 if is_correct else 0) + + # Record performance in curriculum + self.curriculum.record_performance(generator_id, is_correct) + + # Combine scores and add to scored data + for i, (messages, _, _, _) in enumerate(rollout_group_data): + # Use the reward score directly (all weights are applied) + combined_score = reward_scores[i] if reward_scores else 0.0 + + logger.info(f"Final score for item {i}: {combined_score}") + + # Tokenize for the trainer + tokens_dict = tokenize_for_trainer( + self.tokenizer, + messages, + None, + ) + + # Add to scored data + scored_data["tokens"].append(tokens_dict["tokens"]) + scored_data["masks"].append(tokens_dict["masks"]) + scored_data["scores"].append(combined_score) + scored_data["messages"].append(messages) + + # Advance difficulty if criteria met + self.curriculum.advance_difficulty() + + return scored_data + + +if __name__ == "__main__": + import asyncio + + async def main(): + config = InfiniteMathEnvConfig( + tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B", + group_size=8, + use_wandb=True, + max_num_workers=64, + rollout_server_url="http://localhost:8000", + total_steps=10000, + batch_size=1024, + steps_per_eval=25, + max_token_length=4096, + inference_weight=1.0, + wandb_name="infinite_math", + data_path_to_save_groups="data/infinite_math_groups.jsonl", + # InfiniteMath specific config + starting_level=1, + progress_threshold=0.8, + min_evaluations=10, + correct_reward=1.0, + incorrect_reward=-0.5, + apply_length_penalty=True, + length_threshold_ratio=0.6, + # Completion parameters + temperature=0.7, + top_p=0.9, + # Reward function configuration - use name directly + reward_functions=["accuracy", "format", "boxed"], + accuracy_reward_weight=1.0, + format_reward_weight=0.2, + boxed_reward_weight=0.3, + ) + + openai_config = OpenaiConfig( + model_name="NousResearch/Nous-Hermes-2-Yi-34B", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=64, + ) + + env = InfiniteMathEnv( + config=config, + server_configs=[openai_config], + slurm=False, + ) + + await env.env_manager() + + asyncio.run(main()) diff --git a/environments/infinimath/infinimath_local_server.py b/environments/infinimath/infinimath_local_server.py new file mode 100644 index 00000000..08cd0efa --- /dev/null +++ b/environments/infinimath/infinimath_local_server.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +import asyncio +import logging +import os +import argparse + +from dotenv import load_dotenv +from openai import OpenAI + +from environments.infinimath.infinimath_env import ( + InfiniteMathEnv, + InfiniteMathEnvConfig, +) +from atroposlib.envs.base import OpenaiConfig +from atroposlib.utils.config_handler import ConfigHandler + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="InfiniteMath environment server") + parser.add_argument( + "--config", + type=str, + default="infinimath", + help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)", + ) + return parser.parse_args() + + +async def main(): + logger.info("Starting InfiniteMath environment server") + + # Parse command line arguments + args = parse_arguments() + + # Initialize config handler and load configuration + config_handler = ConfigHandler() + + # Determine config path + if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"): + config_path = args.config + else: + # short form that defaults to the envs directory + config_path = os.path.join( + config_handler.config_dir, f"envs/{args.config}.yaml" + ) + + logger.info(f"Loading configuration from: {config_path}") + + try: + with open(config_path, "r") as f: + import yaml + raw_config = yaml.safe_load(f) + logger.info(f"Loaded configuration successfully") + except Exception as e: + logger.error(f"Error loading config directly: {e}") + logger.info("Falling back to default config handler") + raw_config = config_handler.load_config(args) + + # Configure the InfiniteMath environment with values from config + config = InfiniteMathEnvConfig( + # Base environment parameters + tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"), + group_size=raw_config.get("group_size", 1), + use_wandb=raw_config.get("use_wandb", False), + max_num_workers=raw_config.get("max_num_workers", 1), + rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"), + total_steps=raw_config.get("total_steps", 1), + batch_size=raw_config.get("batch_size", 1), + steps_per_eval=raw_config.get("steps_per_eval", 2), + max_token_length=raw_config.get("max_token_length", 4096), + wandb_name=raw_config.get("wandb_name", "infinite_math_test"), + ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False), + + # InfiniteMath specific parameters + starting_level=raw_config.get("infinimath", {}).get("starting_level", 1), + progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7), + min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3), + correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0), + incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5), + apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True), + length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6), + temperature=raw_config.get("infinimath", {}).get("temperature", 0.7), + top_p=raw_config.get("infinimath", {}).get("top_p", 0.9), + reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]), + accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0), + format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2), + boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3), + ) + + # Server configuration from config file or defaults + server_configs = [] + + if "server_configs" in raw_config: + for server_config in raw_config["server_configs"]: + api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY")) + # Handle environment variable references like ${OPENAI_API_KEY} + if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"): + env_var = api_key[2:-1] + api_key = os.environ.get(env_var, "") + + server_configs.append( + OpenaiConfig( + model_name=server_config.get("model_name", "gpt-4.1-nano"), + base_url=server_config.get("base_url", None), + api_key=api_key, + num_requests_for_eval=server_config.get("num_requests_for_eval", 70), + ) + ) + else: + # Default configuration if not specified in config file + server_configs.append( + OpenaiConfig( + model_name="gpt-4.1-nano", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=70, + ) + ) + + # Create the environment + env = InfiniteMathEnv( + config=config, + server_configs=server_configs, + slurm=False, + ) + + # Setup the environment + await env.setup() + logger.info("Environment setup complete") + + # Log the number of evaluation problems + total_problems = sum(len(probs) for probs in env.eval_problems.values()) + logger.info( + f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels" + ) + + # Get a math problem + item = await env.get_next_item() + problem_prompt, solution, generator_id = item + + logger.info(f"Problem: {dict(problem_prompt[0])['content']}") + logger.info(f"Solution: {solution}") + + # Collect trajectories + logger.info("Collecting trajectories...") + trajectories_data, backlog = await env.collect_trajectories(item) + + # Score the collected trajectories + logger.info("Scoring trajectories...") + scored_data = await env.score(trajectories_data) + + input("Press Enter to continue...") + # Print scores + logger.info(f"Scores: {scored_data['scores']}") + + # Log the correct/incorrect counts + correct_count = sum(1 for score in scored_data["scores"] if score > 0) + logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}") + + # Test evaluation function specifically + logger.info("\n=== Testing Evaluation Function ===") + + # Record the current level + initial_level = env.curriculum.get_current_level() + logger.info(f"Current level before evaluation: {initial_level}") + logger.info(f"Level description: {env.curriculum.get_level_description()}") + logger.info(f"Progress threshold: {env.curriculum.progress_threshold}") + logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}") + + # Run the evaluate method + eval_metrics = await env.evaluate() + + # Display evaluation results + logger.info("Evaluation metrics:") + for metric_name, metric_value in eval_metrics: + logger.info(f" - {metric_name}: {metric_value}") + + # Check if the level advanced + new_level = env.curriculum.get_current_level() + if new_level > initial_level: + logger.info(f"Successfully advanced to level {new_level}!") + logger.info(f"New level description: {env.curriculum.get_level_description()}") + else: + # Show current progress toward advancement + current_level = env.curriculum.get_current_level() + if current_level in env.curriculum.performance_history: + history = env.curriculum.performance_history[current_level] + if len(history) >= env.curriculum.min_evaluations: + recent_history = history[-env.curriculum.min_evaluations :] + success_rate = sum(recent_history) / len(recent_history) + logger.info( + f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)" + ) + else: + logger.info( + f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}" + ) + + # Show all levels and their performance history + logger.info("\nPerformance history by level:") + for level in sorted(env.curriculum.performance_history.keys()): + history = env.curriculum.performance_history[level] + if history: + success_rate = sum(history) / len(history) + logger.info( + f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)" + ) + else: + logger.info(f" Level {level}: No data") + + # Test curriculum advancement with simulated performance history + logger.info("\n=== Testing Curriculum Advancement ===") + + # Simulate good performance at current level + for _ in range(env.config.min_evaluations): + # Get a problem from current level + item = await env.get_next_item() + _, _, generator_id = item + + # Record positive performance + env.curriculum.record_performance(generator_id, True) + + # Try to advance difficulty + did_advance = env.curriculum.advance_difficulty() + new_level = env.curriculum.get_current_level() + + logger.info(f"Curriculum advancement test:") + logger.info(f" - Starting level: {initial_level}") + logger.info(f" - Recorded {env.config.min_evaluations} correct answers") + logger.info(f" - Did advance: {did_advance}") + logger.info(f" - New level: {new_level}") + + logger.info("Test server completed successfully") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/environments/infinimath/infinimath_server.py b/environments/infinimath/infinimath_server.py new file mode 100644 index 00000000..08cd0efa --- /dev/null +++ b/environments/infinimath/infinimath_server.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +import asyncio +import logging +import os +import argparse + +from dotenv import load_dotenv +from openai import OpenAI + +from environments.infinimath.infinimath_env import ( + InfiniteMathEnv, + InfiniteMathEnvConfig, +) +from atroposlib.envs.base import OpenaiConfig +from atroposlib.utils.config_handler import ConfigHandler + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="InfiniteMath environment server") + parser.add_argument( + "--config", + type=str, + default="infinimath", + help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)", + ) + return parser.parse_args() + + +async def main(): + logger.info("Starting InfiniteMath environment server") + + # Parse command line arguments + args = parse_arguments() + + # Initialize config handler and load configuration + config_handler = ConfigHandler() + + # Determine config path + if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"): + config_path = args.config + else: + # short form that defaults to the envs directory + config_path = os.path.join( + config_handler.config_dir, f"envs/{args.config}.yaml" + ) + + logger.info(f"Loading configuration from: {config_path}") + + try: + with open(config_path, "r") as f: + import yaml + raw_config = yaml.safe_load(f) + logger.info(f"Loaded configuration successfully") + except Exception as e: + logger.error(f"Error loading config directly: {e}") + logger.info("Falling back to default config handler") + raw_config = config_handler.load_config(args) + + # Configure the InfiniteMath environment with values from config + config = InfiniteMathEnvConfig( + # Base environment parameters + tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"), + group_size=raw_config.get("group_size", 1), + use_wandb=raw_config.get("use_wandb", False), + max_num_workers=raw_config.get("max_num_workers", 1), + rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"), + total_steps=raw_config.get("total_steps", 1), + batch_size=raw_config.get("batch_size", 1), + steps_per_eval=raw_config.get("steps_per_eval", 2), + max_token_length=raw_config.get("max_token_length", 4096), + wandb_name=raw_config.get("wandb_name", "infinite_math_test"), + ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False), + + # InfiniteMath specific parameters + starting_level=raw_config.get("infinimath", {}).get("starting_level", 1), + progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7), + min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3), + correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0), + incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5), + apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True), + length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6), + temperature=raw_config.get("infinimath", {}).get("temperature", 0.7), + top_p=raw_config.get("infinimath", {}).get("top_p", 0.9), + reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]), + accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0), + format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2), + boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3), + ) + + # Server configuration from config file or defaults + server_configs = [] + + if "server_configs" in raw_config: + for server_config in raw_config["server_configs"]: + api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY")) + # Handle environment variable references like ${OPENAI_API_KEY} + if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"): + env_var = api_key[2:-1] + api_key = os.environ.get(env_var, "") + + server_configs.append( + OpenaiConfig( + model_name=server_config.get("model_name", "gpt-4.1-nano"), + base_url=server_config.get("base_url", None), + api_key=api_key, + num_requests_for_eval=server_config.get("num_requests_for_eval", 70), + ) + ) + else: + # Default configuration if not specified in config file + server_configs.append( + OpenaiConfig( + model_name="gpt-4.1-nano", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=70, + ) + ) + + # Create the environment + env = InfiniteMathEnv( + config=config, + server_configs=server_configs, + slurm=False, + ) + + # Setup the environment + await env.setup() + logger.info("Environment setup complete") + + # Log the number of evaluation problems + total_problems = sum(len(probs) for probs in env.eval_problems.values()) + logger.info( + f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels" + ) + + # Get a math problem + item = await env.get_next_item() + problem_prompt, solution, generator_id = item + + logger.info(f"Problem: {dict(problem_prompt[0])['content']}") + logger.info(f"Solution: {solution}") + + # Collect trajectories + logger.info("Collecting trajectories...") + trajectories_data, backlog = await env.collect_trajectories(item) + + # Score the collected trajectories + logger.info("Scoring trajectories...") + scored_data = await env.score(trajectories_data) + + input("Press Enter to continue...") + # Print scores + logger.info(f"Scores: {scored_data['scores']}") + + # Log the correct/incorrect counts + correct_count = sum(1 for score in scored_data["scores"] if score > 0) + logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}") + + # Test evaluation function specifically + logger.info("\n=== Testing Evaluation Function ===") + + # Record the current level + initial_level = env.curriculum.get_current_level() + logger.info(f"Current level before evaluation: {initial_level}") + logger.info(f"Level description: {env.curriculum.get_level_description()}") + logger.info(f"Progress threshold: {env.curriculum.progress_threshold}") + logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}") + + # Run the evaluate method + eval_metrics = await env.evaluate() + + # Display evaluation results + logger.info("Evaluation metrics:") + for metric_name, metric_value in eval_metrics: + logger.info(f" - {metric_name}: {metric_value}") + + # Check if the level advanced + new_level = env.curriculum.get_current_level() + if new_level > initial_level: + logger.info(f"Successfully advanced to level {new_level}!") + logger.info(f"New level description: {env.curriculum.get_level_description()}") + else: + # Show current progress toward advancement + current_level = env.curriculum.get_current_level() + if current_level in env.curriculum.performance_history: + history = env.curriculum.performance_history[current_level] + if len(history) >= env.curriculum.min_evaluations: + recent_history = history[-env.curriculum.min_evaluations :] + success_rate = sum(recent_history) / len(recent_history) + logger.info( + f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)" + ) + else: + logger.info( + f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}" + ) + + # Show all levels and their performance history + logger.info("\nPerformance history by level:") + for level in sorted(env.curriculum.performance_history.keys()): + history = env.curriculum.performance_history[level] + if history: + success_rate = sum(history) / len(history) + logger.info( + f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)" + ) + else: + logger.info(f" Level {level}: No data") + + # Test curriculum advancement with simulated performance history + logger.info("\n=== Testing Curriculum Advancement ===") + + # Simulate good performance at current level + for _ in range(env.config.min_evaluations): + # Get a problem from current level + item = await env.get_next_item() + _, _, generator_id = item + + # Record positive performance + env.curriculum.record_performance(generator_id, True) + + # Try to advance difficulty + did_advance = env.curriculum.advance_difficulty() + new_level = env.curriculum.get_current_level() + + logger.info(f"Curriculum advancement test:") + logger.info(f" - Starting level: {initial_level}") + logger.info(f" - Recorded {env.config.min_evaluations} correct answers") + logger.info(f" - Did advance: {did_advance}") + logger.info(f" - New level: {new_level}") + + logger.info("Test server completed successfully") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/environments/infinimath/test_curriculum.py b/environments/infinimath/test_curriculum.py new file mode 100644 index 00000000..d6df31ee --- /dev/null +++ b/environments/infinimath/test_curriculum.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +""" +Test script for the Infinite Math curriculum manager. +This script tests the core functionality of both the MathCurriculum and InfiniteMath classes. +""" + +import os +import random +import sys +from typing import Any, Dict, List, Optional, Tuple + +# Use relative imports +from .curriculum import MathCurriculum +from .infinimath import InfiniteMath + + +def test_curriculum_initialization(): + """Test that the curriculum initializes correctly with different levels.""" + print("\n=== Testing Curriculum Initialization ===") + + # Test default initialization + curriculum = MathCurriculum() + assert curriculum.get_current_level() == 1, "Default level should be 1" + print("✓ Default initialization successful") + + # Test initialization with specific level + curriculum = MathCurriculum(starting_level=3) + assert curriculum.get_current_level() == 3, "Starting level should be 3" + print("✓ Custom level initialization successful") + + # Test initialization with invalid level + try: + curriculum = MathCurriculum(starting_level=10) + print("✗ Invalid level initialization should fail") + except ValueError: + print("✓ Invalid level initialization correctly raises ValueError") + + +def test_problem_generation(): + """Test that problems are generated correctly at different difficulty levels.""" + print("\n=== Testing Problem Generation ===") + + # Test problem generation at different levels + for level in range(1, 8): + curriculum = MathCurriculum(starting_level=level) + problem, solution, generator_id = curriculum.get_problem() + + # Verify we got a problem and solution + assert ( + isinstance(problem, str) and len(problem) > 0 + ), f"Problem at level {level} should be a non-empty string" + assert solution is not None, f"Solution at level {level} should not be None" + + # Verify the generator ID belongs to the correct level + assert ( + generator_id in curriculum.DIFFICULTY_LEVELS[level] + ), f"Generator ID {generator_id} should be in level {level}" + + print( + f"✓ Level {level} problem generated: {problem[:50]}{'...' if len(problem) > 50 else ''}" + ) + print(f" Solution: {solution}") + print(f" Generator ID: {generator_id}") + + +def test_performance_tracking(): + """Test performance tracking and level advancement.""" + print("\n=== Testing Performance Tracking and Advancement ===") + + # Create curriculum with test parameters + curriculum = MathCurriculum( + starting_level=1, progress_threshold=0.7, min_evaluations=3 + ) + + # Record some correct answers (not enough to advance) + generator_id = curriculum.DIFFICULTY_LEVELS[1][0] # Get a generator from level 1 + curriculum.record_performance(generator_id, True) + curriculum.record_performance(generator_id, True) + + # Check if we advance (should not) + did_advance = curriculum.advance_difficulty() + assert not did_advance, "Should not advance with only 2 evaluations" + assert curriculum.get_current_level() == 1, "Level should still be 1" + print("✓ Correctly did not advance with insufficient evaluations") + + # Add one more correct answer (now should advance) + curriculum.record_performance(generator_id, True) + did_advance = curriculum.advance_difficulty() + assert did_advance, "Should advance with 3 correct evaluations (100% success rate)" + assert curriculum.get_current_level() == 2, "Level should be 2 after advancement" + print("✓ Correctly advanced to level 2 after sufficient success") + + # Test with too low success rate + curriculum = MathCurriculum( + starting_level=1, progress_threshold=0.7, min_evaluations=3 + ) + generator_id = curriculum.DIFFICULTY_LEVELS[1][0] + curriculum.record_performance(generator_id, True) # 1 correct + curriculum.record_performance(generator_id, False) # 1 incorrect + curriculum.record_performance(generator_id, False) # 1 incorrect + + did_advance = curriculum.advance_difficulty() + assert ( + not did_advance + ), "Should not advance with 33% success rate when threshold is 70%" + print("✓ Correctly did not advance with insufficient success rate") + + # Test advancement at the highest level + curriculum = MathCurriculum( + starting_level=7, progress_threshold=0.7, min_evaluations=3 + ) + generator_id = curriculum.DIFFICULTY_LEVELS[7][0] + curriculum.record_performance(generator_id, True) + curriculum.record_performance(generator_id, True) + curriculum.record_performance(generator_id, True) + + did_advance = curriculum.advance_difficulty() + assert not did_advance, "Should not advance beyond the highest level" + print("✓ Correctly did not advance beyond the highest level") + + +def test_level_descriptions(): + """Test that level descriptions are correct.""" + print("\n=== Testing Level Descriptions ===") + + curriculum = MathCurriculum() + + for level in range(1, 8): + description = curriculum.get_level_description(level) + assert ( + isinstance(description, str) and len(description) > 0 + ), f"Description for level {level} should be a non-empty string" + print(f"✓ Level {level}: {description}") + + +def test_infinite_math_environment(): + """Test the InfiniteMath environment functionality.""" + print("\n=== Testing InfiniteMath Environment ===") + + # Initialize the environment + env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=3) + + # Get the initial state + state = env.get_state() + assert "problem" in state, "State should include a problem" + assert "current_level" in state, "State should include current level" + print( + f"✓ Initial state: Level {state['current_level']}, Problem: {state['problem'][:50]}{'...' if len(state['problem']) > 50 else ''}" + ) + + # Test answering a problem incorrectly + result = env.submit_answer("wrong answer") + assert "is_correct" in result, "Result should indicate correctness" + assert not result["is_correct"], "Result should be incorrect" + print("✓ Incorrect answer handled correctly") + + # Test answering a problem correctly + # Note: We can't predict the correct answer, so we'll get it from the environment + correct_solution = env.current_solution + result = env.submit_answer(correct_solution) + assert result["is_correct"], "Result should be correct" + print("✓ Correct answer handled successfully") + + # Test resetting the environment + env.reset(level=3) + state = env.get_state() + assert state["current_level"] == 3, "After reset, level should be 3" + print(f"✓ Reset to level 3 successful") + + # Test getting difficulty stats + stats = env.get_difficulty_stats() + assert len(stats) == 7, "Should have stats for all 7 difficulty levels" + print("✓ Difficulty statistics retrieved successfully") + + +def simulate_learning(): + """Simulate a learning agent improving performance over time.""" + print("\n=== Simulating Learning Process ===") + + env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=5) + episodes = 30 + + print(f"Starting simulation with {episodes} episodes") + print(f"Initial level: {env.get_state()['current_level']}") + + for i in range(episodes): + state = env.get_state() + current_level = state["current_level"] + + # Simulated agent gradually improves - higher chance of correct answer over time + # and higher chance at lower levels + success_probability = min(0.5 + (i / episodes) + (1 / current_level), 0.95) + + # Simulate an answer + is_correct = random.random() < success_probability + + # If we decide to be correct, use the actual solution + if is_correct: + answer = env.current_solution + else: + # Otherwise provide a wrong answer + answer = "wrong answer" + + # Submit the answer + result = env.submit_answer(answer) + + # Check for level advancement + if result.get("did_advance_level", False): + new_level = result["new_level"] + print( + f"Episode {i+1}: Advanced to level {new_level}! (success probability: {success_probability:.2f})" + ) + elif i % 5 == 0: # Print status occasionally + print( + f"Episode {i+1}: Still at level {current_level} (success probability: {success_probability:.2f})" + ) + + # Print final stats + final_state = env.get_state() + print(f"\nFinal level: {final_state['current_level']}") + print( + f"Overall accuracy: {final_state['correct_problems'] / final_state['total_problems']:.2%}" + ) + + # Print level-by-level stats + stats = env.get_difficulty_stats() + print("\nPerformance by level:") + for level, level_stats in stats.items(): + if level_stats["problems_attempted"] > 0: + success_rate = level_stats["success_rate"] + if success_rate is not None: + print( + f"Level {level}: {success_rate:.2%} success rate ({level_stats['problems_attempted']} problems)" + ) + else: + print( + f"Level {level}: Not enough data ({level_stats['problems_attempted']} problems)" + ) + + +def main(): + """Run all tests.""" + print("=== Starting Curriculum Manager Tests ===") + + try: + test_curriculum_initialization() + test_problem_generation() + test_performance_tracking() + test_level_descriptions() + test_infinite_math_environment() + simulate_learning() + + print("\n=== All tests completed successfully! ===") + except AssertionError as e: + print(f"\n❌ Test failed: {e}") + return 1 + except Exception as e: + print(f"\n❌ Unexpected error: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main())