diff --git a/reasoning_gym/geometry/__init__.py b/reasoning_gym/geometry/__init__.py index 6e4e2d1a..42a8731a 100644 --- a/reasoning_gym/geometry/__init__.py +++ b/reasoning_gym/geometry/__init__.py @@ -1,9 +1,10 @@ from .advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset -from .simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset +from .simple_geometry import SimpleGeometryConfig, SimpleGeometryCurriculum, SimpleGeometryDataset __all__ = [ "SimpleGeometryConfig", "SimpleGeometryDataset", + "SimpleGeometryCurriculum", "AdvancedGeometryConfig", "AdvancedGeometryDataset", ] diff --git a/reasoning_gym/geometry/simple_geometry.py b/reasoning_gym/geometry/simple_geometry.py index 665a440f..b6ac5819 100644 --- a/reasoning_gym/geometry/simple_geometry.py +++ b/reasoning_gym/geometry/simple_geometry.py @@ -2,6 +2,7 @@ import random from dataclasses import dataclass from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -114,6 +115,7 @@ class SimpleGeometryDataset(ProceduralDataset): "missing_angle_raw": missing_angle, "missing_angle_rounded": missing_angle_rounded, "total_interior_sum": total_sum, + "difficulty": {"sides": n_sides}, }, } @@ -142,5 +144,24 @@ class SimpleGeometryDataset(ProceduralDataset): ) +class SimpleGeometryCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(SimpleGeometryCurriculum.__name__, SimpleGeometryConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="sides", + levels=[5, 10, 25, 50], + default_level=1, + description="Number of sides in the polygon.", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_sides", + upper_field_name="max_sides", + ) + ) + + # Register the dataset so it can be accessed similarly to the others -register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig) +register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig, SimpleGeometryCurriculum) diff --git a/tests/test_simple_geometry.py b/tests/test_simple_geometry.py index 804cf15a..36b45d10 100644 --- a/tests/test_simple_geometry.py +++ b/tests/test_simple_geometry.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset +from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryCurriculum, SimpleGeometryDataset def test_simple_geometry_config_validation(): @@ -78,3 +78,24 @@ def test_simple_geometry_dataset_iteration(): first_items = list(dataset) second_items = list(dataset) assert first_items == second_items, "Multiple iterations should yield the same items." + + +def test_simple_geometry_curriculum(): + curriculum = SimpleGeometryCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: SimpleGeometryConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_sides == 5 and base_cfg.max_sides == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("sides") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_sides == 5 and increased_cfg.max_sides == 25 + + # test decrementing attribute level for sides again + curriculum.decrement_attr_level("sides") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_sides == 5 and partially_decreased_cfg.max_sides == 10