Merge branch 'main' into koko/gsm-symbolic-task-1

This commit is contained in:
Adefioye 2025-02-03 01:23:26 -06:00 committed by GitHub
commit a80339a0e6
49 changed files with 6334 additions and 147 deletions

View file

@ -4,9 +4,11 @@ Arithmetic tasks for training reasoning capabilities:
- Chain sums
- Word problems
- Leg counting
- Time intervals
"""
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
@ -14,6 +16,7 @@ from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
from .gsm_symbolic.gsm_symbolic_datasets import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
__all__ = [
"BasicArithmeticDataset",
@ -21,6 +24,10 @@ __all__ = [
"basic_arithmetic_dataset",
"ChainSum",
"ChainSumConfig",
"CalendarArithmeticConfig",
"CalendarArithmeticDataset",
"Weekday",
"CalendarTask",
"FractionSimplificationConfig",
"FractionSimplificationDataset",
"GCDConfig",
@ -33,4 +40,6 @@ __all__ = [
"PrimeFactorizationDataset",
"GSMSymbolicDatasetConfig",
"GSMSymbolicDataset",
"TimeIntervalsConfig",
"TimeIntervalsDataset",
]

View file

@ -0,0 +1,490 @@
import calendar
import math
import random
from dataclasses import dataclass
from datetime import date, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
class Weekday(Enum):
MONDAY = auto()
TUESDAY = auto()
WEDNESDAY = auto()
THURSDAY = auto()
FRIDAY = auto()
SATURDAY = auto()
SUNDAY = auto()
@classmethod
def from_date(cls, d: date) -> "Weekday":
return list(cls)[d.weekday()]
@classmethod
def random(cls, rng: random.Random) -> "Weekday":
return list(cls)[rng.randint(0, 6)]
@classmethod
def __getitem__(cls, idx) -> "Weekday":
return list(cls)[idx]
@property
def index(self) -> int:
return self.value - 1
def __str__(self) -> str:
return self.name.capitalize()
class CalendarTask(Enum):
WEEKDAY_OFFSET = "weekday_offset"
WEEKDAY_OF_DATE = "weekday_of_date"
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
COUNT_DAYS = "count_days"
COUNT_BUSINESS_DAYS = "count_business_days"
IS_LEAP_YEAR = "is_leap_year"
@dataclass
class CalendarArithmeticConfig:
year: int = 2022
tasks: Optional[List[str]] = None
offset_upper_bound: int = 100
leap_year_range: int = 200
seed: Optional[int] = 42
size: int = 500
def __post_init__(self):
if self.tasks is None:
self.tasks = [task.value for task in CalendarTask]
else:
self.tasks = [task.lower() for task in self.tasks]
valid_tasks = {task.value for task in CalendarTask}
invalid_tasks = set(self.tasks) - valid_tasks
if invalid_tasks:
valid_task_list = ", ".join(sorted(valid_tasks))
raise ValueError(
f"Invalid tasks: {', '.join(sorted(invalid_tasks))}. " f"Valid tasks are: {valid_task_list}"
)
def validate(self) -> None:
"""Validate the configuration parameters."""
if not isinstance(self.year, int) or self.year <= 0:
raise ValueError(f"year must be a positive integer, got {self.year}")
if self.seed is not None and not isinstance(self.seed, int):
raise ValueError(f"seed must be an integer or None, got {type(self.seed)}")
if not isinstance(self.size, int) or self.size <= 0:
raise ValueError(f"size must be a positive integer, got {self.size}")
class CalendarArithmeticDataset(ProceduralDataset):
DAY_QUESTION_TEMPLATES = [
"Answer with the weekday's name (e.g., Monday, Tuesday, etc.).",
"Provide the full name of the weekday.",
"State the weekday (Monday through Sunday).",
"Give the weekday name in full.",
"Reply with just the weekday name.",
"Write out the full weekday name.",
"Respond with the weekday (Monday-Sunday).",
"Answer using the complete weekday name.",
"Name the day of the week in full.",
]
COUNT_QUESTION_TEMPLATES = [
"Answer with a number.",
"Provide the count as a number.",
"Respond with just the number.",
"Write the total number.",
"Give the count numerically.",
"State the amount as a number.",
"Reply with the numerical value.",
"Express your answer as a number.",
]
def __init__(self, config: CalendarArithmeticConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.task_handlers = {
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
CalendarTask.COUNT_DAYS.value: self._count_days,
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
CalendarTask.IS_LEAP_YEAR.value: self._is_leap_year,
}
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
def __getitem__(self, idx: int) -> dict:
item_rng = random.Random(self.seed + idx)
task = item_rng.choice(self.tasks)
question, answer, metadata = task(item_rng)
return {
"question": question,
"answer": str(answer),
"metadata": metadata,
}
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
Task: Given a starting date and a day offset (which may be positive or negative),
ask what day of the week it will be.
Examples:
- "If today is Wednesday, March 13, 2024, what day of the week will it be in 10 days? Answer with the weekday's name."
- "If today is Wednesday, March 13, 2024, what day of the week was it 10 days ago? Answer with the weekday's name."
"""
year = self.config.year
start_date = self._random_date_for_year(rng, year)
offset = rng.randint(1, self.config.offset_upper_bound)
sign = rng.choice([-1, 1])
offset_days = sign * offset
target_date = start_date + timedelta(days=offset_days)
target_weekday = target_date.strftime("%A")
date_str = f"{start_date.strftime('%A')}, {start_date.strftime('%B')} {start_date.day}, {start_date.year}"
if offset_days >= 0:
templates = [
f"If today is {date_str}, what day of the week will it be in {offset_days} days? ",
f"Starting from {date_str}, which weekday falls after a {offset_days}-day jump? ",
f"Count forward {offset_days} days from {date_str} - what's the weekday? ",
]
else:
templates = [
f"If today is {date_str}, what day of the week was it {abs(offset_days)} days ago? ",
f"Starting from {date_str}, which weekday was it {abs(offset_days)} days before? ",
f"Count backward {abs(offset_days)} days from {date_str} - what's the weekday? ",
]
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.WEEKDAY_OFFSET.value,
"start_date": start_date.isoformat(),
"offset_days": offset_days,
"target_date": target_date.isoformat(),
}
return question, target_weekday, metadata
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Ask what day of the week a given date was.
example:
"What day of the week was January 15, 2024?
Answer with the weekday's name."
"""
year = self.config.year
target_date = self._random_date_for_year(rng, year)
answer_weekday = target_date.strftime("%A")
templates = [
f"What day of the week was {target_date.strftime('%B')} {target_date.day}, {year}?",
f"On which weekday did {target_date.strftime('%B')} {target_date.day}, {year} fall?",
f"Name the day of the week for {target_date.strftime('%m/%d/%Y')}.",
]
question = f"{rng.choice(templates)} {rng.choice(self.DAY_QUESTION_TEMPLATES)}"
metadata = {
"task": CalendarTask.WEEKDAY_OF_DATE.value,
"target_date": target_date.isoformat(),
}
return question, answer_weekday, metadata
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
example:
"If the first day of the year was a Monday, what day of the week will December 31 be?
Answer with the weekday's name."
"""
year = self.config.year
first_day = Weekday.random(rng)
first_day_index = first_day.index
# Ensure target date is not January 1.
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
max_delta = timedelta(days=self.config.offset_upper_bound)
max_date = min(year_start + max_delta, year_end)
while True:
target_date = self._random_date_between(rng, year_start, max_date)
if target_date != date(year, 1, 1):
break
delta_days = (target_date - date(year, 1, 1)).days
answer_index = (first_day_index + delta_days) % 7
answer_weekday = Weekday(answer_index + 1)
templates = [
f"If the first day of the year was a {first_day}, what day of the week will "
f"{target_date.strftime('%B')} {target_date.day} be? ",
f"Given that January 1 fell on a {first_day}, which weekday occurs on "
f"{target_date.strftime('%B')} {target_date.day}? ",
f"In a year where {first_day} is January 1st, name the weekday of "
f"{target_date.strftime('%B')} {target_date.day}. ",
]
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
"year": year,
"first_day": str(first_day),
"target_date": target_date.isoformat(),
"delta_days": delta_days,
}
return question, answer_weekday, metadata
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: For a recurring event defined by an ordinal weekday pattern in a month,
ask on which day of the month the event occurs.
example:
"If a meeting is scheduled on the second Tuesday of May 2024, on which day does it fall?
Answer with a number."
"""
year = self.config.year
month = rng.randint(1, 12)
ordinals = ["first", "second", "third", "fourth", "last"]
ordinal = rng.choice(ordinals)
weekday = Weekday.random(rng)
month_name = calendar.month_name[month]
_, last_day = calendar.monthrange(year, month)
if ordinal != "last":
ordinal_number = {"first": 1, "second": 2, "third": 3, "fourth": 4}[ordinal]
count = 0
event_day = None
for day in range(1, last_day + 1):
d = date(year, month, day)
if d.strftime("%A") == str(weekday):
count += 1
if count == ordinal_number:
event_day = day
break
if event_day is None:
# This should rarely happen but in some months the ordinal may not exist.
event_day = -1
else:
event_day = None
for day in range(last_day, 0, -1):
d = date(year, month, day)
if d.strftime("%A") == str(weekday):
event_day = day
break
if event_day is None:
event_day = -1
templates = [
f"If a meeting is scheduled on the {ordinal} {weekday} of {month_name} {year}, on which day of the month does it occur? ",
f"In {month_name} {year}, if an event recurs on the {ordinal} {weekday}, what is the date (day of the month) of the event? ",
f"Determine the day of the month for the {ordinal} {weekday} in {month_name} {year}. ",
]
question = (
rng.choice(templates)
+ rng.choice(self.COUNT_QUESTION_TEMPLATES)
+ " Answer with -1 if the ordinal does not exist in the month."
)
metadata = {
"task": CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
"year": year,
"month": month,
"ordinal": ordinal,
"weekday": str(weekday),
}
return question, str(event_day), metadata
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Ask how many times a given weekday occurs in a specified range.
example:
"How many days are there between March 1, 2024 and March 15, 2024?
Answer with a number."
"""
year = self.config.year
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
start_date = self._random_date_between(rng, year_start, year_end)
max_delta = timedelta(days=self.config.offset_upper_bound)
end_date = self._random_date_between(rng, start_date, min(year_end, start_date + max_delta))
weekday = Weekday.random(rng)
def count_weekday_between(d1: date, d2: date, weekday: str) -> int:
days = (d2 - d1).days + 1
return sum(1 for i in range(days) if (d1 + timedelta(days=i)).strftime("%A") == weekday)
count = count_weekday_between(start_date, end_date, str(weekday))
templates = [
f"How many {weekday}s are there from {start_date.strftime('%A, %B')} {start_date.day}, {year} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} (inclusive of both dates)? ",
f"Count the occurrences of {weekday} from {start_date.strftime('%A, %B')} {start_date.day} "
f"to {end_date.strftime('%A, %B')} {end_date.day}, {year} (including both start and end dates). ",
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(counting both dates), how many times does {weekday} occur? ",
]
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.COUNT_DAYS.value,
"year": year,
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
}
return question, str(count), metadata
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Count the number of business days (Monday-Friday) between two dates.
example:
"How many business days (Monday-Friday) are there between March 1, 2024 and March 15, 2024?
Answer with a number."
"""
year = self.config.year
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
start_date = self._random_date_between(rng, year_start, year_end)
max_delta = timedelta(days=self.config.offset_upper_bound)
end_date = self._random_date_between(rng, start_date, start_date + max_delta)
count = 0
def business_days_between(d1: date, d2: date) -> int:
days = (d2 - d1).days + 1
weeks, remainder = divmod(days, 7)
count = weeks * 5
start_weekday = d1.weekday()
for i in range(remainder):
if (start_weekday + i) % 7 < 5:
count += 1
return count
count = business_days_between(start_date, end_date)
templates = [
f"How many business days (Monday-Friday) are there from "
f"{start_date.strftime('%A, %B')} {start_date.day}, {year} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(inclusive of both dates)? ",
f"Count the weekdays (excluding weekends) from "
f"{start_date.strftime('%A, %B')} {start_date.day} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(including both start and end dates). ",
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(counting both dates), what's the total count of business days "
f"(Monday through Friday)? ",
]
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.COUNT_BUSINESS_DAYS.value,
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
}
return question, str(count), metadata
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Given a year, determine whether it is a leap year.
example:
"Is 2024 a leap year? Answer with Yes or No."
"""
semirange = self.config.leap_year_range // 2
year = rng.randint(self.config.year - semirange, self.config.year + semirange)
is_leap = calendar.isleap(year)
answer = "Yes" if is_leap else "No"
templates = [
f"Determine if the year {year} is a leap year. ",
f"Is {year} a leap year? ",
f"Tell me whether {year} is a leap year. ",
]
question = rng.choice(templates) + "Answer with Yes or No."
metadata = {
"task": CalendarTask.IS_LEAP_YEAR.value,
"year": year,
"is_leap": is_leap,
}
return question, answer, metadata
def _random_date_for_year(self, rng: random.Random, year: int) -> date:
"""Return a random date within the given year."""
month = rng.randint(1, 12)
_, last_day = calendar.monthrange(year, month)
day = rng.randint(1, last_day)
return date(year, month, day)
def _random_date_between(self, rng: random.Random, start_date: date, end_date: date) -> date:
"""
Return a random date between start_date and end_date (inclusive).
Assumes start_date <= end_date.
"""
if start_date > end_date:
raise ValueError("start_date must be <= end_date")
delta = (end_date - start_date).days
random_days = rng.randint(0, delta)
return start_date + timedelta(days=random_days)
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
return 0.0
oracle_answer = entry["answer"]
task = entry["metadata"]["task"]
if task in {
CalendarTask.WEEKDAY_OFFSET.value,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
CalendarTask.WEEKDAY_OF_DATE.value,
}:
if not answer:
return 0.0
answer = answer.strip()
oracle_answer = oracle_answer
weekdays = {d.name.title() for d in Weekday}
if answer == oracle_answer:
return 1.0
if answer in weekdays:
return 0.1
if answer.title() in weekdays:
return 0.05
if answer.title() not in weekdays:
return 0.0
return 0.0
# denser reward for numerical tasks
elif task in {
CalendarTask.COUNT_BUSINESS_DAYS.value,
CalendarTask.COUNT_DAYS.value,
CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
}:
try:
ans_num = int(answer.strip())
oracle_num = int(oracle_answer.strip())
if oracle_num == 0:
return 1.0 if ans_num == 0 else 0.0
relative_error = abs(ans_num - oracle_num) / oracle_num
return max(0.0, math.exp(-5 * relative_error))
except (ValueError, AttributeError):
return 0.0
elif task == CalendarTask.IS_LEAP_YEAR.value:
if answer.strip().lower() == oracle_answer.lower():
return 1.0
return 0.0
return 0.0
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)

