diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 30dcb0c4..01d387a6 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -2,6 +2,7 @@ import random from dataclasses import dataclass from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset): return expression, result +class ChainSumCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ChainSumCurriculum.__name__, ChainSumConfig) + + # Define attributes + self._define_attributes( + ( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 4, 10], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + ) + + # Register the dataset register_dataset("chain_sum", ChainSum, ChainSumConfig) diff --git a/reasoning_gym/coaching/__init__.py b/reasoning_gym/coaching/__init__.py new file mode 100644 index 00000000..50683d66 --- /dev/null +++ b/reasoning_gym/coaching/__init__.py @@ -0,0 +1,14 @@ +from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition +from .base_curriculum import BaseCurriculum +from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats + +__all__ = [ + "AttributeType", + "AttributeDefinition", + "RangeAttributeDefinition", + "BaseCurriculum", + "Coach", + "ScoreBoard", + "GroupedScores", + "ScoreStats", +] diff --git a/reasoning_gym/coaching/attributes.py b/reasoning_gym/coaching/attributes.py new file mode 100644 index 00000000..33fc053b --- /dev/null +++ b/reasoning_gym/coaching/attributes.py @@ -0,0 +1,73 @@ +from collections import abc +from dataclasses import dataclass +from enum import StrEnum +from typing import Any, Optional + + +class AttributeType(StrEnum): + """Defines how attribute levels should be interpreted""" + + STATIC = "static" # Each level is independent + UBOUND = "ubound" # Each level is an upper bound + APPEND = "append" # Each level includes all previous levels + + +@dataclass(kw_only=True) +class AttributeDefinition: + name: str + levels: list + default_level: int + description: Optional[str] = None + attr_type: AttributeType = AttributeType.STATIC # Default to static + min_value: Optional[int | float] = None # Minimum value for numeric attributes + + def validate_level(self, level: int, curriculum: str) -> None: + """ + Validate that a level is valid for an attribute. + Args: + level: Level to validate + curriculum: Name of the curriculum + Raises: + ValueError: If level is invalid + """ + # TODO: if > set as [-1], if <0 set as [0] + if not 0 <= level < len(self.levels): + raise ValueError( + f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. " + f"Must be between 0 and {len(self.levels)-1}" + ) + + def get_level_value(self, level: int, curriculum: str) -> Any: + """ + Get the value for an attribute at a specific level based on its type. + Args: + attr: The attribute definition + level: Level to get value for + Returns: + Value for the attribute based on its level and type + """ + if self.attr_type == AttributeType.STATIC: + return self.levels[level] + elif self.attr_type == AttributeType.UBOUND: + return self.levels[level] + elif self.attr_type == AttributeType.APPEND: + return self.levels[: level + 1] + + raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'") + + +@dataclass(kw_only=True) +class ScalarAttributeDefinition(AttributeDefinition): + field_name: str + + +@dataclass(kw_only=True) +class RangeAttributeDefinition(AttributeDefinition): + lower_field_name: str + upper_field_name: str + + def get_level_value(self, level: int, curriculum: str) -> Any: + v = super().get_level_value(level, curriculum) + if not isinstance(v, abc.Iterable): + return [v] + return v diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py new file mode 100644 index 00000000..8d619869 --- /dev/null +++ b/reasoning_gym/coaching/base_curriculum.py @@ -0,0 +1,108 @@ +from typing import Any, Iterable, Optional + +from ..factory import ConfigT +from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition + + +class BaseCurriculum: + def __init__(self, name: str, config_cls: ConfigT): + self.name = name + self._config_cls = config_cls + self._attributes: dict[str, AttributeDefinition] = {} + self._current_levels: dict[str, int] = {} + + def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT: + config_args = defaults.copy() if defaults is not None else {} + for attr in self._attributes.values(): + if isinstance(attr, RangeAttributeDefinition): + vals = self.get_attr_value(attr.name) + config_args[attr.lower_field_name] = min(vals) + config_args[attr.upper_field_name] = max(vals) + elif isinstance(attr, ScalarAttributeDefinition): + val = self.get_attr_value(attr.name) + config_args[attr.field_name] = val + print(config_args) + return self._config_cls(**config_args) + + @property + def attributes(self) -> dict[str, AttributeDefinition]: + """Get the curriculum's attributes""" + return self._attributes + + def get_attribute(self, attr_name: str) -> AttributeDefinition: + if attr_name not in self._attributes: + raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist") + return self._attributes[attr_name] + + def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None: + for attr in attrs: + if attr.name in self.attributes: + raise RuntimeError(f"Attribute with name {attr.name} is already defined.") + self.attributes[attr.name] = attr + + def get_attr_level(self, attr_name: str) -> int: + """ + Get the current level for an attribute. + Args: + attr_name: Name of the attribute + Returns: + Current level index for the attribute + """ + attr = self.get_attribute(attr_name) + return self._current_levels.get(attr_name, attr.default_level) + + def get_attr_value(self, attr_name: str) -> Any: + """ + Get the current value for an attribute based on its level. + Args: + attr_name: Name of the attribute + Returns: + Current value for the attribute based on its level and type + """ + attr = self.get_attribute(attr_name) + level = self.get_attr_level(attr_name) + return attr.get_level_value(level, curriculum=self.name) + + def set_attr_level(self, attr_name: str, level: int) -> None: + """ + Set the level for an attribute. + Args: + attr_name: Name of the attribute + level: New level index + """ + attr = self.get_attribute(attr_name) + attr.validate_level(level, curriculum=self.name) + self._current_levels[attr_name] = level + + def increment_attr_level(self, attr_name: str) -> bool: + """ + Increment the level of an attribute if possible. + Args: + attr_name: Name of the attribute to increment + Returns: + bool: True if level was incremented, False if already at max level + Raises: + KeyError: If attribute doesn't exist + """ + attr = self.get_attribute(attr_name) + current_level = self.get_attr_level(attr_name) + if current_level < len(attr.levels) - 1: + self.set_attr_level(attr_name, current_level + 1) + return True + return False + + def decrement_attr_level(self, attr_name: str) -> bool: + """ + Decrement the level of an attribute if possible. + Args: + attr_name: Name of the attribute to decrement + Returns: + bool: True if level was decremented, False if already at min level + Raises: + KeyError: If attribute doesn't exist + """ + current_level = self.get_attr_level(attr_name) + if current_level > 0: + self.set_attr_level(attr_name, current_level - 1) + return True + return False diff --git a/reasoning_gym/coaching.py b/reasoning_gym/coaching/coach.py similarity index 99% rename from reasoning_gym/coaching.py rename to reasoning_gym/coaching/coach.py index ad14077a..eeeab5d1 100644 --- a/reasoning_gym/coaching.py +++ b/reasoning_gym/coaching/coach.py @@ -8,7 +8,7 @@ from pathlib import Path from statistics import mean, stdev from typing import Any, Dict, List, Optional, Tuple, Union -from .dataset import ProceduralDataset +from ..dataset import ProceduralDataset @dataclass diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index c1ddf641..ed90429f 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,6 +1,7 @@ import pytest from reasoning_gym.arithmetic import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum def test_chain_sum_config_validation(): @@ -127,3 +128,20 @@ def test_chain_sum_iteration(): first_items = list(dataset) second_items = list(dataset) assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_chain_sum_curriculum(): + c = ChainSumCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ChainSumConfig = c.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + + c.increment_attr_level("num_terms") + c.increment_attr_level("num_digits") + + config2 = c.generate_configuration()