mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
Merge branch 'main' into koko/gsm-symbolic-task-1
This commit is contained in:
commit
a80339a0e6
49 changed files with 6334 additions and 147 deletions
|
|
@ -4,9 +4,11 @@ Arithmetic tasks for training reasoning capabilities:
|
|||
- Chain sums
|
||||
- Word problems
|
||||
- Leg counting
|
||||
- Time intervals
|
||||
"""
|
||||
|
||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||
from .chain_sum import ChainSum, ChainSumConfig
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .gcd import GCDConfig, GCDDataset
|
||||
|
|
@ -14,6 +16,7 @@ from .lcm import LCMConfig, LCMDataset
|
|||
from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||
from .gsm_symbolic.gsm_symbolic_datasets import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
||||
|
||||
__all__ = [
|
||||
"BasicArithmeticDataset",
|
||||
|
|
@ -21,6 +24,10 @@ __all__ = [
|
|||
"basic_arithmetic_dataset",
|
||||
"ChainSum",
|
||||
"ChainSumConfig",
|
||||
"CalendarArithmeticConfig",
|
||||
"CalendarArithmeticDataset",
|
||||
"Weekday",
|
||||
"CalendarTask",
|
||||
"FractionSimplificationConfig",
|
||||
"FractionSimplificationDataset",
|
||||
"GCDConfig",
|
||||
|
|
@ -33,4 +40,6 @@ __all__ = [
|
|||
"PrimeFactorizationDataset",
|
||||
"GSMSymbolicDatasetConfig",
|
||||
"GSMSymbolicDataset",
|
||||
"TimeIntervalsConfig",
|
||||
"TimeIntervalsDataset",
|
||||
]
|
||||
|
|
|
|||
490
reasoning_gym/arithmetic/calendar_arithmetic.py
Normal file
490
reasoning_gym/arithmetic/calendar_arithmetic.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
import calendar
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
class Weekday(Enum):
|
||||
MONDAY = auto()
|
||||
TUESDAY = auto()
|
||||
WEDNESDAY = auto()
|
||||
THURSDAY = auto()
|
||||
FRIDAY = auto()
|
||||
SATURDAY = auto()
|
||||
SUNDAY = auto()
|
||||
|
||||
@classmethod
|
||||
def from_date(cls, d: date) -> "Weekday":
|
||||
return list(cls)[d.weekday()]
|
||||
|
||||
@classmethod
|
||||
def random(cls, rng: random.Random) -> "Weekday":
|
||||
return list(cls)[rng.randint(0, 6)]
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, idx) -> "Weekday":
|
||||
return list(cls)[idx]
|
||||
|
||||
@property
|
||||
def index(self) -> int:
|
||||
return self.value - 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name.capitalize()
|
||||
|
||||
|
||||
class CalendarTask(Enum):
|
||||
WEEKDAY_OFFSET = "weekday_offset"
|
||||
WEEKDAY_OF_DATE = "weekday_of_date"
|
||||
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
|
||||
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
|
||||
COUNT_DAYS = "count_days"
|
||||
COUNT_BUSINESS_DAYS = "count_business_days"
|
||||
IS_LEAP_YEAR = "is_leap_year"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalendarArithmeticConfig:
|
||||
year: int = 2022
|
||||
tasks: Optional[List[str]] = None
|
||||
offset_upper_bound: int = 100
|
||||
leap_year_range: int = 200
|
||||
seed: Optional[int] = 42
|
||||
size: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tasks is None:
|
||||
self.tasks = [task.value for task in CalendarTask]
|
||||
else:
|
||||
self.tasks = [task.lower() for task in self.tasks]
|
||||
valid_tasks = {task.value for task in CalendarTask}
|
||||
invalid_tasks = set(self.tasks) - valid_tasks
|
||||
if invalid_tasks:
|
||||
valid_task_list = ", ".join(sorted(valid_tasks))
|
||||
raise ValueError(
|
||||
f"Invalid tasks: {', '.join(sorted(invalid_tasks))}. " f"Valid tasks are: {valid_task_list}"
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the configuration parameters."""
|
||||
if not isinstance(self.year, int) or self.year <= 0:
|
||||
raise ValueError(f"year must be a positive integer, got {self.year}")
|
||||
|
||||
if self.seed is not None and not isinstance(self.seed, int):
|
||||
raise ValueError(f"seed must be an integer or None, got {type(self.seed)}")
|
||||
|
||||
if not isinstance(self.size, int) or self.size <= 0:
|
||||
raise ValueError(f"size must be a positive integer, got {self.size}")
|
||||
|
||||
|
||||
class CalendarArithmeticDataset(ProceduralDataset):
|
||||
DAY_QUESTION_TEMPLATES = [
|
||||
"Answer with the weekday's name (e.g., Monday, Tuesday, etc.).",
|
||||
"Provide the full name of the weekday.",
|
||||
"State the weekday (Monday through Sunday).",
|
||||
"Give the weekday name in full.",
|
||||
"Reply with just the weekday name.",
|
||||
"Write out the full weekday name.",
|
||||
"Respond with the weekday (Monday-Sunday).",
|
||||
"Answer using the complete weekday name.",
|
||||
"Name the day of the week in full.",
|
||||
]
|
||||
|
||||
COUNT_QUESTION_TEMPLATES = [
|
||||
"Answer with a number.",
|
||||
"Provide the count as a number.",
|
||||
"Respond with just the number.",
|
||||
"Write the total number.",
|
||||
"Give the count numerically.",
|
||||
"State the amount as a number.",
|
||||
"Reply with the numerical value.",
|
||||
"Express your answer as a number.",
|
||||
]
|
||||
|
||||
def __init__(self, config: CalendarArithmeticConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
self.task_handlers = {
|
||||
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
|
||||
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
|
||||
CalendarTask.COUNT_DAYS.value: self._count_days,
|
||||
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
|
||||
CalendarTask.IS_LEAP_YEAR.value: self._is_leap_year,
|
||||
}
|
||||
|
||||
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
item_rng = random.Random(self.seed + idx)
|
||||
task = item_rng.choice(self.tasks)
|
||||
question, answer, metadata = task(item_rng)
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(answer),
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
Task: Given a starting date and a day offset (which may be positive or negative),
|
||||
ask what day of the week it will be.
|
||||
Examples:
|
||||
- "If today is Wednesday, March 13, 2024, what day of the week will it be in 10 days? Answer with the weekday's name."
|
||||
- "If today is Wednesday, March 13, 2024, what day of the week was it 10 days ago? Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
start_date = self._random_date_for_year(rng, year)
|
||||
offset = rng.randint(1, self.config.offset_upper_bound)
|
||||
sign = rng.choice([-1, 1])
|
||||
offset_days = sign * offset
|
||||
target_date = start_date + timedelta(days=offset_days)
|
||||
target_weekday = target_date.strftime("%A")
|
||||
|
||||
date_str = f"{start_date.strftime('%A')}, {start_date.strftime('%B')} {start_date.day}, {start_date.year}"
|
||||
if offset_days >= 0:
|
||||
templates = [
|
||||
f"If today is {date_str}, what day of the week will it be in {offset_days} days? ",
|
||||
f"Starting from {date_str}, which weekday falls after a {offset_days}-day jump? ",
|
||||
f"Count forward {offset_days} days from {date_str} - what's the weekday? ",
|
||||
]
|
||||
else:
|
||||
templates = [
|
||||
f"If today is {date_str}, what day of the week was it {abs(offset_days)} days ago? ",
|
||||
f"Starting from {date_str}, which weekday was it {abs(offset_days)} days before? ",
|
||||
f"Count backward {abs(offset_days)} days from {date_str} - what's the weekday? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OFFSET.value,
|
||||
"start_date": start_date.isoformat(),
|
||||
"offset_days": offset_days,
|
||||
"target_date": target_date.isoformat(),
|
||||
}
|
||||
return question, target_weekday, metadata
|
||||
|
||||
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask what day of the week a given date was.
|
||||
example:
|
||||
"What day of the week was January 15, 2024?
|
||||
Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
target_date = self._random_date_for_year(rng, year)
|
||||
answer_weekday = target_date.strftime("%A")
|
||||
templates = [
|
||||
f"What day of the week was {target_date.strftime('%B')} {target_date.day}, {year}?",
|
||||
f"On which weekday did {target_date.strftime('%B')} {target_date.day}, {year} fall?",
|
||||
f"Name the day of the week for {target_date.strftime('%m/%d/%Y')}.",
|
||||
]
|
||||
|
||||
question = f"{rng.choice(templates)} {rng.choice(self.DAY_QUESTION_TEMPLATES)}"
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
"target_date": target_date.isoformat(),
|
||||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
|
||||
example:
|
||||
"If the first day of the year was a Monday, what day of the week will December 31 be?
|
||||
Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
first_day = Weekday.random(rng)
|
||||
first_day_index = first_day.index
|
||||
# Ensure target date is not January 1.
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
max_date = min(year_start + max_delta, year_end)
|
||||
while True:
|
||||
target_date = self._random_date_between(rng, year_start, max_date)
|
||||
if target_date != date(year, 1, 1):
|
||||
break
|
||||
delta_days = (target_date - date(year, 1, 1)).days
|
||||
answer_index = (first_day_index + delta_days) % 7
|
||||
answer_weekday = Weekday(answer_index + 1)
|
||||
|
||||
templates = [
|
||||
f"If the first day of the year was a {first_day}, what day of the week will "
|
||||
f"{target_date.strftime('%B')} {target_date.day} be? ",
|
||||
f"Given that January 1 fell on a {first_day}, which weekday occurs on "
|
||||
f"{target_date.strftime('%B')} {target_date.day}? ",
|
||||
f"In a year where {first_day} is January 1st, name the weekday of "
|
||||
f"{target_date.strftime('%B')} {target_date.day}. ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
"year": year,
|
||||
"first_day": str(first_day),
|
||||
"target_date": target_date.isoformat(),
|
||||
"delta_days": delta_days,
|
||||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: For a recurring event defined by an ordinal weekday pattern in a month,
|
||||
ask on which day of the month the event occurs.
|
||||
example:
|
||||
"If a meeting is scheduled on the second Tuesday of May 2024, on which day does it fall?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
month = rng.randint(1, 12)
|
||||
ordinals = ["first", "second", "third", "fourth", "last"]
|
||||
ordinal = rng.choice(ordinals)
|
||||
weekday = Weekday.random(rng)
|
||||
month_name = calendar.month_name[month]
|
||||
_, last_day = calendar.monthrange(year, month)
|
||||
|
||||
if ordinal != "last":
|
||||
ordinal_number = {"first": 1, "second": 2, "third": 3, "fourth": 4}[ordinal]
|
||||
count = 0
|
||||
event_day = None
|
||||
for day in range(1, last_day + 1):
|
||||
d = date(year, month, day)
|
||||
if d.strftime("%A") == str(weekday):
|
||||
count += 1
|
||||
if count == ordinal_number:
|
||||
event_day = day
|
||||
break
|
||||
if event_day is None:
|
||||
# This should rarely happen but in some months the ordinal may not exist.
|
||||
event_day = -1
|
||||
else:
|
||||
event_day = None
|
||||
for day in range(last_day, 0, -1):
|
||||
d = date(year, month, day)
|
||||
if d.strftime("%A") == str(weekday):
|
||||
event_day = day
|
||||
break
|
||||
if event_day is None:
|
||||
event_day = -1
|
||||
|
||||
templates = [
|
||||
f"If a meeting is scheduled on the {ordinal} {weekday} of {month_name} {year}, on which day of the month does it occur? ",
|
||||
f"In {month_name} {year}, if an event recurs on the {ordinal} {weekday}, what is the date (day of the month) of the event? ",
|
||||
f"Determine the day of the month for the {ordinal} {weekday} in {month_name} {year}. ",
|
||||
]
|
||||
question = (
|
||||
rng.choice(templates)
|
||||
+ rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
+ " Answer with -1 if the ordinal does not exist in the month."
|
||||
)
|
||||
metadata = {
|
||||
"task": CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
|
||||
"year": year,
|
||||
"month": month,
|
||||
"ordinal": ordinal,
|
||||
"weekday": str(weekday),
|
||||
}
|
||||
return question, str(event_day), metadata
|
||||
|
||||
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask how many times a given weekday occurs in a specified range.
|
||||
example:
|
||||
"How many days are there between March 1, 2024 and March 15, 2024?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
start_date = self._random_date_between(rng, year_start, year_end)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
end_date = self._random_date_between(rng, start_date, min(year_end, start_date + max_delta))
|
||||
weekday = Weekday.random(rng)
|
||||
|
||||
def count_weekday_between(d1: date, d2: date, weekday: str) -> int:
|
||||
days = (d2 - d1).days + 1
|
||||
return sum(1 for i in range(days) if (d1 + timedelta(days=i)).strftime("%A") == weekday)
|
||||
|
||||
count = count_weekday_between(start_date, end_date, str(weekday))
|
||||
|
||||
templates = [
|
||||
f"How many {weekday}s are there from {start_date.strftime('%A, %B')} {start_date.day}, {year} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} (inclusive of both dates)? ",
|
||||
f"Count the occurrences of {weekday} from {start_date.strftime('%A, %B')} {start_date.day} "
|
||||
f"to {end_date.strftime('%A, %B')} {end_date.day}, {year} (including both start and end dates). ",
|
||||
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(counting both dates), how many times does {weekday} occur? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.COUNT_DAYS.value,
|
||||
"year": year,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Count the number of business days (Monday-Friday) between two dates.
|
||||
example:
|
||||
"How many business days (Monday-Friday) are there between March 1, 2024 and March 15, 2024?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
start_date = self._random_date_between(rng, year_start, year_end)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
end_date = self._random_date_between(rng, start_date, start_date + max_delta)
|
||||
|
||||
count = 0
|
||||
|
||||
def business_days_between(d1: date, d2: date) -> int:
|
||||
days = (d2 - d1).days + 1
|
||||
weeks, remainder = divmod(days, 7)
|
||||
count = weeks * 5
|
||||
start_weekday = d1.weekday()
|
||||
for i in range(remainder):
|
||||
if (start_weekday + i) % 7 < 5:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
count = business_days_between(start_date, end_date)
|
||||
|
||||
templates = [
|
||||
f"How many business days (Monday-Friday) are there from "
|
||||
f"{start_date.strftime('%A, %B')} {start_date.day}, {year} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(inclusive of both dates)? ",
|
||||
f"Count the weekdays (excluding weekends) from "
|
||||
f"{start_date.strftime('%A, %B')} {start_date.day} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(including both start and end dates). ",
|
||||
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(counting both dates), what's the total count of business days "
|
||||
f"(Monday through Friday)? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.COUNT_BUSINESS_DAYS.value,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given a year, determine whether it is a leap year.
|
||||
example:
|
||||
"Is 2024 a leap year? Answer with Yes or No."
|
||||
"""
|
||||
semirange = self.config.leap_year_range // 2
|
||||
year = rng.randint(self.config.year - semirange, self.config.year + semirange)
|
||||
is_leap = calendar.isleap(year)
|
||||
answer = "Yes" if is_leap else "No"
|
||||
templates = [
|
||||
f"Determine if the year {year} is a leap year. ",
|
||||
f"Is {year} a leap year? ",
|
||||
f"Tell me whether {year} is a leap year. ",
|
||||
]
|
||||
question = rng.choice(templates) + "Answer with Yes or No."
|
||||
metadata = {
|
||||
"task": CalendarTask.IS_LEAP_YEAR.value,
|
||||
"year": year,
|
||||
"is_leap": is_leap,
|
||||
}
|
||||
return question, answer, metadata
|
||||
|
||||
def _random_date_for_year(self, rng: random.Random, year: int) -> date:
|
||||
"""Return a random date within the given year."""
|
||||
month = rng.randint(1, 12)
|
||||
_, last_day = calendar.monthrange(year, month)
|
||||
day = rng.randint(1, last_day)
|
||||
return date(year, month, day)
|
||||
|
||||
def _random_date_between(self, rng: random.Random, start_date: date, end_date: date) -> date:
|
||||
"""
|
||||
Return a random date between start_date and end_date (inclusive).
|
||||
Assumes start_date <= end_date.
|
||||
"""
|
||||
if start_date > end_date:
|
||||
raise ValueError("start_date must be <= end_date")
|
||||
delta = (end_date - start_date).days
|
||||
random_days = rng.randint(0, delta)
|
||||
return start_date + timedelta(days=random_days)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
# we suppose the answer is the last occurence of the expected answer type
|
||||
if answer is None:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
task = entry["metadata"]["task"]
|
||||
|
||||
if task in {
|
||||
CalendarTask.WEEKDAY_OFFSET.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
}:
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
answer = answer.strip()
|
||||
oracle_answer = oracle_answer
|
||||
weekdays = {d.name.title() for d in Weekday}
|
||||
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
|
||||
if answer in weekdays:
|
||||
return 0.1
|
||||
|
||||
if answer.title() in weekdays:
|
||||
return 0.05
|
||||
|
||||
if answer.title() not in weekdays:
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
|
||||
# denser reward for numerical tasks
|
||||
elif task in {
|
||||
CalendarTask.COUNT_BUSINESS_DAYS.value,
|
||||
CalendarTask.COUNT_DAYS.value,
|
||||
CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
|
||||
}:
|
||||
try:
|
||||
ans_num = int(answer.strip())
|
||||
oracle_num = int(oracle_answer.strip())
|
||||
|
||||
if oracle_num == 0:
|
||||
return 1.0 if ans_num == 0 else 0.0
|
||||
|
||||
relative_error = abs(ans_num - oracle_num) / oracle_num
|
||||
return max(0.0, math.exp(-5 * relative_error))
|
||||
|
||||
except (ValueError, AttributeError):
|
||||
return 0.0
|
||||
|
||||
elif task == CalendarTask.IS_LEAP_YEAR.value:
|
||||
if answer.strip().lower() == oracle_answer.lower():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)
|
||||
323
reasoning_gym/arithmetic/time_intervals.py
Normal file
323
reasoning_gym/arithmetic/time_intervals.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeIntervalsConfig:
|
||||
"""Configuration for time interval calculation tasks"""
|
||||
|
||||
min_time: time = time.min
|
||||
max_time: time = time.max
|
||||
max_time_difference_seconds: int = 24 * 60 * 60
|
||||
min_date: date = date(1900, 1, 1)
|
||||
max_date: date = date(3000, 1, 1)
|
||||
max_date_difference_days: int = 100
|
||||
task_types: List[str] = field(
|
||||
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.max_time_difference_seconds > 0, "max_time_difference_seconds must be positive"
|
||||
assert self.max_date_difference_days > 0, "max_date_difference_days must be positive"
|
||||
assert self.min_date < self.max_date, "min_date must be before max_date"
|
||||
|
||||
|
||||
class TimeIntervalsDataset(ProceduralDataset):
|
||||
"""Generates time interval calculation tasks with various formats and complexities"""
|
||||
|
||||
TEMPLATES = [
|
||||
"What is the duration between {start} and {end}? Please answer in {format}.",
|
||||
"Calculate the time difference between {start} and {end}. Express the result in {format}.",
|
||||
"How much time elapsed from {start} to {end}? Give your answer in {format}.",
|
||||
"A meeting started at {start} and ended at {end}. How long was the meeting? Answer in {format}.",
|
||||
"A system operation started at {start} and completed at {end}. What was the operation duration? Answer in {format}.",
|
||||
"A database query started at {start} and ended at {end}. How long did the query take? Answer in {format}.",
|
||||
"A flight departed at {start} and arrived at {end}. How long was the flight? Answer in {format}.",
|
||||
"A video call started at {start} and ended at {end}. How long was the call? Answer in {format}.",
|
||||
"A system backup started at {start} and completed at {end}. What was the total backup duration? Answer in {format}.",
|
||||
"A conference call began at {start} and ended at {end}. How long was the conference? Answer in {format}.",
|
||||
]
|
||||
|
||||
TIME_FORMATS = [
|
||||
"%H:%M",
|
||||
"%H:%M:%S",
|
||||
"%H:%M:%S.%f",
|
||||
]
|
||||
|
||||
DATE_FORMATS = [
|
||||
"%Y-%m-%d",
|
||||
"%B %d, %Y",
|
||||
"%m/%d/%Y",
|
||||
"%A, %B %d, %Y", # e.g. Monday, January 15, 2024
|
||||
"%a %b %d %Y", # e.g. Mon Jan 15 2024
|
||||
"%d %B %Y", # e.g. 15 January 2024
|
||||
"%Y-%m-%d (%A)", # e.g. 2024-01-15 (Monday)
|
||||
]
|
||||
|
||||
DATETIME_FORMATS = [
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M %z", # For UTC offset format
|
||||
"%Y-%m-%d %H:%M:%S %z", # For UTC offset with seconds
|
||||
"%A, %B %d, %Y at %H:%M", # e.g. Monday, January 15, 2024 at 14:30
|
||||
"%a %b %d %Y %H:%M:%S", # e.g. Mon Jan 15 2024 14:30:45
|
||||
"%d %B %Y, %H:%M", # e.g. 15 January 2024, 14:30
|
||||
"%d %B %Y, %H:%M %z", # e.g. 15 January 2024, 14:30 +0000
|
||||
"%Y-%m-%d (%A) %H:%M:%S %z", # e.g. 2024-01-15 (Monday) 14:30:45 +0000
|
||||
]
|
||||
|
||||
def __init__(self, config: TimeIntervalsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single time interval calculation task"""
|
||||
item_rng = random.Random(self.seed + idx)
|
||||
|
||||
# Randomly choose task type from config
|
||||
task_type = item_rng.choice(self.config.task_types)
|
||||
|
||||
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type)
|
||||
|
||||
template = item_rng.choice(self.TEMPLATES)
|
||||
question = template.format(start=start_time, end=end_time, format=expected_format)
|
||||
|
||||
# Calculate the actual difference
|
||||
if isinstance(start_time, str):
|
||||
# Handle datetime strings with weekday names in parentheses
|
||||
start_time = start_time.split(" (")[0] # Remove (Weekday) if present
|
||||
end_time = end_time.split(" (")[0]
|
||||
# Parse with UTC offset handling
|
||||
start_dt = parser.parse(start_time)
|
||||
end_dt = parser.parse(end_time)
|
||||
else:
|
||||
start_dt = start_time
|
||||
end_dt = end_time
|
||||
|
||||
difference = end_dt - start_dt
|
||||
|
||||
# Format the answer according to expected_format
|
||||
if expected_format == "HH:MM":
|
||||
total_seconds = difference.total_seconds()
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}"
|
||||
elif expected_format == "HH:MM:SS":
|
||||
total_seconds = difference.total_seconds()
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}"
|
||||
elif expected_format == "HH:MM:SS.mmm":
|
||||
total_seconds = difference.total_seconds()
|
||||
ms = int((total_seconds % 1) * 1000)
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}.{ms:03d}"
|
||||
elif expected_format == "D days":
|
||||
answer = f"{difference.days} days"
|
||||
else: # "D days, HH:MM" or "D days, HH:MM:SS"
|
||||
days = difference.days
|
||||
hours = difference.seconds // 3600
|
||||
minutes = (difference.seconds % 3600) // 60
|
||||
seconds = difference.seconds % 60
|
||||
if expected_format == "D days, HH:MM:SS":
|
||||
answer = f"{days} days, {hours:02d}:{minutes:02d}:{seconds:02d}"
|
||||
else: # "D days, HH:MM"
|
||||
answer = f"{days} days, {hours:02d}:{minutes:02d}"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"task_type": task_type,
|
||||
"start_time": start_dt,
|
||||
"end_time": end_dt,
|
||||
"format": format_str,
|
||||
"expected_format": expected_format,
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_times(self, rng: random.Random, task_type: str):
|
||||
"""Generate start and end times based on task type"""
|
||||
if task_type.startswith("time"):
|
||||
if task_type == "time_ms":
|
||||
format_str = self.TIME_FORMATS[2] # Get milliseconds format
|
||||
expected_format = "HH:MM:SS.mmm"
|
||||
else:
|
||||
format_str = next(f for f in self.TIME_FORMATS if f.count(":") == (2 if "seconds" in task_type else 1))
|
||||
expected_format = "HH:MM:SS" if "seconds" in task_type else "HH:MM"
|
||||
|
||||
# Generate random start time
|
||||
start_hour = rng.randint(0, 23)
|
||||
start_minute = rng.randint(0, 59)
|
||||
start_second = rng.randint(0, 59)
|
||||
base = datetime.combine(date.today(), time(start_hour, start_minute, start_second))
|
||||
|
||||
# Calculate seconds remaining until midnight
|
||||
seconds_until_midnight = ((24 - start_hour) * 3600) - (start_minute * 60) - start_second
|
||||
# Use the minimum of config max and seconds until midnight
|
||||
max_seconds = min(self.config.max_time_difference_seconds, seconds_until_midnight)
|
||||
diff_seconds = rng.randint(1, max_seconds) if max_seconds > 0 else 0
|
||||
|
||||
if task_type == "time_ms":
|
||||
# Add microseconds for millisecond precision
|
||||
base = base.replace(microsecond=rng.randint(0, 999) * 1000)
|
||||
end_time = base + timedelta(seconds=diff_seconds, microseconds=rng.randint(0, 999) * 1000)
|
||||
# Format with exactly 3 decimal places for milliseconds
|
||||
start_time = base.strftime(format_str)[:-3] # Remove extra microsecond digits
|
||||
end_time = end_time.strftime(format_str)[:-3] # Remove extra microsecond digits
|
||||
else:
|
||||
start_time = base.strftime(format_str)
|
||||
end_time = (base + timedelta(seconds=diff_seconds)).strftime(format_str)
|
||||
|
||||
elif task_type == "date":
|
||||
format_str = rng.choice(self.DATE_FORMATS)
|
||||
expected_format = "D days" # Always return number of days for date tasks
|
||||
|
||||
# Generate random start date within configured range, leaving room for end date
|
||||
max_date_difference_days = min(
|
||||
self.config.max_date_difference_days, (self.config.max_date - self.config.min_date).days
|
||||
)
|
||||
max_start_days = (self.config.max_date - self.config.min_date).days - max_date_difference_days
|
||||
start_days = rng.randint(0, max_start_days - 1)
|
||||
start_date = self.config.min_date + timedelta(days=start_days)
|
||||
|
||||
# Ensure positive difference between dates
|
||||
diff_days = rng.randint(0, max_date_difference_days)
|
||||
end_date = start_date + timedelta(days=diff_days)
|
||||
|
||||
start_time = start_date.strftime(format_str)
|
||||
end_time = end_date.strftime(format_str)
|
||||
|
||||
else: # datetime or datetime_tz
|
||||
format_str = rng.choice(self.DATETIME_FORMATS)
|
||||
# Choose between HH:MM and HH:MM:SS format for datetime answers
|
||||
expected_format = rng.choice(["D days, HH:MM", "D days, HH:MM:SS"])
|
||||
|
||||
# Generate random start datetime
|
||||
days_range = (self.config.max_date - self.config.min_date).days
|
||||
start_days = rng.randint(0, days_range)
|
||||
start_hour = rng.randint(0, 23)
|
||||
start_minute = rng.randint(0, 59)
|
||||
start_second = rng.randint(0, 59)
|
||||
|
||||
# Generate random time differences first
|
||||
diff_days = rng.randint(0, self.config.max_date_difference_days)
|
||||
diff_seconds = rng.randint(1, self.config.max_time_difference_seconds)
|
||||
|
||||
if "%z" in format_str:
|
||||
# Use simpler timezone format with offset
|
||||
base = datetime.combine(
|
||||
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
|
||||
)
|
||||
# Generate timezone offsets
|
||||
start_offset = rng.randint(-12, 12)
|
||||
end_offset = rng.randint(-12, 12)
|
||||
|
||||
# Apply start timezone
|
||||
base = base.replace(tzinfo=pytz.FixedOffset(start_offset * 60))
|
||||
start_format = format_str.replace("%z", "%+05d" % (start_offset * 100))
|
||||
|
||||
# Calculate end time and convert to end timezone
|
||||
end_dt = base + timedelta(days=diff_days, seconds=diff_seconds)
|
||||
end_dt = end_dt.replace(tzinfo=pytz.FixedOffset(end_offset * 60))
|
||||
end_format = format_str.replace("%z", "%+05d" % (end_offset * 100))
|
||||
|
||||
# Format times with their respective timezone offsets
|
||||
start_time = base.strftime(start_format).rstrip()
|
||||
end_time = end_dt.strftime(end_format).rstrip()
|
||||
else:
|
||||
base = datetime.combine(
|
||||
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
|
||||
)
|
||||
# For non-timezone aware times, both use same format
|
||||
start_time = base.strftime(format_str).rstrip()
|
||||
end_time = (base + timedelta(days=diff_days, seconds=diff_seconds)).strftime(format_str).rstrip()
|
||||
|
||||
return start_time, end_time, format_str, expected_format
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict) -> float:
|
||||
"""Score an answer based on how close it is to the expected duration
|
||||
|
||||
Returns a score between 0 and 1, with partial credit for answers that are
|
||||
close to correct in the appropriate units/format
|
||||
"""
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
expected = entry["answer"]
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
|
||||
try:
|
||||
if task_type == "date":
|
||||
# Parse "X days" format
|
||||
try:
|
||||
actual = int(answer.strip().split()[0]) # Get number before "days"
|
||||
expected = int(expected.strip().split()[0])
|
||||
if actual == expected:
|
||||
return 1.0
|
||||
# Partial credit based on how close the day count is
|
||||
max_diff = self.config.max_date_difference_days
|
||||
diff = abs(actual - expected)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
except (ValueError, IndexError):
|
||||
return 0.0
|
||||
|
||||
elif task_type.startswith("time"):
|
||||
# Parse times into total seconds for comparison
|
||||
def parse_time(t):
|
||||
parts = t.strip().split(":")
|
||||
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
|
||||
if len(parts) > 2:
|
||||
if "." in parts[2]: # Has milliseconds
|
||||
s, ms = parts[2].split(".")
|
||||
seconds += int(s) + int(ms) / 1000
|
||||
else:
|
||||
seconds += int(parts[2])
|
||||
return seconds
|
||||
|
||||
actual_seconds = parse_time(answer)
|
||||
expected_seconds = parse_time(expected)
|
||||
|
||||
if actual_seconds == expected_seconds:
|
||||
return 1.0
|
||||
|
||||
# Partial credit based on how close the times are
|
||||
max_diff = self.config.max_time_difference_seconds
|
||||
diff = abs(actual_seconds - expected_seconds)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
|
||||
else: # datetime or datetime_tz
|
||||
# Parse the complex format "X days, HH:MM" or "X days, HH:MM:SS"
|
||||
def parse_datetime(t):
|
||||
days = int(t.split(" days,")[0])
|
||||
time_part = t.split(",")[1].strip()
|
||||
parts = time_part.split(":")
|
||||
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
|
||||
if len(parts) > 2:
|
||||
seconds += int(parts[2])
|
||||
return days * 86400 + seconds
|
||||
|
||||
actual_seconds = parse_datetime(answer)
|
||||
expected_seconds = parse_datetime(expected)
|
||||
|
||||
if actual_seconds == expected_seconds:
|
||||
return 1.0
|
||||
|
||||
# Partial credit based on total time difference
|
||||
max_diff = self.config.max_date_difference_days * 86400
|
||||
diff = abs(actual_seconds - expected_seconds)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return 0.0 # Invalid format
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig)
|
||||
Loading…
Add table
Add a link
Reference in a new issue