simple geometry curriculum (#342)

This commit is contained in:
Zafir Stojanovski 2025-03-13 21:00:05 +01:00 committed by GitHub
parent 1e2808889c
commit c3554c65f0
3 changed files with 46 additions and 3 deletions

View file

@ -1,9 +1,10 @@
from .advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset from .advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset
from .simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset from .simple_geometry import SimpleGeometryConfig, SimpleGeometryCurriculum, SimpleGeometryDataset
__all__ = [ __all__ = [
"SimpleGeometryConfig", "SimpleGeometryConfig",
"SimpleGeometryDataset", "SimpleGeometryDataset",
"SimpleGeometryCurriculum",
"AdvancedGeometryConfig", "AdvancedGeometryConfig",
"AdvancedGeometryDataset", "AdvancedGeometryDataset",
] ]

View file

@ -2,6 +2,7 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -114,6 +115,7 @@ class SimpleGeometryDataset(ProceduralDataset):
"missing_angle_raw": missing_angle, "missing_angle_raw": missing_angle,
"missing_angle_rounded": missing_angle_rounded, "missing_angle_rounded": missing_angle_rounded,
"total_interior_sum": total_sum, "total_interior_sum": total_sum,
"difficulty": {"sides": n_sides},
}, },
} }
@ -142,5 +144,24 @@ class SimpleGeometryDataset(ProceduralDataset):
) )
class SimpleGeometryCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SimpleGeometryCurriculum.__name__, SimpleGeometryConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="sides",
levels=[5, 10, 25, 50],
default_level=1,
description="Number of sides in the polygon.",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_sides",
upper_field_name="max_sides",
)
)
# Register the dataset so it can be accessed similarly to the others # Register the dataset so it can be accessed similarly to the others
register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig) register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig, SimpleGeometryCurriculum)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryCurriculum, SimpleGeometryDataset
def test_simple_geometry_config_validation(): def test_simple_geometry_config_validation():
@ -78,3 +78,24 @@ def test_simple_geometry_dataset_iteration():
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield the same items." assert first_items == second_items, "Multiple iterations should yield the same items."
def test_simple_geometry_curriculum():
curriculum = SimpleGeometryCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SimpleGeometryConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_sides == 5 and base_cfg.max_sides == 10
# test incrementing attribute levels
curriculum.increment_attr_level("sides")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_sides == 5 and increased_cfg.max_sides == 25
# test decrementing attribute level for sides again
curriculum.decrement_attr_level("sides")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_sides == 5 and partially_decreased_cfg.max_sides == 10