View file

@ -0,0 +1,323 @@
import random
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import List, Optional
import pytz
from dateutil import parser
from ..factory import ProceduralDataset, register_dataset
@dataclass
class TimeIntervalsConfig:
"""Configuration for time interval calculation tasks"""
min_time: time = time.min
max_time: time = time.max
max_time_difference_seconds: int = 24 * 60 * 60
min_date: date = date(1900, 1, 1)
max_date: date = date(3000, 1, 1)
max_date_difference_days: int = 100
task_types: List[str] = field(
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
)
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.max_time_difference_seconds > 0, "max_time_difference_seconds must be positive"
assert self.max_date_difference_days > 0, "max_date_difference_days must be positive"
assert self.min_date < self.max_date, "min_date must be before max_date"
class TimeIntervalsDataset(ProceduralDataset):
"""Generates time interval calculation tasks with various formats and complexities"""
TEMPLATES = [
"What is the duration between {start} and {end}? Please answer in {format}.",
"Calculate the time difference between {start} and {end}. Express the result in {format}.",
"How much time elapsed from {start} to {end}? Give your answer in {format}.",
"A meeting started at {start} and ended at {end}. How long was the meeting? Answer in {format}.",
"A system operation started at {start} and completed at {end}. What was the operation duration? Answer in {format}.",
"A database query started at {start} and ended at {end}. How long did the query take? Answer in {format}.",
"A flight departed at {start} and arrived at {end}. How long was the flight? Answer in {format}.",
"A video call started at {start} and ended at {end}. How long was the call? Answer in {format}.",
"A system backup started at {start} and completed at {end}. What was the total backup duration? Answer in {format}.",
"A conference call began at {start} and ended at {end}. How long was the conference? Answer in {format}.",
]
TIME_FORMATS = [
"%H:%M",
"%H:%M:%S",
"%H:%M:%S.%f",
]
DATE_FORMATS = [
"%Y-%m-%d",
"%B %d, %Y",
"%m/%d/%Y",
"%A, %B %d, %Y", # e.g. Monday, January 15, 2024
"%a %b %d %Y", # e.g. Mon Jan 15 2024
"%d %B %Y", # e.g. 15 January 2024
"%Y-%m-%d (%A)", # e.g. 2024-01-15 (Monday)
]
DATETIME_FORMATS = [
"%Y-%m-%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M %z", # For UTC offset format
"%Y-%m-%d %H:%M:%S %z", # For UTC offset with seconds
"%A, %B %d, %Y at %H:%M", # e.g. Monday, January 15, 2024 at 14:30
"%a %b %d %Y %H:%M:%S", # e.g. Mon Jan 15 2024 14:30:45
"%d %B %Y, %H:%M", # e.g. 15 January 2024, 14:30
"%d %B %Y, %H:%M %z", # e.g. 15 January 2024, 14:30 +0000
"%Y-%m-%d (%A) %H:%M:%S %z", # e.g. 2024-01-15 (Monday) 14:30:45 +0000
]
def __init__(self, config: TimeIntervalsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single time interval calculation task"""
item_rng = random.Random(self.seed + idx)
# Randomly choose task type from config
task_type = item_rng.choice(self.config.task_types)
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type)
template = item_rng.choice(self.TEMPLATES)
question = template.format(start=start_time, end=end_time, format=expected_format)
# Calculate the actual difference
if isinstance(start_time, str):
# Handle datetime strings with weekday names in parentheses
start_time = start_time.split(" (")[0] # Remove (Weekday) if present
end_time = end_time.split(" (")[0]
# Parse with UTC offset handling
start_dt = parser.parse(start_time)
end_dt = parser.parse(end_time)
else:
start_dt = start_time
end_dt = end_time
difference = end_dt - start_dt
# Format the answer according to expected_format
if expected_format == "HH:MM":
total_seconds = difference.total_seconds()
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}"
elif expected_format == "HH:MM:SS":
total_seconds = difference.total_seconds()
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}"
elif expected_format == "HH:MM:SS.mmm":
total_seconds = difference.total_seconds()
ms = int((total_seconds % 1) * 1000)
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}.{ms:03d}"
elif expected_format == "D days":
answer = f"{difference.days} days"
else: # "D days, HH:MM" or "D days, HH:MM:SS"
days = difference.days
hours = difference.seconds // 3600
minutes = (difference.seconds % 3600) // 60
seconds = difference.seconds % 60
if expected_format == "D days, HH:MM:SS":
answer = f"{days} days, {hours:02d}:{minutes:02d}:{seconds:02d}"
else: # "D days, HH:MM"
answer = f"{days} days, {hours:02d}:{minutes:02d}"
return {
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type,
"start_time": start_dt,
"end_time": end_dt,
"format": format_str,
"expected_format": expected_format,
},
}
def _generate_times(self, rng: random.Random, task_type: str):
"""Generate start and end times based on task type"""
if task_type.startswith("time"):
if task_type == "time_ms":
format_str = self.TIME_FORMATS[2] # Get milliseconds format
expected_format = "HH:MM:SS.mmm"
else:
format_str = next(f for f in self.TIME_FORMATS if f.count(":") == (2 if "seconds" in task_type else 1))
expected_format = "HH:MM:SS" if "seconds" in task_type else "HH:MM"
# Generate random start time
start_hour = rng.randint(0, 23)
start_minute = rng.randint(0, 59)
start_second = rng.randint(0, 59)
base = datetime.combine(date.today(), time(start_hour, start_minute, start_second))
# Calculate seconds remaining until midnight
seconds_until_midnight = ((24 - start_hour) * 3600) - (start_minute * 60) - start_second
# Use the minimum of config max and seconds until midnight
max_seconds = min(self.config.max_time_difference_seconds, seconds_until_midnight)
diff_seconds = rng.randint(1, max_seconds) if max_seconds > 0 else 0
if task_type == "time_ms":
# Add microseconds for millisecond precision
base = base.replace(microsecond=rng.randint(0, 999) * 1000)
end_time = base + timedelta(seconds=diff_seconds, microseconds=rng.randint(0, 999) * 1000)
# Format with exactly 3 decimal places for milliseconds
start_time = base.strftime(format_str)[:-3] # Remove extra microsecond digits
end_time = end_time.strftime(format_str)[:-3] # Remove extra microsecond digits
else:
start_time = base.strftime(format_str)
end_time = (base + timedelta(seconds=diff_seconds)).strftime(format_str)
elif task_type == "date":
format_str = rng.choice(self.DATE_FORMATS)
expected_format = "D days" # Always return number of days for date tasks
# Generate random start date within configured range, leaving room for end date
max_date_difference_days = min(
self.config.max_date_difference_days, (self.config.max_date - self.config.min_date).days
)
max_start_days = (self.config.max_date - self.config.min_date).days - max_date_difference_days
start_days = rng.randint(0, max_start_days - 1)
start_date = self.config.min_date + timedelta(days=start_days)
# Ensure positive difference between dates
diff_days = rng.randint(0, max_date_difference_days)
end_date = start_date + timedelta(days=diff_days)
start_time = start_date.strftime(format_str)
end_time = end_date.strftime(format_str)
else: # datetime or datetime_tz
format_str = rng.choice(self.DATETIME_FORMATS)
# Choose between HH:MM and HH:MM:SS format for datetime answers
expected_format = rng.choice(["D days, HH:MM", "D days, HH:MM:SS"])
# Generate random start datetime
days_range = (self.config.max_date - self.config.min_date).days
start_days = rng.randint(0, days_range)
start_hour = rng.randint(0, 23)
start_minute = rng.randint(0, 59)
start_second = rng.randint(0, 59)
# Generate random time differences first
diff_days = rng.randint(0, self.config.max_date_difference_days)
diff_seconds = rng.randint(1, self.config.max_time_difference_seconds)
if "%z" in format_str:
# Use simpler timezone format with offset
base = datetime.combine(
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
)
# Generate timezone offsets
start_offset = rng.randint(-12, 12)
end_offset = rng.randint(-12, 12)
# Apply start timezone
base = base.replace(tzinfo=pytz.FixedOffset(start_offset * 60))
start_format = format_str.replace("%z", "%+05d" % (start_offset * 100))
# Calculate end time and convert to end timezone
end_dt = base + timedelta(days=diff_days, seconds=diff_seconds)
end_dt = end_dt.replace(tzinfo=pytz.FixedOffset(end_offset * 60))
end_format = format_str.replace("%z", "%+05d" % (end_offset * 100))
# Format times with their respective timezone offsets
start_time = base.strftime(start_format).rstrip()
end_time = end_dt.strftime(end_format).rstrip()
else:
base = datetime.combine(
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
)
# For non-timezone aware times, both use same format
start_time = base.strftime(format_str).rstrip()
end_time = (base + timedelta(days=diff_days, seconds=diff_seconds)).strftime(format_str).rstrip()
return start_time, end_time, format_str, expected_format
def score_answer(self, answer: Optional[str], entry: dict) -> float:
"""Score an answer based on how close it is to the expected duration
Returns a score between 0 and 1, with partial credit for answers that are
close to correct in the appropriate units/format
"""
if not answer:
return 0.0
expected = entry["answer"]
task_type = entry["metadata"]["task_type"]
try:
if task_type == "date":
# Parse "X days" format
try:
actual = int(answer.strip().split()[0]) # Get number before "days"
expected = int(expected.strip().split()[0])
if actual == expected:
return 1.0
# Partial credit based on how close the day count is
max_diff = self.config.max_date_difference_days
diff = abs(actual - expected)
return max(0.0, 1.0 - (diff / max_diff))
except (ValueError, IndexError):
return 0.0
elif task_type.startswith("time"):
# Parse times into total seconds for comparison
def parse_time(t):
parts = t.strip().split(":")
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
if len(parts) > 2:
if "." in parts[2]: # Has milliseconds
s, ms = parts[2].split(".")
seconds += int(s) + int(ms) / 1000
else:
seconds += int(parts[2])
return seconds
actual_seconds = parse_time(answer)
expected_seconds = parse_time(expected)
if actual_seconds == expected_seconds:
return 1.0
# Partial credit based on how close the times are
max_diff = self.config.max_time_difference_seconds
diff = abs(actual_seconds - expected_seconds)
return max(0.0, 1.0 - (diff / max_diff))
else: # datetime or datetime_tz
# Parse the complex format "X days, HH:MM" or "X days, HH:MM:SS"
def parse_datetime(t):
days = int(t.split(" days,")[0])
time_part = t.split(",")[1].strip()
parts = time_part.split(":")
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
if len(parts) > 2:
seconds += int(parts[2])
return days * 86400 + seconds
actual_seconds = parse_datetime(answer)
expected_seconds = parse_datetime(expected)
if actual_seconds == expected_seconds:
return 1.0
# Partial credit based on total time difference
max_diff = self.config.max_date_difference_days * 86400
diff = abs(actual_seconds - expected_seconds)
return max(0.0, 1.0 - (diff / max_diff))
except (ValueError, IndexError):
return 0.0 # Invalid format
return 0.0
# Register the dataset
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig)