mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
copied from trajectory handler branch
This commit is contained in:
parent
101cbdd803
commit
4e7fcd3c9a
8 changed files with 2238 additions and 0 deletions
105
environments/infinimath/README.md
Normal file
105
environments/infinimath/README.md
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
# InfiniteMath Environment
|
||||
|
||||
## Environment Overview
|
||||
|
||||
This environment provides procedurally generated math problems with curriculum-based advancement. It allows an agent to solve increasingly difficult math problems, with the difficulty level adapting based on performance.
|
||||
|
||||
**Demonstrates:**
|
||||
- Procedural content generation (math problems).
|
||||
- Curriculum learning: The environment automatically adjusts the difficulty (levels 1-7) based on the LLM's success rate.
|
||||
- Step-by-step reasoning evaluation: Rewards correctness, the presence of reasoning steps (within `<think>` tags), and the final answer format (`\boxed{}`).
|
||||
- Handling LaTeX formatting for problems and answers.
|
||||
|
||||
**Training Goal:**
|
||||
- To train LLMs to solve mathematical problems accurately.
|
||||
- To encourage explicit step-by-step reasoning before providing an answer.
|
||||
- To improve the LLM's ability to follow specific formatting instructions (using `<think>` tags and `\boxed{}`).
|
||||
- To teach the model to handle progressively more complex problems through the curriculum.
|
||||
|
||||
## Features
|
||||
|
||||
- Progressive difficulty scaling across 7 levels of math problems
|
||||
- Built-in curriculum system that adapts to agent performance
|
||||
- Automatic problem generation with solutions
|
||||
- Reward functions for accuracy, formatting, and boxed answer checking
|
||||
|
||||
## Usage
|
||||
|
||||
### Running with Default Configuration
|
||||
|
||||
To run the InfiniteMath environment with the default configuration:
|
||||
|
||||
```bash
|
||||
python environments/infinite_math/infinimath_local_server.py
|
||||
```
|
||||
|
||||
This will use the default configuration from `configs/envs/infinimath.yaml`.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
You can specify a custom configuration file:
|
||||
|
||||
```bash
|
||||
python environments/infinite_math/infinimath_local_server.py --config my_custom_config
|
||||
```
|
||||
|
||||
The `--config` parameter can be:
|
||||
|
||||
1. A name (without `.yaml` extension) which will be looked up in `configs/envs/`
|
||||
2. A relative or absolute path to a YAML file
|
||||
|
||||
For example:
|
||||
```bash
|
||||
# Using a config in configs/envs/
|
||||
python environments/infinite_math/infinimath_local_server.py --config infinimath_hard
|
||||
|
||||
# Using a config with full path
|
||||
python environments/infinite_math/infinimath_local_server.py --config /path/to/my/config.yaml
|
||||
```
|
||||
|
||||
## Configuration Structure
|
||||
|
||||
The configuration file follows this structure:
|
||||
|
||||
```yaml
|
||||
# Base environment parameters
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
group_size: 1
|
||||
use_wandb: false
|
||||
# ... other base parameters
|
||||
|
||||
# InfiniteMath specific configuration
|
||||
infinimath:
|
||||
# Curriculum parameters
|
||||
starting_level: 1
|
||||
progress_threshold: 0.7
|
||||
# ... other InfiniteMath specific parameters
|
||||
|
||||
# Server configuration
|
||||
server_configs:
|
||||
- model_name: "gpt-4.1-nano"
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
num_requests_for_eval: 70
|
||||
```
|
||||
|
||||
### Important Configuration Parameters
|
||||
|
||||
#### Base Parameters
|
||||
|
||||
- `tokenizer_name`: The tokenizer to use for encoding/decoding text
|
||||
- `group_size`: Number of responses to collect per prompt
|
||||
- `max_token_length`: Maximum token length for generation
|
||||
- `steps_per_eval`: How often to run evaluations
|
||||
|
||||
#### InfiniteMath Specific Parameters
|
||||
|
||||
- `starting_level`: Initial difficulty level (1-7)
|
||||
- `progress_threshold`: Success rate needed to advance levels
|
||||
- `min_evaluations`: Minimum number of evaluations before level advancement
|
||||
- `reward_functions`: List of reward functions to apply
|
||||
|
||||
#### Server Configuration
|
||||
|
||||
- `model_name`: LLM model to use
|
||||
- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax)
|
||||
- `num_requests_for_eval`: Number of evaluation requests to allocate
|
||||
8
environments/infinimath/__init__.py
Normal file
8
environments/infinimath/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
Infinite Math - A reinforcement learning environment for procedurally generated math problems.
|
||||
"""
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
from .infinimath import InfiniteMath
|
||||
|
||||
__all__ = ["MathCurriculum", "InfiniteMath"]
|
||||
378
environments/infinimath/curriculum.py
Normal file
378
environments/infinimath/curriculum.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
import random
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import mathgenerator
|
||||
|
||||
|
||||
class MathCurriculum:
|
||||
"""
|
||||
A curriculum manager for the mathgenerator library.
|
||||
|
||||
This class organizes math problems by difficulty and provides methods
|
||||
to generate problems of appropriate difficulty based on the learner's
|
||||
performance.
|
||||
"""
|
||||
|
||||
# Define difficulty levels and map generator IDs to each level
|
||||
DIFFICULTY_LEVELS = {
|
||||
# Level 1: Basic arithmetic operations
|
||||
1: [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
8,
|
||||
31,
|
||||
71,
|
||||
80,
|
||||
90,
|
||||
], # Addition, Subtraction, Multiplication, Division, Square, Factorial, Absolute difference, Percentage, IsPrime
|
||||
# Level 2: Basic operations with fractions and pre-algebra
|
||||
2: [
|
||||
6,
|
||||
11,
|
||||
13,
|
||||
16,
|
||||
28,
|
||||
44,
|
||||
47,
|
||||
53,
|
||||
97,
|
||||
118,
|
||||
119,
|
||||
124,
|
||||
], # Square Root, Basic Algebra, Fraction to Decimal, Fraction Division, Fraction Multiplication, Compare Fractions, Cube Root, Exponentiation, Power of Powers, Percentage difference/error, Is Composite
|
||||
# Level 3: Basic geometry and more algebra
|
||||
3: [
|
||||
18,
|
||||
19,
|
||||
22,
|
||||
24,
|
||||
25,
|
||||
49,
|
||||
58,
|
||||
75,
|
||||
96,
|
||||
104,
|
||||
108,
|
||||
112,
|
||||
115,
|
||||
], # Area of Triangle, Triangle exists check, Third Angle of Triangle, Distance between 2 points, Pythagorean Theorem, Fourth Angle of Quadrilateral, Sum of Angles of Polygon, Area of a Sector, Perimeter of Polygons, Circumference, Arc length, Area of Circle
|
||||
# Level 4: More advanced algebra and basic statistics
|
||||
4: [
|
||||
9,
|
||||
10,
|
||||
20,
|
||||
21,
|
||||
23,
|
||||
26,
|
||||
40,
|
||||
41,
|
||||
45,
|
||||
50,
|
||||
76,
|
||||
78,
|
||||
105,
|
||||
], # LCM, GCD, Midpoint, Factoring Quadratic, System of Equations, Linear Equations, Common Factors, Intersection of Two Lines, Simple Interest, Quadratic Equation, Mean and Median, Compound Interest, Combine Like terms
|
||||
# Level 5: Vectors, matrices, and solid geometry
|
||||
5: [
|
||||
17,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
43,
|
||||
46,
|
||||
60,
|
||||
61,
|
||||
70,
|
||||
72,
|
||||
77,
|
||||
95,
|
||||
113,
|
||||
117,
|
||||
122,
|
||||
123,
|
||||
], # Matrix Multiplication, Surface Areas, Volumes, Vector operations, etc.
|
||||
# Level 6: Advanced topics (calculus, statistics, computer science)
|
||||
6: [
|
||||
4,
|
||||
5,
|
||||
7,
|
||||
12,
|
||||
14,
|
||||
15,
|
||||
27,
|
||||
30,
|
||||
42,
|
||||
48,
|
||||
52,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
59,
|
||||
62,
|
||||
64,
|
||||
73,
|
||||
79,
|
||||
84,
|
||||
88,
|
||||
89,
|
||||
91,
|
||||
103,
|
||||
107,
|
||||
110,
|
||||
], # Binary operations, Calculus, Combinatorics, Probability, etc.
|
||||
# Level 7: Most complex topics
|
||||
7: [
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
74,
|
||||
85,
|
||||
92,
|
||||
93,
|
||||
94,
|
||||
98,
|
||||
99,
|
||||
100,
|
||||
101,
|
||||
106,
|
||||
109,
|
||||
111,
|
||||
121,
|
||||
], # Complex numbers, Advanced operations, etc.
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
starting_level: int = 1,
|
||||
progress_threshold: float = 0.8,
|
||||
min_evaluations: int = 5,
|
||||
):
|
||||
"""
|
||||
Initialize the curriculum manager.
|
||||
|
||||
Args:
|
||||
starting_level: The difficulty level to start with (default: 1)
|
||||
progress_threshold: The success rate required to advance to the next level (default: 0.8)
|
||||
min_evaluations: Minimum number of evaluations needed before considering level advancement (default: 5)
|
||||
"""
|
||||
self.current_level = starting_level
|
||||
self.progress_threshold = progress_threshold
|
||||
self.min_evaluations = min_evaluations
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = {
|
||||
level: [] for level in self.DIFFICULTY_LEVELS.keys()
|
||||
}
|
||||
|
||||
# Ensure starting level is valid
|
||||
if starting_level not in self.DIFFICULTY_LEVELS:
|
||||
raise ValueError(
|
||||
f"Invalid starting level: {starting_level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
|
||||
)
|
||||
|
||||
def get_problem(self) -> Tuple[str, str, int]:
|
||||
"""
|
||||
Generate a math problem at the current difficulty level.
|
||||
|
||||
Returns:
|
||||
Tuple containing (problem_text, solution_text, generator_id)
|
||||
"""
|
||||
# Get the available generator IDs for the current level
|
||||
available_generators = self.DIFFICULTY_LEVELS[self.current_level]
|
||||
|
||||
# Try generators until one works
|
||||
max_attempts = 5 # Limit the number of attempts to avoid infinite loops
|
||||
attempts = 0
|
||||
|
||||
while attempts < max_attempts:
|
||||
# Get a random generator ID from the current level
|
||||
generator_id = random.choice(available_generators)
|
||||
|
||||
try:
|
||||
# Generate the problem
|
||||
problem, solution = mathgenerator.genById(generator_id)
|
||||
return problem, solution, generator_id
|
||||
except Exception as e:
|
||||
# Log the error and try another generator
|
||||
print(f"Error with generator {generator_id}: {str(e)}")
|
||||
attempts += 1
|
||||
|
||||
# Remove the problematic generator from the available list for this session
|
||||
if generator_id in available_generators:
|
||||
available_generators.remove(generator_id)
|
||||
|
||||
# If we've exhausted all generators in this level, move to an adjacent level
|
||||
if not available_generators:
|
||||
fallback_level = max(
|
||||
1, min(7, self.current_level + random.choice([-1, 1]))
|
||||
)
|
||||
available_generators = self.DIFFICULTY_LEVELS[fallback_level].copy()
|
||||
|
||||
# If all attempts fail, return a simple addition problem as fallback
|
||||
return "What is $2 + 2$?", "4", 0
|
||||
|
||||
def record_performance(self, generator_id: int, is_correct: bool) -> None:
|
||||
"""
|
||||
Record the performance on a specific problem.
|
||||
|
||||
Args:
|
||||
generator_id: The ID of the generator used
|
||||
is_correct: Whether the answer was correct
|
||||
"""
|
||||
# Find which level this generator belongs to
|
||||
level = None
|
||||
for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
|
||||
if generator_id in generator_ids:
|
||||
level = lvl
|
||||
break
|
||||
|
||||
if level is not None:
|
||||
# Add the result to the performance history
|
||||
self.performance_history[level].append(is_correct)
|
||||
|
||||
def get_success_rate(self, level: int) -> Optional[float]:
|
||||
"""
|
||||
Calculate the success rate for a specific level.
|
||||
|
||||
Args:
|
||||
level: The difficulty level
|
||||
|
||||
Returns:
|
||||
Success rate as a float between 0 and 1, or None if not enough data
|
||||
"""
|
||||
history = self.performance_history[level]
|
||||
|
||||
if len(history) < self.min_evaluations:
|
||||
return None
|
||||
|
||||
# Calculate success rate from recent evaluations
|
||||
recent_history = history[-self.min_evaluations :]
|
||||
return sum(recent_history) / len(recent_history)
|
||||
|
||||
def should_advance(self) -> bool:
|
||||
"""
|
||||
Determine if the learner should advance to the next level.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether to advance
|
||||
"""
|
||||
success_rate = self.get_success_rate(self.current_level)
|
||||
|
||||
# If not enough data or below threshold, don't advance
|
||||
if success_rate is None or success_rate < self.progress_threshold:
|
||||
return False
|
||||
|
||||
# Check if there's a next level to advance to
|
||||
return self.current_level < max(self.DIFFICULTY_LEVELS.keys())
|
||||
|
||||
def advance_difficulty(self) -> bool:
|
||||
"""
|
||||
Advance to the next difficulty level if appropriate.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether advancement occurred
|
||||
"""
|
||||
if self.should_advance():
|
||||
self.current_level += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_current_level(self) -> int:
|
||||
"""
|
||||
Get the current difficulty level.
|
||||
|
||||
Returns:
|
||||
Current level as an integer
|
||||
"""
|
||||
return self.current_level
|
||||
|
||||
def get_num_levels(self) -> int:
|
||||
"""
|
||||
Get the total number of difficulty levels.
|
||||
|
||||
Returns:
|
||||
Total number of levels
|
||||
"""
|
||||
return len(self.DIFFICULTY_LEVELS)
|
||||
|
||||
def get_level_description(self, level: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get a description of the specified difficulty level.
|
||||
|
||||
Args:
|
||||
level: The level to describe (default: current level)
|
||||
|
||||
Returns:
|
||||
String description of the level
|
||||
"""
|
||||
if level is None:
|
||||
level = self.current_level
|
||||
|
||||
level_descriptions = {
|
||||
1: "Basic arithmetic operations (addition, subtraction, multiplication, division)",
|
||||
2: "Basic operations with fractions and pre-algebra",
|
||||
3: "Basic geometry and more algebra",
|
||||
4: "More advanced algebra and basic statistics",
|
||||
5: "Vectors, matrices, and solid geometry",
|
||||
6: "Advanced topics (calculus, statistics, computer science)",
|
||||
7: "Most complex topics (complex numbers, advanced operations)",
|
||||
}
|
||||
|
||||
return level_descriptions.get(
|
||||
level, f"Custom level with IDs: {self.DIFFICULTY_LEVELS.get(level, [])}"
|
||||
)
|
||||
|
||||
def reset(self, level: int = 1) -> None:
|
||||
"""
|
||||
Reset the curriculum to a specific level and clear performance history.
|
||||
|
||||
Args:
|
||||
level: The level to reset to (default: 1)
|
||||
"""
|
||||
if level not in self.DIFFICULTY_LEVELS:
|
||||
raise ValueError(
|
||||
f"Invalid level: {level}. Available levels: {list(self.DIFFICULTY_LEVELS.keys())}"
|
||||
)
|
||||
|
||||
self.current_level = level
|
||||
self.performance_history = {lvl: [] for lvl in self.DIFFICULTY_LEVELS.keys()}
|
||||
|
||||
def get_generator_info(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get information about all available generators.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing generator information
|
||||
"""
|
||||
generators = []
|
||||
gen_list = mathgenerator.getGenList()
|
||||
|
||||
for gen in gen_list:
|
||||
# Find which level this generator belongs to
|
||||
level = None
|
||||
for lvl, generator_ids in self.DIFFICULTY_LEVELS.items():
|
||||
if gen[0] in generator_ids:
|
||||
level = lvl
|
||||
break
|
||||
|
||||
generators.append(
|
||||
{
|
||||
"id": gen[0],
|
||||
"name": gen[1],
|
||||
"function": gen[3],
|
||||
"subject": gen[4],
|
||||
"params": gen[5],
|
||||
"difficulty_level": level,
|
||||
}
|
||||
)
|
||||
|
||||
return generators
|
||||
253
environments/infinimath/infinimath.py
Normal file
253
environments/infinimath/infinimath.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Infinite Math - A reinforcement learning environment for math practice
|
||||
using the mathgenerator library with curriculum-based advancement.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
|
||||
|
||||
class InfiniteMath:
|
||||
"""
|
||||
A reinforcement learning environment for practicing math skills with
|
||||
curriculum-based advancement.
|
||||
|
||||
This class uses the MathCurriculum to generate appropriate math problems
|
||||
and track performance, advancing difficulty as the learner improves.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
starting_level: int = 1,
|
||||
progress_threshold: float = 0.8,
|
||||
min_evaluations: int = 5,
|
||||
max_attempts_per_problem: int = 3,
|
||||
):
|
||||
"""
|
||||
Initialize the InfiniteMath environment.
|
||||
|
||||
Args:
|
||||
starting_level: Initial difficulty level (default: 1)
|
||||
progress_threshold: Success rate needed to advance levels (default: 0.8)
|
||||
min_evaluations: Minimum evaluations before considering advancement (default: 5)
|
||||
max_attempts_per_problem: Maximum attempts allowed per problem (default: 3)
|
||||
"""
|
||||
self.curriculum = MathCurriculum(
|
||||
starting_level=starting_level,
|
||||
progress_threshold=progress_threshold,
|
||||
min_evaluations=min_evaluations,
|
||||
)
|
||||
|
||||
self.max_attempts = max_attempts_per_problem
|
||||
self.current_problem = None
|
||||
self.current_solution = None
|
||||
self.current_generator_id = None
|
||||
self.attempts_remaining = 0
|
||||
self.total_problems = 0
|
||||
self.correct_problems = 0
|
||||
|
||||
# Generate the first problem
|
||||
self._generate_problem()
|
||||
|
||||
def _generate_problem(self) -> None:
|
||||
"""Generate a new problem from the curriculum."""
|
||||
self.current_problem, self.current_solution, self.current_generator_id = (
|
||||
self.curriculum.get_problem()
|
||||
)
|
||||
self.attempts_remaining = self.max_attempts
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current state of the environment.
|
||||
|
||||
Returns:
|
||||
Dictionary with current state information
|
||||
"""
|
||||
return {
|
||||
"problem": self.current_problem,
|
||||
"attempts_remaining": self.attempts_remaining,
|
||||
"current_level": self.curriculum.get_current_level(),
|
||||
"total_levels": self.curriculum.get_num_levels(),
|
||||
"level_description": self.curriculum.get_level_description(),
|
||||
"total_problems": self.total_problems,
|
||||
"correct_problems": self.correct_problems,
|
||||
"accuracy": self.correct_problems / max(1, self.total_problems),
|
||||
}
|
||||
|
||||
def submit_answer(self, answer: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Submit an answer to the current problem.
|
||||
|
||||
Args:
|
||||
answer: The learner's answer to the current problem
|
||||
|
||||
Returns:
|
||||
Dictionary with the result of the submission and updated state
|
||||
"""
|
||||
if self.current_problem is None:
|
||||
return {"error": "No active problem. Call reset() to start a new session."}
|
||||
|
||||
# Clean up the answer for comparison (strip whitespace, convert to lowercase)
|
||||
cleaned_answer = str(answer).strip().lower()
|
||||
cleaned_solution = str(self.current_solution).strip().lower()
|
||||
|
||||
# Check if the answer is correct
|
||||
is_correct = cleaned_answer == cleaned_solution
|
||||
|
||||
# Update attempts
|
||||
self.attempts_remaining -= 1
|
||||
|
||||
result = {
|
||||
"is_correct": is_correct,
|
||||
"correct_answer": (
|
||||
self.current_solution
|
||||
if self.attempts_remaining == 0 or is_correct
|
||||
else None
|
||||
),
|
||||
"attempts_remaining": self.attempts_remaining,
|
||||
}
|
||||
|
||||
# If correct or out of attempts, record performance and generate a new problem
|
||||
if is_correct or self.attempts_remaining == 0:
|
||||
self.total_problems += 1
|
||||
if is_correct:
|
||||
self.correct_problems += 1
|
||||
|
||||
# Record performance in the curriculum
|
||||
self.curriculum.record_performance(self.current_generator_id, is_correct)
|
||||
|
||||
# Check if we should advance to the next level
|
||||
did_advance = self.curriculum.advance_difficulty()
|
||||
result["did_advance_level"] = did_advance
|
||||
|
||||
if did_advance:
|
||||
result["new_level"] = self.curriculum.get_current_level()
|
||||
result["level_description"] = self.curriculum.get_level_description()
|
||||
|
||||
# Generate a new problem
|
||||
self._generate_problem()
|
||||
result["new_problem"] = self.current_problem
|
||||
|
||||
return result
|
||||
|
||||
def reset(self, level: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Reset the environment, optionally to a specific difficulty level.
|
||||
|
||||
Args:
|
||||
level: The difficulty level to reset to (default: current level)
|
||||
|
||||
Returns:
|
||||
Dictionary with the new state
|
||||
"""
|
||||
if level is not None:
|
||||
self.curriculum.reset(level)
|
||||
|
||||
self.total_problems = 0
|
||||
self.correct_problems = 0
|
||||
|
||||
# Generate a new problem
|
||||
self._generate_problem()
|
||||
|
||||
return self.get_state()
|
||||
|
||||
def get_difficulty_stats(self) -> Dict[int, Dict[str, Any]]:
|
||||
"""
|
||||
Get performance statistics for each difficulty level.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics for each level
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
for level in self.curriculum.DIFFICULTY_LEVELS.keys():
|
||||
success_rate = self.curriculum.get_success_rate(level)
|
||||
history = self.curriculum.performance_history[level]
|
||||
|
||||
stats[level] = {
|
||||
"description": self.curriculum.get_level_description(level),
|
||||
"problems_attempted": len(history),
|
||||
"success_rate": (
|
||||
success_rate if success_rate is not None else float("nan")
|
||||
),
|
||||
"is_current_level": level == self.curriculum.get_current_level(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
"""Example usage of the InfiniteMath environment."""
|
||||
# Create the environment
|
||||
env = InfiniteMath(starting_level=1)
|
||||
|
||||
print("Welcome to InfiniteMath!")
|
||||
print(
|
||||
f"Starting at level {env.get_state()['current_level']}: {env.get_state()['level_description']}"
|
||||
)
|
||||
|
||||
playing = True
|
||||
while playing:
|
||||
# Display current problem
|
||||
state = env.get_state()
|
||||
print("\n" + "=" * 50)
|
||||
print(f"Level {state['current_level']}/{state['total_levels']}")
|
||||
print(f"Problem: {state['problem']}")
|
||||
print(f"Attempts remaining: {state['attempts_remaining']}")
|
||||
|
||||
# Get user input
|
||||
answer = input("Your answer (or 'q' to quit, 'r' to reset): ")
|
||||
|
||||
if answer.lower() == "q":
|
||||
playing = False
|
||||
continue
|
||||
elif answer.lower() == "r":
|
||||
level = input("Reset to level (1-7, or press Enter for current level): ")
|
||||
if level and level.isdigit() and 1 <= int(level) <= 7:
|
||||
env.reset(int(level))
|
||||
else:
|
||||
env.reset()
|
||||
continue
|
||||
|
||||
# Submit the answer
|
||||
result = env.submit_answer(answer)
|
||||
|
||||
# Display the result
|
||||
if result.get("is_correct", False):
|
||||
print("Correct!")
|
||||
else:
|
||||
print("Incorrect.")
|
||||
|
||||
if result.get("correct_answer") is not None:
|
||||
print(f"The correct answer is: {result['correct_answer']}")
|
||||
|
||||
# Check if we advanced to a new level
|
||||
if result.get("did_advance_level", False):
|
||||
print(f"\nCongratulations! You've advanced to level {result['new_level']}!")
|
||||
print(f"New level: {result['level_description']}")
|
||||
|
||||
# If we have a new problem, show it
|
||||
if result.get("new_problem") is not None:
|
||||
print("\nNext problem is ready.")
|
||||
|
||||
# Show final statistics
|
||||
print("\nFinal Statistics:")
|
||||
print(f"Total problems attempted: {env.total_problems}")
|
||||
print(f"Correct answers: {env.correct_problems}")
|
||||
if env.total_problems > 0:
|
||||
print(f"Overall accuracy: {env.correct_problems / env.total_problems:.2%}")
|
||||
|
||||
level_stats = env.get_difficulty_stats()
|
||||
print("\nPerformance by level:")
|
||||
for level, stats in level_stats.items():
|
||||
if stats["problems_attempted"] > 0:
|
||||
print(
|
||||
f"Level {level}: {stats['success_rate']:.2%} success rate ({stats['problems_attempted']} problems)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
745
environments/infinimath/infinimath_env.py
Normal file
745
environments/infinimath/infinimath_env.py
Normal file
|
|
@ -0,0 +1,745 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from trajectoryhandler.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from trajectoryhandler.envs.reward_fns import registry
|
||||
from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward
|
||||
from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
|
||||
|
||||
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
|
||||
|
||||
Follow these steps:
|
||||
1. Understand the problem carefully
|
||||
2. Plan your approach
|
||||
3. Execute the calculations step-by-step
|
||||
4. Verify your solution
|
||||
5. Express the final answer as \\boxed{your answer here}
|
||||
|
||||
You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
|
||||
|
||||
Your answer format should be:
|
||||
<think>
|
||||
[Your detailed step-by-step reasoning process here]
|
||||
</think>
|
||||
|
||||
\\boxed{your final answer here}
|
||||
|
||||
Remember to format your final answer correctly as this is important for evaluation."""
|
||||
|
||||
|
||||
class InfiniteMathEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the InfiniteMath environment."""
|
||||
|
||||
# Curriculum parameters
|
||||
starting_level: int = 1
|
||||
progress_threshold: float = 0.8
|
||||
min_evaluations: int = 5
|
||||
|
||||
# Environment parameters
|
||||
max_attempts_per_problem: int = 3
|
||||
correct_reward: float = 1.0
|
||||
incorrect_reward: float = -1.0
|
||||
|
||||
# Length penalty parameters
|
||||
apply_length_penalty: bool = True
|
||||
length_threshold_ratio: float = (
|
||||
0.5 # Percentage of max_token_length before penalties apply
|
||||
)
|
||||
|
||||
# Completion parameters
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
|
||||
# Reward functions
|
||||
reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"]
|
||||
accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward
|
||||
format_reward_weight: float = (
|
||||
0.2 # Weight for the format reward relative to correctness
|
||||
)
|
||||
boxed_reward_weight: float = (
|
||||
0.3 # Weight for the boxed answer reward relative to correctness
|
||||
)
|
||||
|
||||
|
||||
class InfiniteMathEnv(BaseEnv):
|
||||
"""Environment for procedurally generated math problems with curriculum advancement."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InfiniteMathEnvConfig,
|
||||
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config = config # Override with our specific config class
|
||||
|
||||
# Initialize tracking metrics
|
||||
self.percent_correct_buffer = []
|
||||
self.level_correct_buffer = {
|
||||
i: [] for i in range(1, 8)
|
||||
} # Track correctness for each level
|
||||
self.eval_metrics = []
|
||||
|
||||
# Curriculum will be initialized in setup()
|
||||
self.curriculum = None
|
||||
|
||||
# Set the system prompt
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# Initialize reward function
|
||||
self.reward_function = self._initialize_reward_function()
|
||||
|
||||
def _initialize_reward_function(self):
|
||||
"""Initialize the combined reward function for scoring."""
|
||||
if hasattr(self.config, "reward_functions") and self.config.reward_functions:
|
||||
# Configure parameters for specific reward functions
|
||||
reward_configs = []
|
||||
|
||||
for reward_func in self.config.reward_functions:
|
||||
if isinstance(reward_func, str):
|
||||
# String name case - handle known rewards with custom params
|
||||
if reward_func == "accuracy":
|
||||
# Configure accuracy reward
|
||||
accuracy_config = {
|
||||
"type": "accuracy",
|
||||
"weight": self.config.accuracy_reward_weight,
|
||||
"params": {
|
||||
"split_on_think_tag": True, # Only look at what's after </think>
|
||||
"tolerance": 1e-6, # Tolerance for number comparisons
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding accuracy reward with config: {accuracy_config}")
|
||||
reward_configs.append(accuracy_config)
|
||||
elif reward_func == "format":
|
||||
# Configure format reward with think tags and explicit weight
|
||||
format_config = {
|
||||
"type": "format",
|
||||
"weight": self.config.format_reward_weight,
|
||||
"params": {
|
||||
"preferred_tags": ["think"],
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding format reward with config: {format_config}")
|
||||
reward_configs.append(format_config)
|
||||
elif reward_func == "boxed":
|
||||
# Configure boxed reward with proper parameters and explicit weight
|
||||
boxed_config = {
|
||||
"type": "boxed",
|
||||
"weight": self.config.boxed_reward_weight,
|
||||
"params": {
|
||||
"require_outside_think": True,
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding boxed reward with config: {boxed_config}")
|
||||
reward_configs.append(boxed_config)
|
||||
else:
|
||||
# Pass through other reward functions as is
|
||||
logger.info(f"Adding generic reward function: {reward_func}")
|
||||
reward_configs.append(reward_func)
|
||||
else:
|
||||
# Dict case - pass through as is
|
||||
logger.info(f"Adding reward config: {reward_func}")
|
||||
reward_configs.append(reward_func)
|
||||
|
||||
# Create the reward function(s)
|
||||
if len(reward_configs) == 1:
|
||||
logger.info(f"Creating single reward function: {reward_configs[0]}")
|
||||
return registry.create(reward_configs[0])
|
||||
else:
|
||||
logger.info(f"Creating combined reward function with {len(reward_configs)} rewards")
|
||||
# Add explicit normalization to sum to 1.0
|
||||
return CombinedReward(rewards=reward_configs, normalization="none")
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize the environment and curriculum."""
|
||||
logger.info("Setting up InfiniteMathEnv")
|
||||
|
||||
# Initialize curriculum
|
||||
self.curriculum = MathCurriculum(
|
||||
starting_level=self.config.starting_level,
|
||||
progress_threshold=self.config.progress_threshold,
|
||||
min_evaluations=self.config.min_evaluations,
|
||||
)
|
||||
|
||||
# Generate some test problems for each level for evaluation
|
||||
self.eval_problems = {}
|
||||
for level in range(1, 8):
|
||||
self.eval_problems[level] = []
|
||||
temp_curriculum = MathCurriculum(starting_level=level)
|
||||
# Generate 10 test problems for each level
|
||||
attempts = 0
|
||||
max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones
|
||||
|
||||
while (
|
||||
len(self.eval_problems[level]) < 10
|
||||
and attempts < max_attempts_per_level
|
||||
):
|
||||
try:
|
||||
problem, solution, generator_id = temp_curriculum.get_problem()
|
||||
# Strip LaTeX delimiters
|
||||
problem = self.strip_latex_delimiters(problem)
|
||||
solution = self.strip_latex_delimiters(solution)
|
||||
self.eval_problems[level].append((problem, solution, generator_id))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error generating evaluation problem for level {level}: {e}"
|
||||
)
|
||||
attempts += 1
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
|
||||
)
|
||||
|
||||
# If any levels have no problems, add a simple fallback
|
||||
for level in range(1, 8):
|
||||
if not self.eval_problems[level]:
|
||||
logger.warning(
|
||||
f"No valid evaluation problems for level {level}, adding fallback"
|
||||
)
|
||||
if level == 1:
|
||||
self.eval_problems[level].append(("What is 2 + 3?", "5", 0))
|
||||
elif level == 2:
|
||||
self.eval_problems[level].append(
|
||||
("What is the square root of 16?", "4", 6)
|
||||
)
|
||||
elif level == 3:
|
||||
self.eval_problems[level].append(
|
||||
(
|
||||
"What is the area of a triangle with base 6 and height 8?",
|
||||
"24",
|
||||
18,
|
||||
)
|
||||
)
|
||||
elif level == 4:
|
||||
self.eval_problems[level].append(
|
||||
("What is the solution to x + 5 = 12?", "7", 26)
|
||||
)
|
||||
elif level == 5:
|
||||
self.eval_problems[level].append(
|
||||
("What is the volume of a cube with side length 3?", "27", 33)
|
||||
)
|
||||
elif level == 6:
|
||||
self.eval_problems[level].append(
|
||||
("What is 5 factorial?", "120", 31)
|
||||
)
|
||||
else:
|
||||
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
|
||||
|
||||
def strip_latex_delimiters(self, text: str) -> str:
|
||||
"""Strip LaTeX delimiters ($...$) from text."""
|
||||
# Handle both inline expressions $...$ and expressions that make up the entire string
|
||||
return re.sub(r"\$(.*?)\$", r"\1", text)
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
"""Save curriculum state in checkpoint."""
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
# Save curriculum state
|
||||
data["curriculum_level"] = self.curriculum.get_current_level()
|
||||
data["performance_history"] = {
|
||||
str(k): v for k, v in self.curriculum.performance_history.items()
|
||||
}
|
||||
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
def load_checkpoint(self):
|
||||
"""Load curriculum state from checkpoint."""
|
||||
super().load_checkpoint()
|
||||
|
||||
# Check if we have curriculum data in the checkpoint
|
||||
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
|
||||
try:
|
||||
with open(checkpoint_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Restore curriculum state if available
|
||||
if "curriculum_level" in data:
|
||||
level = data["curriculum_level"]
|
||||
self.curriculum.current_level = level
|
||||
|
||||
if "performance_history" in data:
|
||||
# Convert string keys back to integers
|
||||
self.curriculum.performance_history = {
|
||||
int(k): v for k, v in data["performance_history"].items()
|
||||
}
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Failed to load checkpoint: {e}")
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log overall correct percentage
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / max(1, len(self.percent_correct_buffer))
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
# Log per-level metrics
|
||||
for level, buffer in self.level_correct_buffer.items():
|
||||
if buffer:
|
||||
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
|
||||
buffer
|
||||
)
|
||||
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
|
||||
|
||||
# Log current level and curriculum progress
|
||||
if self.curriculum:
|
||||
current_level = self.curriculum.get_current_level()
|
||||
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
|
||||
|
||||
wandb_metrics["curriculum/current_level"] = current_level
|
||||
wandb_metrics["curriculum/max_level"] = max_level
|
||||
wandb_metrics["curriculum/progress_percent"] = (
|
||||
current_level / max_level
|
||||
) * 100
|
||||
|
||||
# Log level description
|
||||
wandb_metrics["curriculum/level_description"] = (
|
||||
self.curriculum.get_level_description()
|
||||
)
|
||||
|
||||
# Log performance history for current level
|
||||
if current_level in self.curriculum.performance_history:
|
||||
history = self.curriculum.performance_history[current_level]
|
||||
if history:
|
||||
recent_history = history[
|
||||
-min(len(history), self.curriculum.min_evaluations) :
|
||||
]
|
||||
if recent_history:
|
||||
success_rate = sum(recent_history) / len(recent_history)
|
||||
wandb_metrics["curriculum/current_level_success_rate"] = (
|
||||
success_rate
|
||||
)
|
||||
wandb_metrics["curriculum/threshold_to_advance"] = (
|
||||
self.curriculum.progress_threshold
|
||||
)
|
||||
wandb_metrics["curriculum/remaining_to_threshold"] = max(
|
||||
0, self.curriculum.progress_threshold - success_rate
|
||||
)
|
||||
|
||||
# Log reward function metrics
|
||||
if hasattr(self, "reward_function") and self.wandb:
|
||||
if hasattr(self.reward_function, "set_wandb_logger"):
|
||||
self.reward_function.set_wandb_logger(self.wandb)
|
||||
|
||||
# Log the reward configurations
|
||||
if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
|
||||
# Log the reward configuration
|
||||
wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
|
||||
wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
|
||||
|
||||
if hasattr(self.config, "format_reward_weight"):
|
||||
wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
|
||||
|
||||
if hasattr(self.config, "boxed_reward_weight"):
|
||||
wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
|
||||
|
||||
# Add eval metrics
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
|
||||
# Reset buffers
|
||||
self.percent_correct_buffer = []
|
||||
for level in self.level_correct_buffer:
|
||||
self.level_correct_buffer[level] = []
|
||||
self.eval_metrics = []
|
||||
|
||||
# Call the parent method to handle remaining metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Get the next problem based on current curriculum level."""
|
||||
problem, solution, generator_id = self.curriculum.get_problem()
|
||||
|
||||
# Strip LaTeX delimiters from problem and solution
|
||||
problem = self.strip_latex_delimiters(problem)
|
||||
solution = self.strip_latex_delimiters(solution)
|
||||
|
||||
# Create a message with the problem
|
||||
prompt = tuple([frozenset({"role": "user", "content": problem}.items())])
|
||||
|
||||
# Return the problem with metadata
|
||||
return (prompt, solution, generator_id)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate the model on test problems at the current curriculum level."""
|
||||
current_level = self.curriculum.get_current_level()
|
||||
logger.info(f"Starting evaluation for curriculum level {current_level}")
|
||||
|
||||
# Only evaluate problems at the current level
|
||||
eval_tasks = []
|
||||
eval_generator_ids = []
|
||||
if current_level in self.eval_problems:
|
||||
for problem, solution, generator_id in self.eval_problems[current_level]:
|
||||
eval_tasks.append(
|
||||
self.evaluate_single_problem(problem, solution, current_level)
|
||||
)
|
||||
eval_generator_ids.append(generator_id)
|
||||
|
||||
if not eval_tasks:
|
||||
logger.warning(
|
||||
f"No evaluation problems available for level {current_level}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Run evaluation tasks
|
||||
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
|
||||
results = await asyncio.gather(*eval_tasks)
|
||||
|
||||
# Calculate accuracy for the current level
|
||||
correct_count = sum(1 for _, is_correct in results if is_correct)
|
||||
total_count = len(results)
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
||||
logger.info(
|
||||
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
|
||||
)
|
||||
|
||||
# Record metrics for the current level
|
||||
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
|
||||
self.eval_metrics.append(("eval/current_level", current_level))
|
||||
|
||||
# Record the actual evaluation results in the curriculum's performance history
|
||||
for i, (_, is_correct) in enumerate(results):
|
||||
if i < len(eval_generator_ids):
|
||||
# Record the actual result
|
||||
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
|
||||
else:
|
||||
# Fallback if somehow the lists are different lengths
|
||||
sample_generator_id = random.choice(
|
||||
self.curriculum.DIFFICULTY_LEVELS[current_level]
|
||||
)
|
||||
self.curriculum.record_performance(sample_generator_id, is_correct)
|
||||
|
||||
# Try to advance to the next level
|
||||
advanced = self.curriculum.advance_difficulty()
|
||||
new_level = self.curriculum.get_current_level()
|
||||
|
||||
if advanced:
|
||||
logger.info(f"Advanced from level {current_level} to level {new_level}!")
|
||||
self.eval_metrics.append(("eval/advanced_level", 1))
|
||||
else:
|
||||
logger.info(f"Remaining at level {current_level}")
|
||||
self.eval_metrics.append(("eval/advanced_level", 0))
|
||||
|
||||
return self.eval_metrics
|
||||
|
||||
async def evaluate_single_problem(
|
||||
self, problem: str, solution: str, level: int
|
||||
) -> Tuple[int, bool]:
|
||||
"""Evaluate a single problem."""
|
||||
try:
|
||||
logger.debug(f"Evaluating level {level} problem: {problem[:30]}...")
|
||||
|
||||
# Convert messages to a single prompt using the tokenizer
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": problem},
|
||||
]
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# Add prefilled thinking starter
|
||||
prefill = "\n<think>\n"
|
||||
prefilled_prompt = prompt + prefill
|
||||
|
||||
# Generate completion using the prompt
|
||||
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
|
||||
completion = await self.server.completion(
|
||||
prompt=prefilled_prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0, # Use 0 temperature for deterministic results
|
||||
top_p=1.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the completion text and prepend the thinking starter
|
||||
model_answer = prefill + (
|
||||
completion.choices[0].text
|
||||
if hasattr(completion.choices[0], "text")
|
||||
else completion.choices[0].message.content
|
||||
)
|
||||
|
||||
# Check if the answer is correct
|
||||
is_correct = self.check_answer(model_answer, solution)
|
||||
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
|
||||
|
||||
return level, is_correct
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating problem: {e}")
|
||||
# Return a failed result in case of error
|
||||
return level, False
|
||||
|
||||
def check_answer(self, model_answer: str, solution: str) -> bool:
|
||||
"""Check if the model's answer matches the solution."""
|
||||
# Extract the part after the thinking block
|
||||
after_think_part = (
|
||||
model_answer.split("</think>")[-1].strip()
|
||||
if "</think>" in model_answer
|
||||
else model_answer
|
||||
)
|
||||
|
||||
# Extract the boxed answer if present
|
||||
boxed_answer = self.extract_boxed_answer(after_think_part)
|
||||
if not boxed_answer:
|
||||
# Try to find the answer in the last line
|
||||
lines = after_think_part.strip().split("\n")
|
||||
if lines:
|
||||
boxed_answer = lines[-1].strip()
|
||||
|
||||
# Clean up answers for comparison (remove spaces, convert to lowercase)
|
||||
model_clean = self.clean_for_comparison(
|
||||
boxed_answer if boxed_answer else after_think_part
|
||||
)
|
||||
solution_clean = self.clean_for_comparison(solution)
|
||||
|
||||
# Check if they match
|
||||
return model_clean == solution_clean
|
||||
|
||||
def extract_boxed_answer(self, text: str) -> Optional[str]:
|
||||
"""Extract answer from a LaTeX boxed expression."""
|
||||
# Try to find boxed content
|
||||
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1)
|
||||
return None
|
||||
|
||||
def clean_for_comparison(self, text: str) -> str:
|
||||
"""Clean text for comparison."""
|
||||
# Remove LaTeX commands, spaces, commas, and convert to lowercase
|
||||
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
|
||||
cleaned = re.sub(r"[,\s]", "", cleaned)
|
||||
cleaned = cleaned.lower()
|
||||
return cleaned
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
"""Collect trajectories for the current item."""
|
||||
# Extract information from the item
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
# Create prompt using tokenizer's chat template
|
||||
# Add prefilled thinking starter
|
||||
prefill = "\n<think>\n"
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
||||
{"role": "assistant", "content": prefill},
|
||||
]
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
# Generate completions using completion API
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
# Prepare data for scoring
|
||||
to_score = []
|
||||
|
||||
# Track level for metrics
|
||||
level = None
|
||||
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
|
||||
if generator_id in generator_ids:
|
||||
level = lvl
|
||||
break
|
||||
|
||||
# Process each completion
|
||||
for i, completion in enumerate(completions.choices):
|
||||
# Get the completion text and prepend the thinking starter
|
||||
model_answer = prefill + (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
print("model_answer", model_answer)
|
||||
|
||||
# Build complete message sequence
|
||||
full_messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
||||
{"role": "assistant", "content": model_answer},
|
||||
]
|
||||
|
||||
# Add to scoring list
|
||||
to_score.append((full_messages, solution, generator_id, level))
|
||||
|
||||
# Record performance in curriculum for each item we're scoring
|
||||
# This will be called again after scoring, but that's fine
|
||||
|
||||
# No additional items for backlog
|
||||
backlog = []
|
||||
|
||||
return to_score, backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> ScoredDataGroup:
|
||||
"""Score the collected trajectories."""
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
scored_data["masks"] = []
|
||||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
|
||||
# Format completions for reward function evaluation
|
||||
format_completions = []
|
||||
|
||||
# Process each item in the rollout data
|
||||
for messages, solution, generator_id, level in rollout_group_data:
|
||||
# Extract the model's answer
|
||||
model_answer = messages[-1]["content"]
|
||||
|
||||
# Add to format completions list for reward function
|
||||
format_completions.append([{"role": "assistant", "content": model_answer}])
|
||||
|
||||
# Record performance in curriculum based on the answer and solution
|
||||
# This will be updated after the reward functions are applied
|
||||
|
||||
# Apply all reward functions
|
||||
reward_scores = []
|
||||
unweighted_scores = []
|
||||
if hasattr(self, "reward_function") and self.reward_function:
|
||||
try:
|
||||
# Apply the reward function (which may be a combined reward)
|
||||
reward_scores = self.reward_function(format_completions, solution=solution)
|
||||
logger.info(f"Reward scores: {reward_scores}")
|
||||
|
||||
# Debug individual rewards if it's a combined reward
|
||||
if hasattr(self.reward_function, "rewards"):
|
||||
logger.info(f"Combined reward with {len(self.reward_function.rewards)} components")
|
||||
for i, reward in enumerate(self.reward_function.rewards):
|
||||
if hasattr(reward, "compute"):
|
||||
# Get raw unweighted scores
|
||||
raw_scores = reward.compute(format_completions, solution=solution)
|
||||
if hasattr(reward, "weight"):
|
||||
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}")
|
||||
else:
|
||||
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}")
|
||||
else:
|
||||
logger.info(f"Using single reward: {type(self.reward_function).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying reward functions: {e}")
|
||||
logger.exception(e)
|
||||
reward_scores = [0.0] * len(format_completions)
|
||||
|
||||
# Now update curriculum based on accuracy reward results
|
||||
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
|
||||
# Extract accuracy from the combined reward if available
|
||||
is_correct = False
|
||||
if reward_scores and hasattr(self.reward_function, "rewards"):
|
||||
for reward in self.reward_function.rewards:
|
||||
if type(reward).__name__ == "AccuracyReward":
|
||||
# Get raw scores from accuracy reward
|
||||
accuracy_scores = reward.compute(format_completions, solution=solution)
|
||||
is_correct = accuracy_scores[i] > 0
|
||||
break
|
||||
|
||||
# Record answer correctness for tracking
|
||||
self.percent_correct_buffer.append(1 if is_correct else 0)
|
||||
if level is not None:
|
||||
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
||||
|
||||
# Record performance in curriculum
|
||||
self.curriculum.record_performance(generator_id, is_correct)
|
||||
|
||||
# Combine scores and add to scored data
|
||||
for i, (messages, _, _, _) in enumerate(rollout_group_data):
|
||||
# Use the reward score directly (all weights are applied)
|
||||
combined_score = reward_scores[i] if reward_scores else 0.0
|
||||
|
||||
logger.info(f"Final score for item {i}: {combined_score}")
|
||||
|
||||
# Tokenize for the trainer
|
||||
tokens_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages,
|
||||
None,
|
||||
)
|
||||
|
||||
# Add to scored data
|
||||
scored_data["tokens"].append(tokens_dict["tokens"])
|
||||
scored_data["masks"].append(tokens_dict["masks"])
|
||||
scored_data["scores"].append(combined_score)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
# Advance difficulty if criteria met
|
||||
self.curriculum.advance_difficulty()
|
||||
|
||||
return scored_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
config = InfiniteMathEnvConfig(
|
||||
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=64,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=25,
|
||||
max_token_length=4096,
|
||||
inference_weight=1.0,
|
||||
wandb_name="infinite_math",
|
||||
data_path_to_save_groups="data/infinite_math_groups.jsonl",
|
||||
# InfiniteMath specific config
|
||||
starting_level=1,
|
||||
progress_threshold=0.8,
|
||||
min_evaluations=10,
|
||||
correct_reward=1.0,
|
||||
incorrect_reward=-0.5,
|
||||
apply_length_penalty=True,
|
||||
length_threshold_ratio=0.6,
|
||||
# Completion parameters
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
# Reward function configuration - use name directly
|
||||
reward_functions=["accuracy", "format", "boxed"],
|
||||
accuracy_reward_weight=1.0,
|
||||
format_reward_weight=0.2,
|
||||
boxed_reward_weight=0.3,
|
||||
)
|
||||
|
||||
openai_config = OpenaiConfig(
|
||||
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=64,
|
||||
)
|
||||
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
server_configs=[openai_config],
|
||||
slurm=False,
|
||||
)
|
||||
|
||||
await env.env_manager()
|
||||
|
||||
asyncio.run(main())
|
||||
242
environments/infinimath/infinimath_local_server.py
Normal file
242
environments/infinimath/infinimath_local_server.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from environments.infinimath.infinimath_env import (
|
||||
InfiniteMathEnv,
|
||||
InfiniteMathEnvConfig,
|
||||
)
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="InfiniteMath environment server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="infinimath",
|
||||
help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting InfiniteMath environment server")
|
||||
|
||||
# Parse command line arguments
|
||||
args = parse_arguments()
|
||||
|
||||
# Initialize config handler and load configuration
|
||||
config_handler = ConfigHandler()
|
||||
|
||||
# Determine config path
|
||||
if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
|
||||
config_path = args.config
|
||||
else:
|
||||
# short form that defaults to the envs directory
|
||||
config_path = os.path.join(
|
||||
config_handler.config_dir, f"envs/{args.config}.yaml"
|
||||
)
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
import yaml
|
||||
raw_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded configuration successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config directly: {e}")
|
||||
logger.info("Falling back to default config handler")
|
||||
raw_config = config_handler.load_config(args)
|
||||
|
||||
# Configure the InfiniteMath environment with values from config
|
||||
config = InfiniteMathEnvConfig(
|
||||
# Base environment parameters
|
||||
tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
|
||||
group_size=raw_config.get("group_size", 1),
|
||||
use_wandb=raw_config.get("use_wandb", False),
|
||||
max_num_workers=raw_config.get("max_num_workers", 1),
|
||||
rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
|
||||
total_steps=raw_config.get("total_steps", 1),
|
||||
batch_size=raw_config.get("batch_size", 1),
|
||||
steps_per_eval=raw_config.get("steps_per_eval", 2),
|
||||
max_token_length=raw_config.get("max_token_length", 4096),
|
||||
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
|
||||
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
|
||||
|
||||
# InfiniteMath specific parameters
|
||||
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
|
||||
progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
|
||||
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
|
||||
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
|
||||
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
|
||||
apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
|
||||
length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
|
||||
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
|
||||
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
|
||||
reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
|
||||
accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
|
||||
format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
|
||||
boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
|
||||
)
|
||||
|
||||
# Server configuration from config file or defaults
|
||||
server_configs = []
|
||||
|
||||
if "server_configs" in raw_config:
|
||||
for server_config in raw_config["server_configs"]:
|
||||
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
|
||||
# Handle environment variable references like ${OPENAI_API_KEY}
|
||||
if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
|
||||
env_var = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var, "")
|
||||
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name=server_config.get("model_name", "gpt-4.1-nano"),
|
||||
base_url=server_config.get("base_url", None),
|
||||
api_key=api_key,
|
||||
num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Default configuration if not specified in config file
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=70,
|
||||
)
|
||||
)
|
||||
|
||||
# Create the environment
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
server_configs=server_configs,
|
||||
slurm=False,
|
||||
)
|
||||
|
||||
# Setup the environment
|
||||
await env.setup()
|
||||
logger.info("Environment setup complete")
|
||||
|
||||
# Log the number of evaluation problems
|
||||
total_problems = sum(len(probs) for probs in env.eval_problems.values())
|
||||
logger.info(
|
||||
f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
|
||||
)
|
||||
|
||||
# Get a math problem
|
||||
item = await env.get_next_item()
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
|
||||
logger.info(f"Solution: {solution}")
|
||||
|
||||
# Collect trajectories
|
||||
logger.info("Collecting trajectories...")
|
||||
trajectories_data, backlog = await env.collect_trajectories(item)
|
||||
|
||||
# Score the collected trajectories
|
||||
logger.info("Scoring trajectories...")
|
||||
scored_data = await env.score(trajectories_data)
|
||||
|
||||
input("Press Enter to continue...")
|
||||
# Print scores
|
||||
logger.info(f"Scores: {scored_data['scores']}")
|
||||
|
||||
# Log the correct/incorrect counts
|
||||
correct_count = sum(1 for score in scored_data["scores"] if score > 0)
|
||||
logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
|
||||
|
||||
# Test evaluation function specifically
|
||||
logger.info("\n=== Testing Evaluation Function ===")
|
||||
|
||||
# Record the current level
|
||||
initial_level = env.curriculum.get_current_level()
|
||||
logger.info(f"Current level before evaluation: {initial_level}")
|
||||
logger.info(f"Level description: {env.curriculum.get_level_description()}")
|
||||
logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
|
||||
logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
|
||||
|
||||
# Run the evaluate method
|
||||
eval_metrics = await env.evaluate()
|
||||
|
||||
# Display evaluation results
|
||||
logger.info("Evaluation metrics:")
|
||||
for metric_name, metric_value in eval_metrics:
|
||||
logger.info(f" - {metric_name}: {metric_value}")
|
||||
|
||||
# Check if the level advanced
|
||||
new_level = env.curriculum.get_current_level()
|
||||
if new_level > initial_level:
|
||||
logger.info(f"Successfully advanced to level {new_level}!")
|
||||
logger.info(f"New level description: {env.curriculum.get_level_description()}")
|
||||
else:
|
||||
# Show current progress toward advancement
|
||||
current_level = env.curriculum.get_current_level()
|
||||
if current_level in env.curriculum.performance_history:
|
||||
history = env.curriculum.performance_history[current_level]
|
||||
if len(history) >= env.curriculum.min_evaluations:
|
||||
recent_history = history[-env.curriculum.min_evaluations :]
|
||||
success_rate = sum(recent_history) / len(recent_history)
|
||||
logger.info(
|
||||
f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
|
||||
)
|
||||
|
||||
# Show all levels and their performance history
|
||||
logger.info("\nPerformance history by level:")
|
||||
for level in sorted(env.curriculum.performance_history.keys()):
|
||||
history = env.curriculum.performance_history[level]
|
||||
if history:
|
||||
success_rate = sum(history) / len(history)
|
||||
logger.info(
|
||||
f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
|
||||
)
|
||||
else:
|
||||
logger.info(f" Level {level}: No data")
|
||||
|
||||
# Test curriculum advancement with simulated performance history
|
||||
logger.info("\n=== Testing Curriculum Advancement ===")
|
||||
|
||||
# Simulate good performance at current level
|
||||
for _ in range(env.config.min_evaluations):
|
||||
# Get a problem from current level
|
||||
item = await env.get_next_item()
|
||||
_, _, generator_id = item
|
||||
|
||||
# Record positive performance
|
||||
env.curriculum.record_performance(generator_id, True)
|
||||
|
||||
# Try to advance difficulty
|
||||
did_advance = env.curriculum.advance_difficulty()
|
||||
new_level = env.curriculum.get_current_level()
|
||||
|
||||
logger.info(f"Curriculum advancement test:")
|
||||
logger.info(f" - Starting level: {initial_level}")
|
||||
logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
|
||||
logger.info(f" - Did advance: {did_advance}")
|
||||
logger.info(f" - New level: {new_level}")
|
||||
|
||||
logger.info("Test server completed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
242
environments/infinimath/infinimath_server.py
Normal file
242
environments/infinimath/infinimath_server.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from environments.infinimath.infinimath_env import (
|
||||
InfiniteMathEnv,
|
||||
InfiniteMathEnvConfig,
|
||||
)
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="InfiniteMath environment server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="infinimath",
|
||||
help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting InfiniteMath environment server")
|
||||
|
||||
# Parse command line arguments
|
||||
args = parse_arguments()
|
||||
|
||||
# Initialize config handler and load configuration
|
||||
config_handler = ConfigHandler()
|
||||
|
||||
# Determine config path
|
||||
if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
|
||||
config_path = args.config
|
||||
else:
|
||||
# short form that defaults to the envs directory
|
||||
config_path = os.path.join(
|
||||
config_handler.config_dir, f"envs/{args.config}.yaml"
|
||||
)
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
import yaml
|
||||
raw_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded configuration successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config directly: {e}")
|
||||
logger.info("Falling back to default config handler")
|
||||
raw_config = config_handler.load_config(args)
|
||||
|
||||
# Configure the InfiniteMath environment with values from config
|
||||
config = InfiniteMathEnvConfig(
|
||||
# Base environment parameters
|
||||
tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
|
||||
group_size=raw_config.get("group_size", 1),
|
||||
use_wandb=raw_config.get("use_wandb", False),
|
||||
max_num_workers=raw_config.get("max_num_workers", 1),
|
||||
rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
|
||||
total_steps=raw_config.get("total_steps", 1),
|
||||
batch_size=raw_config.get("batch_size", 1),
|
||||
steps_per_eval=raw_config.get("steps_per_eval", 2),
|
||||
max_token_length=raw_config.get("max_token_length", 4096),
|
||||
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
|
||||
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
|
||||
|
||||
# InfiniteMath specific parameters
|
||||
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
|
||||
progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
|
||||
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
|
||||
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
|
||||
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
|
||||
apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
|
||||
length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
|
||||
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
|
||||
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
|
||||
reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
|
||||
accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
|
||||
format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
|
||||
boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
|
||||
)
|
||||
|
||||
# Server configuration from config file or defaults
|
||||
server_configs = []
|
||||
|
||||
if "server_configs" in raw_config:
|
||||
for server_config in raw_config["server_configs"]:
|
||||
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
|
||||
# Handle environment variable references like ${OPENAI_API_KEY}
|
||||
if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
|
||||
env_var = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var, "")
|
||||
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name=server_config.get("model_name", "gpt-4.1-nano"),
|
||||
base_url=server_config.get("base_url", None),
|
||||
api_key=api_key,
|
||||
num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Default configuration if not specified in config file
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=70,
|
||||
)
|
||||
)
|
||||
|
||||
# Create the environment
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
server_configs=server_configs,
|
||||
slurm=False,
|
||||
)
|
||||
|
||||
# Setup the environment
|
||||
await env.setup()
|
||||
logger.info("Environment setup complete")
|
||||
|
||||
# Log the number of evaluation problems
|
||||
total_problems = sum(len(probs) for probs in env.eval_problems.values())
|
||||
logger.info(
|
||||
f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
|
||||
)
|
||||
|
||||
# Get a math problem
|
||||
item = await env.get_next_item()
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
|
||||
logger.info(f"Solution: {solution}")
|
||||
|
||||
# Collect trajectories
|
||||
logger.info("Collecting trajectories...")
|
||||
trajectories_data, backlog = await env.collect_trajectories(item)
|
||||
|
||||
# Score the collected trajectories
|
||||
logger.info("Scoring trajectories...")
|
||||
scored_data = await env.score(trajectories_data)
|
||||
|
||||
input("Press Enter to continue...")
|
||||
# Print scores
|
||||
logger.info(f"Scores: {scored_data['scores']}")
|
||||
|
||||
# Log the correct/incorrect counts
|
||||
correct_count = sum(1 for score in scored_data["scores"] if score > 0)
|
||||
logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
|
||||
|
||||
# Test evaluation function specifically
|
||||
logger.info("\n=== Testing Evaluation Function ===")
|
||||
|
||||
# Record the current level
|
||||
initial_level = env.curriculum.get_current_level()
|
||||
logger.info(f"Current level before evaluation: {initial_level}")
|
||||
logger.info(f"Level description: {env.curriculum.get_level_description()}")
|
||||
logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
|
||||
logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
|
||||
|
||||
# Run the evaluate method
|
||||
eval_metrics = await env.evaluate()
|
||||
|
||||
# Display evaluation results
|
||||
logger.info("Evaluation metrics:")
|
||||
for metric_name, metric_value in eval_metrics:
|
||||
logger.info(f" - {metric_name}: {metric_value}")
|
||||
|
||||
# Check if the level advanced
|
||||
new_level = env.curriculum.get_current_level()
|
||||
if new_level > initial_level:
|
||||
logger.info(f"Successfully advanced to level {new_level}!")
|
||||
logger.info(f"New level description: {env.curriculum.get_level_description()}")
|
||||
else:
|
||||
# Show current progress toward advancement
|
||||
current_level = env.curriculum.get_current_level()
|
||||
if current_level in env.curriculum.performance_history:
|
||||
history = env.curriculum.performance_history[current_level]
|
||||
if len(history) >= env.curriculum.min_evaluations:
|
||||
recent_history = history[-env.curriculum.min_evaluations :]
|
||||
success_rate = sum(recent_history) / len(recent_history)
|
||||
logger.info(
|
||||
f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
|
||||
)
|
||||
|
||||
# Show all levels and their performance history
|
||||
logger.info("\nPerformance history by level:")
|
||||
for level in sorted(env.curriculum.performance_history.keys()):
|
||||
history = env.curriculum.performance_history[level]
|
||||
if history:
|
||||
success_rate = sum(history) / len(history)
|
||||
logger.info(
|
||||
f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
|
||||
)
|
||||
else:
|
||||
logger.info(f" Level {level}: No data")
|
||||
|
||||
# Test curriculum advancement with simulated performance history
|
||||
logger.info("\n=== Testing Curriculum Advancement ===")
|
||||
|
||||
# Simulate good performance at current level
|
||||
for _ in range(env.config.min_evaluations):
|
||||
# Get a problem from current level
|
||||
item = await env.get_next_item()
|
||||
_, _, generator_id = item
|
||||
|
||||
# Record positive performance
|
||||
env.curriculum.record_performance(generator_id, True)
|
||||
|
||||
# Try to advance difficulty
|
||||
did_advance = env.curriculum.advance_difficulty()
|
||||
new_level = env.curriculum.get_current_level()
|
||||
|
||||
logger.info(f"Curriculum advancement test:")
|
||||
logger.info(f" - Starting level: {initial_level}")
|
||||
logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
|
||||
logger.info(f" - Did advance: {did_advance}")
|
||||
logger.info(f" - New level: {new_level}")
|
||||
|
||||
logger.info("Test server completed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
265
environments/infinimath/test_curriculum.py
Normal file
265
environments/infinimath/test_curriculum.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Infinite Math curriculum manager.
|
||||
This script tests the core functionality of both the MathCurriculum and InfiniteMath classes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
# Use relative imports
|
||||
from .curriculum import MathCurriculum
|
||||
from .infinimath import InfiniteMath
|
||||
|
||||
|
||||
def test_curriculum_initialization():
|
||||
"""Test that the curriculum initializes correctly with different levels."""
|
||||
print("\n=== Testing Curriculum Initialization ===")
|
||||
|
||||
# Test default initialization
|
||||
curriculum = MathCurriculum()
|
||||
assert curriculum.get_current_level() == 1, "Default level should be 1"
|
||||
print("✓ Default initialization successful")
|
||||
|
||||
# Test initialization with specific level
|
||||
curriculum = MathCurriculum(starting_level=3)
|
||||
assert curriculum.get_current_level() == 3, "Starting level should be 3"
|
||||
print("✓ Custom level initialization successful")
|
||||
|
||||
# Test initialization with invalid level
|
||||
try:
|
||||
curriculum = MathCurriculum(starting_level=10)
|
||||
print("✗ Invalid level initialization should fail")
|
||||
except ValueError:
|
||||
print("✓ Invalid level initialization correctly raises ValueError")
|
||||
|
||||
|
||||
def test_problem_generation():
|
||||
"""Test that problems are generated correctly at different difficulty levels."""
|
||||
print("\n=== Testing Problem Generation ===")
|
||||
|
||||
# Test problem generation at different levels
|
||||
for level in range(1, 8):
|
||||
curriculum = MathCurriculum(starting_level=level)
|
||||
problem, solution, generator_id = curriculum.get_problem()
|
||||
|
||||
# Verify we got a problem and solution
|
||||
assert (
|
||||
isinstance(problem, str) and len(problem) > 0
|
||||
), f"Problem at level {level} should be a non-empty string"
|
||||
assert solution is not None, f"Solution at level {level} should not be None"
|
||||
|
||||
# Verify the generator ID belongs to the correct level
|
||||
assert (
|
||||
generator_id in curriculum.DIFFICULTY_LEVELS[level]
|
||||
), f"Generator ID {generator_id} should be in level {level}"
|
||||
|
||||
print(
|
||||
f"✓ Level {level} problem generated: {problem[:50]}{'...' if len(problem) > 50 else ''}"
|
||||
)
|
||||
print(f" Solution: {solution}")
|
||||
print(f" Generator ID: {generator_id}")
|
||||
|
||||
|
||||
def test_performance_tracking():
|
||||
"""Test performance tracking and level advancement."""
|
||||
print("\n=== Testing Performance Tracking and Advancement ===")
|
||||
|
||||
# Create curriculum with test parameters
|
||||
curriculum = MathCurriculum(
|
||||
starting_level=1, progress_threshold=0.7, min_evaluations=3
|
||||
)
|
||||
|
||||
# Record some correct answers (not enough to advance)
|
||||
generator_id = curriculum.DIFFICULTY_LEVELS[1][0] # Get a generator from level 1
|
||||
curriculum.record_performance(generator_id, True)
|
||||
curriculum.record_performance(generator_id, True)
|
||||
|
||||
# Check if we advance (should not)
|
||||
did_advance = curriculum.advance_difficulty()
|
||||
assert not did_advance, "Should not advance with only 2 evaluations"
|
||||
assert curriculum.get_current_level() == 1, "Level should still be 1"
|
||||
print("✓ Correctly did not advance with insufficient evaluations")
|
||||
|
||||
# Add one more correct answer (now should advance)
|
||||
curriculum.record_performance(generator_id, True)
|
||||
did_advance = curriculum.advance_difficulty()
|
||||
assert did_advance, "Should advance with 3 correct evaluations (100% success rate)"
|
||||
assert curriculum.get_current_level() == 2, "Level should be 2 after advancement"
|
||||
print("✓ Correctly advanced to level 2 after sufficient success")
|
||||
|
||||
# Test with too low success rate
|
||||
curriculum = MathCurriculum(
|
||||
starting_level=1, progress_threshold=0.7, min_evaluations=3
|
||||
)
|
||||
generator_id = curriculum.DIFFICULTY_LEVELS[1][0]
|
||||
curriculum.record_performance(generator_id, True) # 1 correct
|
||||
curriculum.record_performance(generator_id, False) # 1 incorrect
|
||||
curriculum.record_performance(generator_id, False) # 1 incorrect
|
||||
|
||||
did_advance = curriculum.advance_difficulty()
|
||||
assert (
|
||||
not did_advance
|
||||
), "Should not advance with 33% success rate when threshold is 70%"
|
||||
print("✓ Correctly did not advance with insufficient success rate")
|
||||
|
||||
# Test advancement at the highest level
|
||||
curriculum = MathCurriculum(
|
||||
starting_level=7, progress_threshold=0.7, min_evaluations=3
|
||||
)
|
||||
generator_id = curriculum.DIFFICULTY_LEVELS[7][0]
|
||||
curriculum.record_performance(generator_id, True)
|
||||
curriculum.record_performance(generator_id, True)
|
||||
curriculum.record_performance(generator_id, True)
|
||||
|
||||
did_advance = curriculum.advance_difficulty()
|
||||
assert not did_advance, "Should not advance beyond the highest level"
|
||||
print("✓ Correctly did not advance beyond the highest level")
|
||||
|
||||
|
||||
def test_level_descriptions():
|
||||
"""Test that level descriptions are correct."""
|
||||
print("\n=== Testing Level Descriptions ===")
|
||||
|
||||
curriculum = MathCurriculum()
|
||||
|
||||
for level in range(1, 8):
|
||||
description = curriculum.get_level_description(level)
|
||||
assert (
|
||||
isinstance(description, str) and len(description) > 0
|
||||
), f"Description for level {level} should be a non-empty string"
|
||||
print(f"✓ Level {level}: {description}")
|
||||
|
||||
|
||||
def test_infinite_math_environment():
|
||||
"""Test the InfiniteMath environment functionality."""
|
||||
print("\n=== Testing InfiniteMath Environment ===")
|
||||
|
||||
# Initialize the environment
|
||||
env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=3)
|
||||
|
||||
# Get the initial state
|
||||
state = env.get_state()
|
||||
assert "problem" in state, "State should include a problem"
|
||||
assert "current_level" in state, "State should include current level"
|
||||
print(
|
||||
f"✓ Initial state: Level {state['current_level']}, Problem: {state['problem'][:50]}{'...' if len(state['problem']) > 50 else ''}"
|
||||
)
|
||||
|
||||
# Test answering a problem incorrectly
|
||||
result = env.submit_answer("wrong answer")
|
||||
assert "is_correct" in result, "Result should indicate correctness"
|
||||
assert not result["is_correct"], "Result should be incorrect"
|
||||
print("✓ Incorrect answer handled correctly")
|
||||
|
||||
# Test answering a problem correctly
|
||||
# Note: We can't predict the correct answer, so we'll get it from the environment
|
||||
correct_solution = env.current_solution
|
||||
result = env.submit_answer(correct_solution)
|
||||
assert result["is_correct"], "Result should be correct"
|
||||
print("✓ Correct answer handled successfully")
|
||||
|
||||
# Test resetting the environment
|
||||
env.reset(level=3)
|
||||
state = env.get_state()
|
||||
assert state["current_level"] == 3, "After reset, level should be 3"
|
||||
print(f"✓ Reset to level 3 successful")
|
||||
|
||||
# Test getting difficulty stats
|
||||
stats = env.get_difficulty_stats()
|
||||
assert len(stats) == 7, "Should have stats for all 7 difficulty levels"
|
||||
print("✓ Difficulty statistics retrieved successfully")
|
||||
|
||||
|
||||
def simulate_learning():
|
||||
"""Simulate a learning agent improving performance over time."""
|
||||
print("\n=== Simulating Learning Process ===")
|
||||
|
||||
env = InfiniteMath(starting_level=1, progress_threshold=0.7, min_evaluations=5)
|
||||
episodes = 30
|
||||
|
||||
print(f"Starting simulation with {episodes} episodes")
|
||||
print(f"Initial level: {env.get_state()['current_level']}")
|
||||
|
||||
for i in range(episodes):
|
||||
state = env.get_state()
|
||||
current_level = state["current_level"]
|
||||
|
||||
# Simulated agent gradually improves - higher chance of correct answer over time
|
||||
# and higher chance at lower levels
|
||||
success_probability = min(0.5 + (i / episodes) + (1 / current_level), 0.95)
|
||||
|
||||
# Simulate an answer
|
||||
is_correct = random.random() < success_probability
|
||||
|
||||
# If we decide to be correct, use the actual solution
|
||||
if is_correct:
|
||||
answer = env.current_solution
|
||||
else:
|
||||
# Otherwise provide a wrong answer
|
||||
answer = "wrong answer"
|
||||
|
||||
# Submit the answer
|
||||
result = env.submit_answer(answer)
|
||||
|
||||
# Check for level advancement
|
||||
if result.get("did_advance_level", False):
|
||||
new_level = result["new_level"]
|
||||
print(
|
||||
f"Episode {i+1}: Advanced to level {new_level}! (success probability: {success_probability:.2f})"
|
||||
)
|
||||
elif i % 5 == 0: # Print status occasionally
|
||||
print(
|
||||
f"Episode {i+1}: Still at level {current_level} (success probability: {success_probability:.2f})"
|
||||
)
|
||||
|
||||
# Print final stats
|
||||
final_state = env.get_state()
|
||||
print(f"\nFinal level: {final_state['current_level']}")
|
||||
print(
|
||||
f"Overall accuracy: {final_state['correct_problems'] / final_state['total_problems']:.2%}"
|
||||
)
|
||||
|
||||
# Print level-by-level stats
|
||||
stats = env.get_difficulty_stats()
|
||||
print("\nPerformance by level:")
|
||||
for level, level_stats in stats.items():
|
||||
if level_stats["problems_attempted"] > 0:
|
||||
success_rate = level_stats["success_rate"]
|
||||
if success_rate is not None:
|
||||
print(
|
||||
f"Level {level}: {success_rate:.2%} success rate ({level_stats['problems_attempted']} problems)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Level {level}: Not enough data ({level_stats['problems_attempted']} problems)"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=== Starting Curriculum Manager Tests ===")
|
||||
|
||||
try:
|
||||
test_curriculum_initialization()
|
||||
test_problem_generation()
|
||||
test_performance_tracking()
|
||||
test_level_descriptions()
|
||||
test_infinite_math_environment()
|
||||
simulate_learning()
|
||||
|
||||
print("\n=== All tests completed successfully! ===")
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue