mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* calendar arithmetic curriculum * add difficulty to metadata * register CalendarArithmeticCurriculum --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
218 lines
7.8 KiB
Python
218 lines
7.8 KiB
Python
import calendar
|
|
import math
|
|
from datetime import date
|
|
|
|
import pytest
|
|
|
|
from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticCurriculum, CalendarArithmeticDataset
|
|
|
|
WEEKDAYS = [
|
|
"Monday",
|
|
"Tuesday",
|
|
"Wednesday",
|
|
"Thursday",
|
|
"Friday",
|
|
"Saturday",
|
|
"Sunday",
|
|
]
|
|
|
|
WEEKDAY_TASKS = {
|
|
"weekday_offset",
|
|
"weekday_of_date_from_first_date",
|
|
"weekday_of_date",
|
|
}
|
|
NUMERIC_TASKS = {
|
|
"count_days",
|
|
"count_business_days",
|
|
}
|
|
DAY_TASKS = {"recurring_event_day"}
|
|
BOOLEAN_TASKS = {"is_leap_year"}
|
|
CALENDAR_TASKS = WEEKDAY_TASKS | NUMERIC_TASKS | DAY_TASKS | BOOLEAN_TASKS
|
|
|
|
|
|
def test_calendar_config_validation():
|
|
"""Test that invalid CalendarArithmeticConfig parameters raise appropriate errors."""
|
|
with pytest.raises(ValueError):
|
|
config = CalendarArithmeticConfig(year=0)
|
|
config.validate()
|
|
|
|
with pytest.raises(ValueError):
|
|
config = CalendarArithmeticConfig(size=0)
|
|
config.validate()
|
|
|
|
with pytest.raises(ValueError):
|
|
config = CalendarArithmeticConfig(seed="not_an_int")
|
|
config.validate()
|
|
|
|
with pytest.raises(ValueError):
|
|
config = CalendarArithmeticConfig(tasks=["invalid_task"])
|
|
|
|
|
|
def test_calendar_deterministic():
|
|
"""Test that a dataset with a fixed seed produces the same items."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=10)
|
|
ds1 = CalendarArithmeticDataset(config)
|
|
ds2 = CalendarArithmeticDataset(config)
|
|
|
|
for i in range(len(ds1)):
|
|
assert ds1[i] == ds2[i]
|
|
|
|
|
|
def test_calendar_item_structure():
|
|
"""Test that dataset items have the correct structure and fields."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=50)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
assert isinstance(item, dict)
|
|
assert all(key in item for key in ["question", "answer", "metadata"])
|
|
|
|
assert isinstance(item["question"], str) and len(item["question"]) > 0
|
|
assert isinstance(item["answer"], str) and len(item["answer"]) > 0
|
|
assert "task" in item["metadata"]
|
|
print(item["metadata"]["task"])
|
|
assert item["metadata"]["task"] in CALENDAR_TASKS
|
|
|
|
|
|
def test_calendar_answer_format():
|
|
"""Test that answers have the correct format based on task type."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=100)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
task = item["metadata"]["task"]
|
|
answer = item["answer"]
|
|
|
|
if task in WEEKDAY_TASKS:
|
|
assert answer in WEEKDAYS
|
|
|
|
elif task in NUMERIC_TASKS:
|
|
try:
|
|
num = int(answer)
|
|
assert num >= 0, f"task {task} produced a negative count: {num}"
|
|
except ValueError:
|
|
pytest.fail(f"task {task} produced a non-integer answer: {answer}")
|
|
|
|
elif task in BOOLEAN_TASKS:
|
|
assert answer in ["Yes", "No"]
|
|
|
|
elif task in DAY_TASKS:
|
|
try:
|
|
num = int(answer)
|
|
year = item["metadata"]["year"]
|
|
month = item["metadata"]["month"]
|
|
_, last_day = calendar.monthrange(year, month)
|
|
assert 1 <= num <= last_day
|
|
except ValueError:
|
|
pytest.fail(f"task {task} produced a day outside expected range (1-{last_day}): {answer}")
|
|
|
|
|
|
def test_scoring_function():
|
|
"""Test scoring function for different answer types."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=1)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
weekday_item = {"answer": "Monday", "metadata": {"task": "weekday_offset"}}
|
|
|
|
assert dataset.score_answer("Monday", weekday_item) == 1.0
|
|
assert dataset.score_answer("Tuesday", weekday_item) == 0.1
|
|
assert dataset.score_answer("It is Monday", weekday_item) == 0.0
|
|
assert dataset.score_answer("no weekday here", weekday_item) == 0.0
|
|
assert dataset.score_answer(None, weekday_item) == 0.0
|
|
|
|
numeric_item = {"answer": "10", "metadata": {"task": "count_business_days"}}
|
|
assert dataset.score_answer("10", numeric_item) == 1.0
|
|
assert dataset.score_answer("15", numeric_item) == pytest.approx(math.exp(-5 * 0.5))
|
|
assert dataset.score_answer("no number", numeric_item) == 0.0
|
|
assert dataset.score_answer(None, numeric_item) == 0.0
|
|
|
|
boolean_item = {"answer": "Yes", "metadata": {"task": "is_leap_year"}}
|
|
assert dataset.score_answer("Yes", boolean_item) == 1.0
|
|
assert dataset.score_answer("yes", boolean_item) == 1.0
|
|
assert dataset.score_answer("nyes", boolean_item) == 0.0
|
|
assert dataset.score_answer(None, boolean_item) == 0.0
|
|
|
|
|
|
def test_calendar_date_consistency():
|
|
"""Test that dates in metadata are consistent with config year."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=50)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
task = item["metadata"]["task"]
|
|
|
|
if task == "weekday_offset":
|
|
start_date = date.fromisoformat(item["metadata"]["start_date"])
|
|
assert start_date.year == config.year
|
|
|
|
elif task in {"weekday_of_date_from_first_day", "weekday_of_date"}:
|
|
target_date = date.fromisoformat(item["metadata"]["target_date"])
|
|
assert target_date.year == config.year
|
|
|
|
elif task in {"count_business_days", "count_days"}:
|
|
start_date = date.fromisoformat(item["metadata"]["start_date"])
|
|
end_date = date.fromisoformat(item["metadata"]["end_date"])
|
|
assert start_date.year == config.year
|
|
assert end_date.year == config.year
|
|
|
|
elif task == "recurring_event_day":
|
|
meta_year = item["metadata"]["year"]
|
|
month = item["metadata"]["month"]
|
|
answer = int(item["answer"])
|
|
assert meta_year == config.year
|
|
assert 1 <= month <= 12
|
|
if answer != -1:
|
|
_, last_day = calendar.monthrange(meta_year, month)
|
|
assert 1 <= answer <= last_day
|
|
|
|
elif task == "is_leap_year":
|
|
year = item["metadata"]["year"]
|
|
assert config.year - 200 <= year <= config.year + 200
|
|
is_leap_metadata = item["metadata"]["is_leap"]
|
|
computed_is_leap = calendar.isleap(year)
|
|
assert is_leap_metadata == computed_is_leap
|
|
|
|
|
|
def test_calendar_iteration():
|
|
"""Test that dataset iteration works correctly and is deterministic."""
|
|
config = CalendarArithmeticConfig(year=2024, seed=42, size=5)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
items = [item for item in dataset]
|
|
assert len(items) == config.size
|
|
|
|
first_iter = list(dataset)
|
|
second_iter = list(dataset)
|
|
assert first_iter == second_iter
|
|
|
|
|
|
def test_task_case_sensitivity():
|
|
"""Test that task names are case-insensitive."""
|
|
tasks = ["WEEKDAY_OFFSET", "Count_Business_Days"]
|
|
config = CalendarArithmeticConfig(tasks=tasks, size=10)
|
|
dataset = CalendarArithmeticDataset(config)
|
|
|
|
for item in dataset:
|
|
assert item["metadata"]["task"] in [t.lower() for t in tasks]
|
|
|
|
|
|
def test_calendar_curriculum():
|
|
"""Test that the curriculum generates correct configurations."""
|
|
curriculum = CalendarArithmeticCurriculum()
|
|
|
|
base_value = {"size": 150, "seed": 1}
|
|
|
|
base_cfg: CalendarArithmeticConfig = curriculum.generate_configuration(base_value)
|
|
assert base_cfg.size == 150
|
|
assert base_cfg.seed == 1
|
|
assert base_cfg.tasks == ["weekday_of_date"]
|
|
assert base_cfg.offset_upper_bound == 30
|
|
|
|
curriculum.increment_attr_level("task_complexity")
|
|
curriculum.increment_attr_level("date_range")
|
|
increased_cfg: CalendarArithmeticConfig = curriculum.generate_configuration()
|
|
assert increased_cfg.tasks == ["weekday_of_date", "is_leap_year", "weekday_offset"]
|
|
assert increased_cfg.offset_upper_bound == 100
|