reasoning-gym/reasoning_gym/coaching/base_curriculum.py
Oliver Stanley 1a727ecf4e
support python 3.10 (#450)
* support python 3.10

* add 3.10 to tests

* new StrEnum
2025-06-04 10:34:01 +01:00

254 lines
9.4 KiB
Python

import abc
from collections.abc import Iterable
from typing import Any, Optional, TypeVar
from ..utils import StrEnum
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
ConfigT = TypeVar("ConfigT")
class CurriculumContext(abc.ABC):
@abc.abstractmethod
def get_attr_value(self, curriculum, attr: AttributeDefinition) -> Any:
pass
class RangeAttributeMode(StrEnum):
"""Text transformation options"""
UPPER_BOUND = "upper_bound" # only use the highest range segment
INCLUSIVE = "inclusive" # include all previous levels
class DefaultCurriculumContext(CurriculumContext):
def __init__(self, mode: RangeAttributeMode = RangeAttributeMode.INCLUSIVE):
self.mode = mode
def get_range_attr_value(self, curriculum, attr: RangeAttributeDefinition) -> Any:
level = curriculum.get_attr_level(attr.name)
v = attr.get_level_value(level)
if isinstance(v, Iterable):
return v
if attr.ensure_interval:
if self.mode == RangeAttributeMode.UPPER_BOUND:
hi_index = min(level + 1, len(attr.levels) - 1)
lo_index = max(0, hi_index - 1)
elif self.mode == RangeAttributeMode.INCLUSIVE:
lo_index = 0
hi_index = min(level + 1, len(attr.levels) - 1)
else:
if self.mode == RangeAttributeMode.UPPER_BOUND:
hi_index = min(level, len(attr.levels) - 1)
lo_index = max(0, hi_index)
elif self.mode == RangeAttributeMode.INCLUSIVE:
lo_index = 0
hi_index = min(level, len(attr.levels) - 1)
lo = attr.get_level_value(lo_index)
hi = attr.get_level_value(hi_index)
return [lo, hi]
def get_attr_value(self, curriculum, attr: AttributeDefinition) -> Any:
if isinstance(attr, RangeAttributeDefinition):
return self.get_range_attr_value(curriculum, attr)
elif isinstance(attr, ScalarAttributeDefinition):
return curriculum.get_attr_value(attr.name)
class BaseCurriculum:
def __init__(self, name: str, config_cls: ConfigT):
self.name = name
self._config_cls = config_cls
self._attributes: dict[str, AttributeDefinition] = {}
self._current_levels: dict[str, int] = {}
def generate_configuration(
self, defaults: Optional[dict[str, Any]] = None, context: Optional[CurriculumContext] = None
) -> ConfigT:
config_args = defaults.copy() if defaults is not None else {}
if context is None:
context = DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE)
for attr in self._attributes.values():
if isinstance(attr, RangeAttributeDefinition):
v = context.get_attr_value(self, attr)
if not isinstance(v, Iterable):
v = [v]
config_args[attr.lower_field_name] = min(v)
config_args[attr.upper_field_name] = max(v)
elif isinstance(attr, ScalarAttributeDefinition):
val = context.get_attr_value(self, attr)
config_args[attr.field_name] = val
return self._config_cls(**config_args)
@property
def attributes(self) -> dict[str, AttributeDefinition]:
"""Get the curriculum's attributes"""
return self._attributes
def get_attribute(self, attr_name: str) -> AttributeDefinition:
if attr_name not in self._attributes:
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
return self._attributes[attr_name]
def _define_attributes(self, *attrs: tuple[AttributeDefinition, ...]) -> None:
for attr in attrs:
if attr.name in self.attributes:
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
self.attributes[attr.name] = attr
def get_attr_level(self, attr_name: str) -> int:
"""
Get the current level for an attribute.
Args:
attr_name: Name of the attribute
Returns:
Current level index for the attribute
"""
attr = self.get_attribute(attr_name)
return self._current_levels.get(attr_name, attr.default_level)
def get_attr_value(self, attr_name: str) -> Any:
"""
Get the current value for an attribute based on its level.
Args:
attr_name: Name of the attribute
Returns:
Current value for the attribute based on its level and type
"""
attr = self.get_attribute(attr_name)
level = self.get_attr_level(attr_name)
return attr.get_level_value(level)
def set_attr_level(self, attr_name: str, level: int) -> None:
"""
Set the level for an attribute.
Args:
attr_name: Name of the attribute
level: New level index
"""
attr = self.get_attribute(attr_name)
attr.validate_level(level, curriculum=self.name)
self._current_levels[attr_name] = level
def increment_attr_level(self, attr_name: str) -> bool:
"""
Increment the level of an attribute if possible.
Args:
attr_name: Name of the attribute to increment
Returns:
bool: True if level was incremented, False if already at max level
Raises:
KeyError: If attribute doesn't exist
"""
attr = self.get_attribute(attr_name)
current_level = self.get_attr_level(attr_name)
if current_level < len(attr.levels) - 1:
self.set_attr_level(attr_name, current_level + 1)
return True
return False
def decrement_attr_level(self, attr_name: str) -> bool:
"""
Decrement the level of an attribute if possible.
Args:
attr_name: Name of the attribute to decrement
Returns:
bool: True if level was decremented, False if already at min level
Raises:
KeyError: If attribute doesn't exist
"""
current_level = self.get_attr_level(attr_name)
if current_level > 0:
self.set_attr_level(attr_name, current_level - 1)
return True
return False
def get_max_level(self) -> int:
"""
Get the maximum level currently set across all attributes.
Returns:
int: The maximum level currently set across all attributes
"""
if not self._attributes:
return 0
return max(self.get_attr_level(attr_name) for attr_name in self._attributes)
def set_global_level(self, level: int) -> None:
"""
Set all attributes to the specified level.
If the level exceeds the number of defined levels for an attribute,
use the highest defined level for that attribute.
Args:
level: The level to set for all attributes
"""
for attr_name, attr in self._attributes.items():
# Use the highest defined level if the requested level exceeds available levels
attr_level = min(level, len(attr.levels) - 1)
self.set_attr_level(attr_name, attr_level)
def increment_global_level(self) -> bool:
"""
Increment the level of all attributes by one from the current maximum level.
Returns:
bool: True if at least one attribute's level was incremented, False otherwise
"""
current_max = self.get_max_level()
target_level = current_max + 1
# Check if any attribute can be incremented
can_increment = any(
self.get_attr_level(attr_name) < len(self._attributes[attr_name].levels) - 1
for attr_name in self._attributes
)
if can_increment:
for attr_name, attr in self._attributes.items():
# Only increment if the attribute is not already at its maximum level
if self.get_attr_level(attr_name) < len(attr.levels) - 1:
# Don't exceed the attribute's maximum level
new_level = min(target_level, len(attr.levels) - 1)
self.set_attr_level(attr_name, new_level)
return True
return False
def decrement_global_level(self) -> bool:
"""
Decrement the level of all attributes by one from the current maximum level.
Returns:
bool: True if at least one attribute's level was decremented, False otherwise
"""
current_max = self.get_max_level()
if current_max > 0:
target_level = current_max - 1
for attr_name in self._attributes:
# Only decrement if the attribute is at the current maximum level
if self.get_attr_level(attr_name) == current_max:
self.set_attr_level(attr_name, target_level)
return True
return False
def get_global_level(self) -> Optional[int]:
"""Get the global level of the curriculum."""
attr_dict = {}
if not self._attributes:
return 0
for attr_name in self._attributes:
attr = self.get_attribute(attr_name)
if isinstance(attr, RangeAttributeDefinition):
attr_dict[attr.upper_field_name] = self.get_attr_value(attr_name)
elif isinstance(attr, ScalarAttributeDefinition):
attr_dict[attr.field_name] = self.get_attr_value(attr_name)
return attr_dict