diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index f886e4a0..c53eed6d 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -23,7 +23,7 @@ from .number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFor from .power_function import PowerFunctionConfig, PowerFunctionCurriculum, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationCurriculum, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset -from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset +from .time_intervals import TimeIntervalsConfig, TimeIntervalsCurriculum, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", @@ -57,6 +57,7 @@ __all__ = [ "GSMSymbolicDataset", "TimeIntervalsConfig", "TimeIntervalsDataset", + "TimeIntervalsCurriculum", "CountBitsConfig", "CountBitsDataset", "CountBitsCurriculum", diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index eb23b232..c08de217 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -6,6 +6,7 @@ from typing import Optional import pytz from dateutil import parser +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -138,6 +139,10 @@ class TimeIntervalsDataset(ProceduralDataset): "end_time": end_dt, "format": format_str, "expected_format": expected_format, + "difficulty": { + "max_time_difference_seconds": self.config.max_time_difference_seconds, + "max_date_difference_days": self.config.max_date_difference_days, + }, }, } @@ -319,5 +324,32 @@ class TimeIntervalsDataset(ProceduralDataset): return 0.0 +class TimeIntervalsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(TimeIntervalsCurriculum.__name__, TimeIntervalsConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="max_time_difference_seconds", + field_name="max_time_difference_seconds", + levels=[60, 24 * 60 * 60, 7 * 24 * 60 * 60, 30 * 24 * 60 * 60, 365 * 24 * 60 * 60], + default_level=0, + description="Maximum time difference in seconds", + attr_type=AttributeType.STATIC, + min_value=1, + ), + ScalarAttributeDefinition( + name="max_date_difference_days", + field_name="max_date_difference_days", + levels=[1, 7, 30, 365, 5 * 365], + default_level=0, + description="Maximum date difference in days", + attr_type=AttributeType.STATIC, + min_value=1, + ), + ) + + # Register the dataset -register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig) +register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum) diff --git a/tests/test_time_intervals.py b/tests/test_time_intervals.py index 4e95f778..3d9a8a4f 100644 --- a/tests/test_time_intervals.py +++ b/tests/test_time_intervals.py @@ -2,7 +2,7 @@ from datetime import date, datetime import pytest -from reasoning_gym.arithmetic import TimeIntervalsConfig, TimeIntervalsDataset +from reasoning_gym.arithmetic import TimeIntervalsConfig, TimeIntervalsCurriculum, TimeIntervalsDataset def test_time_intervals_config_validation(): @@ -111,3 +111,28 @@ def test_time_format_patterns(): # Verify end is after start assert end_dt >= start_dt, item["question"] assert dataset.score_answer(item["answer"], item) == 1.0 + + +def test_time_intervals_curriculum(): + curriculum = TimeIntervalsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: TimeIntervalsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.max_time_difference_seconds == 60 + assert base_cfg.max_date_difference_days == 1 + + # test incrementing attribute levels + curriculum.increment_attr_level("max_time_difference_seconds") + curriculum.increment_attr_level("max_date_difference_days") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.max_time_difference_seconds == 24 * 60 * 60 + assert increased_cfg.max_date_difference_days == 7 + + # test decrementing attribute level + curriculum.decrement_attr_level("max_time_difference_seconds") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.max_time_difference_seconds == 60 + assert partially_decreased_cfg.max_date_difference_days == 7