reasoning-gym/reasoning_gym/core/base_curriculum.py
2025-02-07 10:35:44 +00:00

141 lines
No EOL
4.7 KiB
Python

"""
Base class for exercise curricula that defines the interface and common functionality.
"""
from typing import Dict, List, Any, Callable
from reasoning_gym.core.attributes import AttributeDefinition, AttributeType
from reasoning_gym.core.template import Template
import random
class BaseCurriculum:
"""Base class for all exercise curricula"""
def __init__(self, name: str):
self.name = name
self._attributes: Dict[str, AttributeDefinition] = {}
self._templates: List[Template] = []
self._valid_types: set[AttributeType] = set()
self._current_levels: Dict[str, int] = {}
# Let child class fill in the structure
self._init_curriculum()
# Validate the filled structure
self._validate()
# TODO: Why?
def _init_curriculum(self) -> None:
"""
Initialize curriculum-specific attributes and templates.
Must be implemented by subclasses to fill in the pre-defined structure.
"""
raise NotImplementedError("Subclasses must implement _init_curriculum()")
def _validate(self) -> None:
"""Validate the curriculum configuration"""
# Validate attributes
AttributeDefinition.validate_attributes(
self._attributes,
self._valid_types,
curriculum=self.name
)
# Validate templates exist
if not self._templates:
raise ValueError(f"Curriculum {self.name} has no templates defined")
@property
def attributes(self) -> Dict[str, AttributeDefinition]:
"""Get the curriculum's attributes"""
return self._attributes
@property
def templates(self) -> List[Template]:
"""Get the curriculum's templates"""
return self._templates
def get_template(self, rng: random.Random) -> Template:
"""Get a random template"""
if not self._templates:
raise ValueError("No templates defined")
return rng.choice(self._templates)
def get_attr_level(self, attr_name: str) -> int:
"""
Get the current level for an attribute.
Args:
attr_name: Name of the attribute
Returns:
Current level index for the attribute
"""
attr = AttributeDefinition.check_attribute_exists(self._attributes, attr_name, self.name)
return self._current_levels.get(attr_name, attr.default_level)
def get_attr_value(self, attr_name: str) -> Any:
"""
Get the current value for an attribute based on its level.
Args:
attr_name: Name of the attribute
Returns:
Current value for the attribute based on its level and type
"""
attr = AttributeDefinition.check_attribute_exists(self._attributes, attr_name, self.name)
level = self.get_attr_level(attr_name)
return AttributeDefinition.get_level_value(attr, level, attr_name, self.name)
def set_attr_level(self, attr_name: str, level: int) -> None:
"""
Set the level for an attribute.
Args:
attr_name: Name of the attribute
level: New level index
"""
attr = AttributeDefinition.check_attribute_exists(self._attributes, attr_name, self.name)
AttributeDefinition.validate_level(attr, level, attr_name, self.name)
self._current_levels[attr_name] = level
def increment_attr_level(self, attr_name: str) -> bool:
"""
Increment the level of an attribute if possible.
Args:
attr_name: Name of the attribute to increment
Returns:
bool: True if level was incremented, False if already at max level
Raises:
KeyError: If attribute doesn't exist
"""
attr = AttributeDefinition.check_attribute_exists(self._attributes, attr_name, self.name)
current_level = self.get_attr_level(attr_name)
if current_level < len(attr.levels) - 1:
self.set_attr_level(attr_name, current_level + 1)
return True
return False
def decrement_attr_level(self, attr_name: str) -> bool:
"""
Decrement the level of an attribute if possible.
Args:
attr_name: Name of the attribute to decrement
Returns:
bool: True if level was decremented, False if already at min level
Raises:
KeyError: If attribute doesn't exist
"""
attr = AttributeDefinition.check_attribute_exists(self._attributes, attr_name, self.name)
current_level = self.get_attr_level(attr_name)
if current_level > 0:
self.set_attr_level(attr_name, current_level - 1)
return True
return False