mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
252 lines
8.4 KiB
Python
252 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Infinite Math - A reinforcement learning environment for math practice
|
|
using the mathgenerator library with curriculum-based advancement.
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
from .curriculum import MathCurriculum
|
|
|
|
|
|
class InfiniteMath:
|
|
"""
|
|
A reinforcement learning environment for practicing math skills with
|
|
curriculum-based advancement.
|
|
|
|
This class uses the MathCurriculum to generate appropriate math problems
|
|
and track performance, advancing difficulty as the learner improves.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
starting_level: int = 1,
|
|
progress_threshold: float = 0.8,
|
|
min_evaluations: int = 5,
|
|
max_attempts_per_problem: int = 3,
|
|
):
|
|
"""
|
|
Initialize the InfiniteMath environment.
|
|
|
|
Args:
|
|
starting_level: Initial difficulty level (default: 1)
|
|
progress_threshold: Success rate needed to advance levels (default: 0.8)
|
|
min_evaluations: Minimum evaluations before considering advancement (default: 5)
|
|
max_attempts_per_problem: Maximum attempts allowed per problem (default: 3)
|
|
"""
|
|
self.curriculum = MathCurriculum(
|
|
starting_level=starting_level,
|
|
progress_threshold=progress_threshold,
|
|
min_evaluations=min_evaluations,
|
|
)
|
|
|
|
self.max_attempts = max_attempts_per_problem
|
|
self.current_problem = None
|
|
self.current_solution = None
|
|
self.current_generator_id = None
|
|
self.attempts_remaining = 0
|
|
self.total_problems = 0
|
|
self.correct_problems = 0
|
|
|
|
# Generate the first problem
|
|
self._generate_problem()
|
|
|
|
def _generate_problem(self) -> None:
|
|
"""Generate a new problem from the curriculum."""
|
|
self.current_problem, self.current_solution, self.current_generator_id = (
|
|
self.curriculum.get_problem()
|
|
)
|
|
self.attempts_remaining = self.max_attempts
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""
|
|
Get the current state of the environment.
|
|
|
|
Returns:
|
|
Dictionary with current state information
|
|
"""
|
|
return {
|
|
"problem": self.current_problem,
|
|
"attempts_remaining": self.attempts_remaining,
|
|
"current_level": self.curriculum.get_current_level(),
|
|
"total_levels": self.curriculum.get_num_levels(),
|
|
"level_description": self.curriculum.get_level_description(),
|
|
"total_problems": self.total_problems,
|
|
"correct_problems": self.correct_problems,
|
|
"accuracy": self.correct_problems / max(1, self.total_problems),
|
|
}
|
|
|
|
def submit_answer(self, answer: str) -> Dict[str, Any]:
|
|
"""
|
|
Submit an answer to the current problem.
|
|
|
|
Args:
|
|
answer: The learner's answer to the current problem
|
|
|
|
Returns:
|
|
Dictionary with the result of the submission and updated state
|
|
"""
|
|
if self.current_problem is None:
|
|
return {"error": "No active problem. Call reset() to start a new session."}
|
|
|
|
# Clean up the answer for comparison (strip whitespace, convert to lowercase)
|
|
cleaned_answer = str(answer).strip().lower()
|
|
cleaned_solution = str(self.current_solution).strip().lower()
|
|
|
|
# Check if the answer is correct
|
|
is_correct = cleaned_answer == cleaned_solution
|
|
|
|
# Update attempts
|
|
self.attempts_remaining -= 1
|
|
|
|
result = {
|
|
"is_correct": is_correct,
|
|
"correct_answer": (
|
|
self.current_solution
|
|
if self.attempts_remaining == 0 or is_correct
|
|
else None
|
|
),
|
|
"attempts_remaining": self.attempts_remaining,
|
|
}
|
|
|
|
# If correct or out of attempts, record performance and generate a new problem
|
|
if is_correct or self.attempts_remaining == 0:
|
|
self.total_problems += 1
|
|
if is_correct:
|
|
self.correct_problems += 1
|
|
|
|
# Record performance in the curriculum
|
|
self.curriculum.record_performance(self.current_generator_id, is_correct)
|
|
|
|
# Check if we should advance to the next level
|
|
did_advance = self.curriculum.advance_difficulty()
|
|
result["did_advance_level"] = did_advance
|
|
|
|
if did_advance:
|
|
result["new_level"] = self.curriculum.get_current_level()
|
|
result["level_description"] = self.curriculum.get_level_description()
|
|
|
|
# Generate a new problem
|
|
self._generate_problem()
|
|
result["new_problem"] = self.current_problem
|
|
|
|
return result
|
|
|
|
def reset(self, level: Optional[int] = None) -> Dict[str, Any]:
|
|
"""
|
|
Reset the environment, optionally to a specific difficulty level.
|
|
|
|
Args:
|
|
level: The difficulty level to reset to (default: current level)
|
|
|
|
Returns:
|
|
Dictionary with the new state
|
|
"""
|
|
if level is not None:
|
|
self.curriculum.reset(level)
|
|
|
|
self.total_problems = 0
|
|
self.correct_problems = 0
|
|
|
|
# Generate a new problem
|
|
self._generate_problem()
|
|
|
|
return self.get_state()
|
|
|
|
def get_difficulty_stats(self) -> Dict[int, Dict[str, Any]]:
|
|
"""
|
|
Get performance statistics for each difficulty level.
|
|
|
|
Returns:
|
|
Dictionary with statistics for each level
|
|
"""
|
|
stats = {}
|
|
|
|
for level in self.curriculum.DIFFICULTY_LEVELS.keys():
|
|
success_rate = self.curriculum.get_success_rate(level)
|
|
history = self.curriculum.performance_history[level]
|
|
|
|
stats[level] = {
|
|
"description": self.curriculum.get_level_description(level),
|
|
"problems_attempted": len(history),
|
|
"success_rate": (
|
|
success_rate if success_rate is not None else float("nan")
|
|
),
|
|
"is_current_level": level == self.curriculum.get_current_level(),
|
|
}
|
|
|
|
return stats
|
|
|
|
|
|
def main():
|
|
"""Example usage of the InfiniteMath environment."""
|
|
# Create the environment
|
|
env = InfiniteMath(starting_level=1)
|
|
|
|
print("Welcome to InfiniteMath!")
|
|
print(
|
|
f"Starting at level {env.get_state()['current_level']}: {env.get_state()['level_description']}"
|
|
)
|
|
|
|
playing = True
|
|
while playing:
|
|
# Display current problem
|
|
state = env.get_state()
|
|
print("\n" + "=" * 50)
|
|
print(f"Level {state['current_level']}/{state['total_levels']}")
|
|
print(f"Problem: {state['problem']}")
|
|
print(f"Attempts remaining: {state['attempts_remaining']}")
|
|
|
|
# Get user input
|
|
answer = input("Your answer (or 'q' to quit, 'r' to reset): ")
|
|
|
|
if answer.lower() == "q":
|
|
playing = False
|
|
continue
|
|
elif answer.lower() == "r":
|
|
level = input("Reset to level (1-7, or press Enter for current level): ")
|
|
if level and level.isdigit() and 1 <= int(level) <= 7:
|
|
env.reset(int(level))
|
|
else:
|
|
env.reset()
|
|
continue
|
|
|
|
# Submit the answer
|
|
result = env.submit_answer(answer)
|
|
|
|
# Display the result
|
|
if result.get("is_correct", False):
|
|
print("Correct!")
|
|
else:
|
|
print("Incorrect.")
|
|
|
|
if result.get("correct_answer") is not None:
|
|
print(f"The correct answer is: {result['correct_answer']}")
|
|
|
|
# Check if we advanced to a new level
|
|
if result.get("did_advance_level", False):
|
|
print(f"\nCongratulations! You've advanced to level {result['new_level']}!")
|
|
print(f"New level: {result['level_description']}")
|
|
|
|
# If we have a new problem, show it
|
|
if result.get("new_problem") is not None:
|
|
print("\nNext problem is ready.")
|
|
|
|
# Show final statistics
|
|
print("\nFinal Statistics:")
|
|
print(f"Total problems attempted: {env.total_problems}")
|
|
print(f"Correct answers: {env.correct_problems}")
|
|
if env.total_problems > 0:
|
|
print(f"Overall accuracy: {env.correct_problems / env.total_problems:.2%}")
|
|
|
|
level_stats = env.get_difficulty_stats()
|
|
print("\nPerformance by level:")
|
|
for level, stats in level_stats.items():
|
|
if stats["problems_attempted"] > 0:
|
|
print(
|
|
f"Level {level}: {stats['success_rate']:.2%} success rate ({stats['problems_attempted']} problems)"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|