mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Calendar arithmetic curriculum (#283)
* calendar arithmetic curriculum * add difficulty to metadata * register CalendarArithmeticCurriculum --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
775a42e9e4
commit
8c80bf6bec
3 changed files with 73 additions and 7 deletions
|
|
@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities:
|
|||
|
||||
from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticCurriculum, BitwiseArithmeticDataset
|
||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticCurriculum, CalendarArithmeticDataset
|
||||
from .chain_sum import ChainSumConfig, ChainSumDataset
|
||||
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
|
||||
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
|
||||
|
|
@ -29,6 +29,7 @@ __all__ = [
|
|||
"ChainSumConfig",
|
||||
"CalendarArithmeticConfig",
|
||||
"CalendarArithmeticDataset",
|
||||
"CalendarArithmeticCurriculum",
|
||||
"FractionSimplificationConfig",
|
||||
"FractionSimplificationDataset",
|
||||
"GCDConfig",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from datetime import date, timedelta
|
|||
from enum import Enum, StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -41,7 +42,7 @@ class Weekday(Enum):
|
|||
class CalendarTask(StrEnum):
|
||||
WEEKDAY_OFFSET = "weekday_offset"
|
||||
WEEKDAY_OF_DATE = "weekday_of_date"
|
||||
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
|
||||
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_date"
|
||||
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
|
||||
COUNT_DAYS = "count_days"
|
||||
COUNT_BUSINESS_DAYS = "count_business_days"
|
||||
|
|
@ -112,7 +113,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
self.task_handlers = {
|
||||
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_date,
|
||||
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
|
||||
CalendarTask.COUNT_DAYS.value: self._count_days,
|
||||
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
|
||||
|
|
@ -125,6 +126,10 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
rng = random.Random(self.seed + idx)
|
||||
task = rng.choice(self.tasks)
|
||||
question, answer, metadata = task(rng)
|
||||
metadata["difficulty"] = {
|
||||
"task_complexity": self.tasks.index(task),
|
||||
"date_range": self.config.offset_upper_bound,
|
||||
}
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(answer),
|
||||
|
|
@ -193,7 +198,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _weekday_of_date_from_first_day(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
def _weekday_of_date_from_first_date(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
|
||||
example:
|
||||
|
|
@ -484,4 +489,44 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
|
||||
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)
|
||||
class CalendarArithmeticCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(CalendarArithmeticCurriculum.__name__, CalendarArithmeticConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="task_complexity",
|
||||
levels=[
|
||||
["weekday_of_date"],
|
||||
["weekday_of_date", "is_leap_year", "weekday_offset"],
|
||||
["weekday_of_date", "is_leap_year", "weekday_offset", "count_days", "count_business_days"],
|
||||
[
|
||||
"weekday_of_date",
|
||||
"is_leap_year",
|
||||
"weekday_offset",
|
||||
"count_days",
|
||||
"count_business_days",
|
||||
"weekday_of_date_from_first_date",
|
||||
"recurring_event_day",
|
||||
],
|
||||
],
|
||||
default_level=0,
|
||||
description="Controls which calendar tasks are included",
|
||||
attr_type=AttributeType.STATIC,
|
||||
field_name="tasks",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="date_range",
|
||||
levels=[30, 100, 250, 365],
|
||||
default_level=0,
|
||||
description="Maximum day range for offset and counting tasks",
|
||||
attr_type=AttributeType.STATIC,
|
||||
field_name="offset_upper_bound",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from datetime import date
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||
from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticCurriculum, CalendarArithmeticDataset
|
||||
|
||||
WEEKDAYS = [
|
||||
"Monday",
|
||||
|
|
@ -18,7 +18,7 @@ WEEKDAYS = [
|
|||
|
||||
WEEKDAY_TASKS = {
|
||||
"weekday_offset",
|
||||
"weekday_of_date_from_first_day",
|
||||
"weekday_of_date_from_first_date",
|
||||
"weekday_of_date",
|
||||
}
|
||||
NUMERIC_TASKS = {
|
||||
|
|
@ -71,6 +71,7 @@ def test_calendar_item_structure():
|
|||
assert isinstance(item["question"], str) and len(item["question"]) > 0
|
||||
assert isinstance(item["answer"], str) and len(item["answer"]) > 0
|
||||
assert "task" in item["metadata"]
|
||||
print(item["metadata"]["task"])
|
||||
assert item["metadata"]["task"] in CALENDAR_TASKS
|
||||
|
||||
|
||||
|
|
@ -196,3 +197,22 @@ def test_task_case_sensitivity():
|
|||
|
||||
for item in dataset:
|
||||
assert item["metadata"]["task"] in [t.lower() for t in tasks]
|
||||
|
||||
|
||||
def test_calendar_curriculum():
|
||||
"""Test that the curriculum generates correct configurations."""
|
||||
curriculum = CalendarArithmeticCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: CalendarArithmeticConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.tasks == ["weekday_of_date"]
|
||||
assert base_cfg.offset_upper_bound == 30
|
||||
|
||||
curriculum.increment_attr_level("task_complexity")
|
||||
curriculum.increment_attr_level("date_range")
|
||||
increased_cfg: CalendarArithmeticConfig = curriculum.generate_configuration()
|
||||
assert increased_cfg.tasks == ["weekday_of_date", "is_leap_year", "weekday_offset"]
|
||||
assert increased_cfg.offset_upper_bound == 100
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue