Calendar arithmetic curriculum (#283)

* calendar arithmetic curriculum
* add difficulty to metadata
* register CalendarArithmeticCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
vncntt 2025-03-07 16:38:22 -08:00 committed by GitHub
parent 775a42e9e4
commit 8c80bf6bec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 73 additions and 7 deletions

View file

@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities:
from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticCurriculum, BitwiseArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticCurriculum, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
@ -29,6 +29,7 @@ __all__ = [
"ChainSumConfig",
"CalendarArithmeticConfig",
"CalendarArithmeticDataset",
"CalendarArithmeticCurriculum",
"FractionSimplificationConfig",
"FractionSimplificationDataset",
"GCDConfig",

View file

@ -6,6 +6,7 @@ from datetime import date, timedelta
from enum import Enum, StrEnum, auto
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -41,7 +42,7 @@ class Weekday(Enum):
class CalendarTask(StrEnum):
WEEKDAY_OFFSET = "weekday_offset"
WEEKDAY_OF_DATE = "weekday_of_date"
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_date"
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
COUNT_DAYS = "count_days"
COUNT_BUSINESS_DAYS = "count_business_days"
@ -112,7 +113,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
self.task_handlers = {
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_date,
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
CalendarTask.COUNT_DAYS.value: self._count_days,
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
@ -125,6 +126,10 @@ class CalendarArithmeticDataset(ProceduralDataset):
rng = random.Random(self.seed + idx)
task = rng.choice(self.tasks)
question, answer, metadata = task(rng)
metadata["difficulty"] = {
"task_complexity": self.tasks.index(task),
"date_range": self.config.offset_upper_bound,
}
return {
"question": question,
"answer": str(answer),
@ -193,7 +198,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, answer_weekday, metadata
def _weekday_of_date_from_first_day(self, rng: random.Random) -> tuple[str, str, dict]:
def _weekday_of_date_from_first_date(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
example:
@ -484,4 +489,44 @@ class CalendarArithmeticDataset(ProceduralDataset):
return 0.0
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)
class CalendarArithmeticCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(CalendarArithmeticCurriculum.__name__, CalendarArithmeticConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="task_complexity",
levels=[
["weekday_of_date"],
["weekday_of_date", "is_leap_year", "weekday_offset"],
["weekday_of_date", "is_leap_year", "weekday_offset", "count_days", "count_business_days"],
[
"weekday_of_date",
"is_leap_year",
"weekday_offset",
"count_days",
"count_business_days",
"weekday_of_date_from_first_date",
"recurring_event_day",
],
],
default_level=0,
description="Controls which calendar tasks are included",
attr_type=AttributeType.STATIC,
field_name="tasks",
),
ScalarAttributeDefinition(
name="date_range",
levels=[30, 100, 250, 365],
default_level=0,
description="Maximum day range for offset and counting tasks",
attr_type=AttributeType.STATIC,
field_name="offset_upper_bound",
),
)
register_dataset(
"calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum
)

View file

@ -4,7 +4,7 @@ from datetime import date
import pytest
from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticCurriculum, CalendarArithmeticDataset
WEEKDAYS = [
"Monday",
@ -18,7 +18,7 @@ WEEKDAYS = [
WEEKDAY_TASKS = {
"weekday_offset",
"weekday_of_date_from_first_day",
"weekday_of_date_from_first_date",
"weekday_of_date",
}
NUMERIC_TASKS = {
@ -71,6 +71,7 @@ def test_calendar_item_structure():
assert isinstance(item["question"], str) and len(item["question"]) > 0
assert isinstance(item["answer"], str) and len(item["answer"]) > 0
assert "task" in item["metadata"]
print(item["metadata"]["task"])
assert item["metadata"]["task"] in CALENDAR_TASKS
@ -196,3 +197,22 @@ def test_task_case_sensitivity():
for item in dataset:
assert item["metadata"]["task"] in [t.lower() for t in tasks]
def test_calendar_curriculum():
"""Test that the curriculum generates correct configurations."""
curriculum = CalendarArithmeticCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: CalendarArithmeticConfig = curriculum.generate_configuration(base_value)
assert base_cfg.size == 150
assert base_cfg.seed == 1
assert base_cfg.tasks == ["weekday_of_date"]
assert base_cfg.offset_upper_bound == 30
curriculum.increment_attr_level("task_complexity")
curriculum.increment_attr_level("date_range")
increased_cfg: CalendarArithmeticConfig = curriculum.generate_configuration()
assert increased_cfg.tasks == ["weekday_of_date", "is_leap_year", "weekday_offset"]
assert increased_cfg.offset_upper_bound == 100