Revert "Restructure {reasoning_gym, tests}/{core, exercises, curricula}"

This reverts commit 10dbb374b0.
This commit is contained in:
EduardDurech 2025-02-07 11:27:21 +00:00
parent b756f26c09
commit 4c3ae0aebf
109 changed files with 0 additions and 0 deletions

View file

@ -0,0 +1,30 @@
"""
Arithmetic tasks for training reasoning capabilities:
- Basic arithmetic
- Chain sums
- Word problems
- Leg counting
- Time intervals
"""
# from .basic_arithmetic import BasicArithmeticDataset
# from .calendar_arithmetic import CalendarArithmeticDataset
from .chain_sum import ChainSumDataset
# from .fraction_simplification import FractionSimplificationDataset
# from .gcd import GcdDataset
# from .lcm import LcmDataset
# from .leg_counting import LegCountingDataset
# from .prime_factorization import PrimeFactorizationDataset
# from .time_intervals import TimeIntervalsDataset
__all__ = [
# "BasicArithmeticDataset",
# "CalendarArithmeticDataset",
"ChainSumDataset",
# "FractionSimplificationDataset",
# "GcdDataset",
# "LcmDataset",
# "LegCountingDataset",
# "PrimeFactorizationDataset",
# "TimeIntervalsDataset",
]

View file

@ -0,0 +1,235 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from ..factory import ProceduralDataset, register_dataset
@dataclass
class BasicArithmeticDatasetConfig:
"""Configuration for arithmetic dataset generation"""
min_terms: int = 2
max_terms: int = 6
min_digits: int = 1
max_digits: int = 4
operators: list[str] = ("+", "-", "*", "/")
allow_parentheses: bool = True
allow_negation: bool = True
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
format_style: Literal["simple", "natural"] = "simple"
whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
assert self.min_digits > 0, "min_digits must be positive"
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
assert len(self.operators) > 0, "must provide at least one operator"
for op in self.operators:
assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}"
def find_common_divisors(a: int, b: int) -> list[int]:
# Helper function to find GCD using Euclidean algorithm
def gcd(x, y):
while y:
x, y = y, x % y
return x
# Get the GCD of the two numbers
gcd_value = gcd(abs(a), abs(b))
# Find all divisors of the GCD
divisors = []
i = 1
# We only need to check up to sqrt(gcd_value)
while i * i <= gcd_value:
if gcd_value % i == 0:
divisors.append(i)
# Don't add the same number twice for perfect squares
if i * i != gcd_value:
divisors.append(gcd_value // i)
i += 1
return divisors
def eval_floordiv(exp: str) -> int:
return eval(exp.replace("/", "//"))
class BasicArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
def __init__(self, config: BasicArithmeticDatasetConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted arithmetic expression
- answer: str, the ground truth result
- metadata: dict with generation parameters
"""
# Create deterministic RNG from base seed and idx
item_rng = Random(self.seed + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
if self.config.allow_parentheses:
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
else:
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
question = self._format_question(item_rng, expression)
return {
"question": question,
"answer": str(result),
"metadata": {
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
},
}
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
"""Generate a complex arithmetic task with possible parentheses"""
def add_terms(remaining: int) -> list[str]:
# split terms randomly into left and right
num_left = rng.randint(1, remaining)
num_right = remaining - num_left
left_parts = []
if num_left > 1 and rng.random() > 0.5 and self.config.allow_parentheses:
if rng.random() > 0.5 and self.config.allow_negation:
left_parts.append("-(")
else:
left_parts.append("(")
left_parts.extend(add_terms(num_left))
left_parts.append(")")
else:
for i in range(num_left):
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
left_parts.append(str(c))
if i + 1 < num_left:
left_parts.append(rng.choice([o for o in self.config.operators if o != "/"]))
if num_right == 0:
return left_parts
op = rng.choice(self.config.operators)
if op != "/":
left_parts.append(op)
left_parts.extend(add_terms(num_right))
else:
# left part has parantheses or no division
dividend = eval_floordiv("".join(left_parts) if left_parts[-1] == ")" else left_parts[-1])
left_parts.append(op)
if num_right > 1:
right_parts = add_terms(num_right - 1)
if right_parts[-1] == ")":
right_value = eval_floordiv("".join(right_parts))
if right_value == 0:
correction = 1
else:
target = rng.choice(find_common_divisors(dividend, right_value))
correction = target - right_value
right_parts.pop()
right_parts.append("+")
right_parts.append(str(correction))
right_parts.append(")")
else:
divisor = rng.choice(find_common_divisors(dividend, 0))
left_parts.append(str(divisor))
left_parts.append("+")
left_parts.extend(right_parts)
else:
if dividend != 0:
divisor = rng.choice(find_common_divisors(dividend, 0))
else:
divisor = rng.randint(1, 10**num_digits - 1)
left_parts.append(str(divisor))
return left_parts
parts = add_terms(num_terms)
# Add whitespace according to config
if self.config.whitespace == "no_space":
expression = "".join(parts)
elif self.config.whitespace == "single":
expression = " ".join(parts)
else: # random
space_parts = []
for p in parts:
if rng.random() < 0.15:
space_parts.append(" ")
space_parts.append(p)
expression = "".join(space_parts).strip()
result = eval_floordiv(expression) # Note: eval is safe here as we control the input
return expression, result
def _generate_simple_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
"""Generate a simple linear arithmetic task without parentheses"""
constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)]
operators = [rng.choice(self.config.operators) for _ in range(num_terms - 1)]
# Build expression and compute result
expression_parts = []
result = constants[0]
expression_parts.append(str(constants[0]))
for i, op in enumerate(operators):
c = constants[i + 1]
expression_parts.append(op)
expression_parts.append(str(c))
if op == "+":
result += c
elif op == "-":
result -= c
elif op == "*":
result *= c
elif op == "/":
# Find a number that divides result evenly
divisors = [d for d in range(2, min(abs(result), 10**num_digits)) if result % d == 0]
if divisors:
c = rng.choice(divisors)
result //= c
else:
# Fallback to multiplication if no clean division possible
op = "*"
c = rng.randint(1, 10**num_digits - 1)
result *= c
else:
raise RuntimeError(f"Unsupported operator: {op}")
expression = " ".join(expression_parts)
return expression, result
def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style"""
if self.config.format_style == "simple":
return f"{expression} ="
else:
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
return rng.choice(templates).format(expression)
# Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)

View file

@ -0,0 +1,490 @@
import calendar
import math
import random
from dataclasses import dataclass
from datetime import date, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
class Weekday(Enum):
MONDAY = auto()
TUESDAY = auto()
WEDNESDAY = auto()
THURSDAY = auto()
FRIDAY = auto()
SATURDAY = auto()
SUNDAY = auto()
@classmethod
def from_date(cls, d: date) -> "Weekday":
return list(cls)[d.weekday()]
@classmethod
def random(cls, rng: random.Random) -> "Weekday":
return list(cls)[rng.randint(0, 6)]
@classmethod
def __getitem__(cls, idx) -> "Weekday":
return list(cls)[idx]
@property
def index(self) -> int:
return self.value - 1
def __str__(self) -> str:
return self.name.capitalize()
class CalendarTask(Enum):
WEEKDAY_OFFSET = "weekday_offset"
WEEKDAY_OF_DATE = "weekday_of_date"
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
COUNT_DAYS = "count_days"
COUNT_BUSINESS_DAYS = "count_business_days"
IS_LEAP_YEAR = "is_leap_year"
@dataclass
class CalendarArithmeticConfig:
year: int = 2022
tasks: Optional[List[str]] = None
offset_upper_bound: int = 100
leap_year_range: int = 200
seed: Optional[int] = 42
size: int = 500
def __post_init__(self):
if self.tasks is None:
self.tasks = [task.value for task in CalendarTask]
else:
self.tasks = [task.lower() for task in self.tasks]
valid_tasks = {task.value for task in CalendarTask}
invalid_tasks = set(self.tasks) - valid_tasks
if invalid_tasks:
valid_task_list = ", ".join(sorted(valid_tasks))
raise ValueError(
f"Invalid tasks: {', '.join(sorted(invalid_tasks))}. " f"Valid tasks are: {valid_task_list}"
)
def validate(self) -> None:
"""Validate the configuration parameters."""
if not isinstance(self.year, int) or self.year <= 0:
raise ValueError(f"year must be a positive integer, got {self.year}")
if self.seed is not None and not isinstance(self.seed, int):
raise ValueError(f"seed must be an integer or None, got {type(self.seed)}")
if not isinstance(self.size, int) or self.size <= 0:
raise ValueError(f"size must be a positive integer, got {self.size}")
class CalendarArithmeticDataset(ProceduralDataset):
DAY_QUESTION_TEMPLATES = [
"Answer with the weekday's name (e.g., Monday, Tuesday, etc.).",
"Provide the full name of the weekday.",
"State the weekday (Monday through Sunday).",
"Give the weekday name in full.",
"Reply with just the weekday name.",
"Write out the full weekday name.",
"Respond with the weekday (Monday-Sunday).",
"Answer using the complete weekday name.",
"Name the day of the week in full.",
]
COUNT_QUESTION_TEMPLATES = [
"Answer with a number.",
"Provide the count as a number.",
"Respond with just the number.",
"Write the total number.",
"Give the count numerically.",
"State the amount as a number.",
"Reply with the numerical value.",
"Express your answer as a number.",
]
def __init__(self, config: CalendarArithmeticConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.task_handlers = {
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
CalendarTask.COUNT_DAYS.value: self._count_days,
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
CalendarTask.IS_LEAP_YEAR.value: self._is_leap_year,
}
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
def __getitem__(self, idx: int) -> dict:
item_rng = random.Random(self.seed + idx)
task = item_rng.choice(self.tasks)
question, answer, metadata = task(item_rng)
return {
"question": question,
"answer": str(answer),
"metadata": metadata,
}
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
Task: Given a starting date and a day offset (which may be positive or negative),
ask what day of the week it will be.
Examples:
- "If today is Wednesday, March 13, 2024, what day of the week will it be in 10 days? Answer with the weekday's name."
- "If today is Wednesday, March 13, 2024, what day of the week was it 10 days ago? Answer with the weekday's name."
"""
year = self.config.year
start_date = self._random_date_for_year(rng, year)
offset = rng.randint(1, self.config.offset_upper_bound)
sign = rng.choice([-1, 1])
offset_days = sign * offset
target_date = start_date + timedelta(days=offset_days)
target_weekday = target_date.strftime("%A")
date_str = f"{start_date.strftime('%A')}, {start_date.strftime('%B')} {start_date.day}, {start_date.year}"
if offset_days >= 0:
templates = [
f"If today is {date_str}, what day of the week will it be in {offset_days} days? ",
f"Starting from {date_str}, which weekday falls after a {offset_days}-day jump? ",
f"Count forward {offset_days} days from {date_str} - what's the weekday? ",
]
else:
templates = [
f"If today is {date_str}, what day of the week was it {abs(offset_days)} days ago? ",
f"Starting from {date_str}, which weekday was it {abs(offset_days)} days before? ",
f"Count backward {abs(offset_days)} days from {date_str} - what's the weekday? ",
]
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.WEEKDAY_OFFSET.value,
"start_date": start_date.isoformat(),
"offset_days": offset_days,
"target_date": target_date.isoformat(),
}
return question, target_weekday, metadata
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Ask what day of the week a given date was.
example:
"What day of the week was January 15, 2024?
Answer with the weekday's name."
"""
year = self.config.year
target_date = self._random_date_for_year(rng, year)
answer_weekday = target_date.strftime("%A")
templates = [
f"What day of the week was {target_date.strftime('%B')} {target_date.day}, {year}?",
f"On which weekday did {target_date.strftime('%B')} {target_date.day}, {year} fall?",
f"Name the day of the week for {target_date.strftime('%m/%d/%Y')}.",
]
question = f"{rng.choice(templates)} {rng.choice(self.DAY_QUESTION_TEMPLATES)}"
metadata = {
"task": CalendarTask.WEEKDAY_OF_DATE.value,
"target_date": target_date.isoformat(),
}
return question, answer_weekday, metadata
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
example:
"If the first day of the year was a Monday, what day of the week will December 31 be?
Answer with the weekday's name."
"""
year = self.config.year
first_day = Weekday.random(rng)
first_day_index = first_day.index
# Ensure target date is not January 1.
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
max_delta = timedelta(days=self.config.offset_upper_bound)
max_date = min(year_start + max_delta, year_end)
while True:
target_date = self._random_date_between(rng, year_start, max_date)
if target_date != date(year, 1, 1):
break
delta_days = (target_date - date(year, 1, 1)).days
answer_index = (first_day_index + delta_days) % 7
answer_weekday = Weekday(answer_index + 1)
templates = [
f"If the first day of the year was a {first_day}, what day of the week will "
f"{target_date.strftime('%B')} {target_date.day} be? ",
f"Given that January 1 fell on a {first_day}, which weekday occurs on "
f"{target_date.strftime('%B')} {target_date.day}? ",
f"In a year where {first_day} is January 1st, name the weekday of "
f"{target_date.strftime('%B')} {target_date.day}. ",
]
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
"year": year,
"first_day": str(first_day),
"target_date": target_date.isoformat(),
"delta_days": delta_days,
}
return question, answer_weekday, metadata
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: For a recurring event defined by an ordinal weekday pattern in a month,
ask on which day of the month the event occurs.
example:
"If a meeting is scheduled on the second Tuesday of May 2024, on which day does it fall?
Answer with a number."
"""
year = self.config.year
month = rng.randint(1, 12)
ordinals = ["first", "second", "third", "fourth", "last"]
ordinal = rng.choice(ordinals)
weekday = Weekday.random(rng)
month_name = calendar.month_name[month]
_, last_day = calendar.monthrange(year, month)
if ordinal != "last":
ordinal_number = {"first": 1, "second": 2, "third": 3, "fourth": 4}[ordinal]
count = 0
event_day = None
for day in range(1, last_day + 1):
d = date(year, month, day)
if d.strftime("%A") == str(weekday):
count += 1
if count == ordinal_number:
event_day = day
break
if event_day is None:
# This should rarely happen but in some months the ordinal may not exist.
event_day = -1
else:
event_day = None
for day in range(last_day, 0, -1):
d = date(year, month, day)
if d.strftime("%A") == str(weekday):
event_day = day
break
if event_day is None:
event_day = -1
templates = [
f"If a meeting is scheduled on the {ordinal} {weekday} of {month_name} {year}, on which day of the month does it occur? ",
f"In {month_name} {year}, if an event recurs on the {ordinal} {weekday}, what is the date (day of the month) of the event? ",
f"Determine the day of the month for the {ordinal} {weekday} in {month_name} {year}. ",
]
question = (
rng.choice(templates)
+ rng.choice(self.COUNT_QUESTION_TEMPLATES)
+ " Answer with -1 if the ordinal does not exist in the month."
)
metadata = {
"task": CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
"year": year,
"month": month,
"ordinal": ordinal,
"weekday": str(weekday),
}
return question, str(event_day), metadata
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Ask how many times a given weekday occurs in a specified range.
example:
"How many days are there between March 1, 2024 and March 15, 2024?
Answer with a number."
"""
year = self.config.year
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
start_date = self._random_date_between(rng, year_start, year_end)
max_delta = timedelta(days=self.config.offset_upper_bound)
end_date = self._random_date_between(rng, start_date, min(year_end, start_date + max_delta))
weekday = Weekday.random(rng)
def count_weekday_between(d1: date, d2: date, weekday: str) -> int:
days = (d2 - d1).days + 1
return sum(1 for i in range(days) if (d1 + timedelta(days=i)).strftime("%A") == weekday)
count = count_weekday_between(start_date, end_date, str(weekday))
templates = [
f"How many {weekday}s are there from {start_date.strftime('%A, %B')} {start_date.day}, {year} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} (inclusive of both dates)? ",
f"Count the occurrences of {weekday} from {start_date.strftime('%A, %B')} {start_date.day} "
f"to {end_date.strftime('%A, %B')} {end_date.day}, {year} (including both start and end dates). ",
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(counting both dates), how many times does {weekday} occur? ",
]
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.COUNT_DAYS.value,
"year": year,
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
}
return question, str(count), metadata
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Count the number of business days (Monday-Friday) between two dates.
example:
"How many business days (Monday-Friday) are there between March 1, 2024 and March 15, 2024?
Answer with a number."
"""
year = self.config.year
year_start = date(year, 1, 1)
year_end = date(year, 12, 31)
start_date = self._random_date_between(rng, year_start, year_end)
max_delta = timedelta(days=self.config.offset_upper_bound)
end_date = self._random_date_between(rng, start_date, start_date + max_delta)
count = 0
def business_days_between(d1: date, d2: date) -> int:
days = (d2 - d1).days + 1
weeks, remainder = divmod(days, 7)
count = weeks * 5
start_weekday = d1.weekday()
for i in range(remainder):
if (start_weekday + i) % 7 < 5:
count += 1
return count
count = business_days_between(start_date, end_date)
templates = [
f"How many business days (Monday-Friday) are there from "
f"{start_date.strftime('%A, %B')} {start_date.day}, {year} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(inclusive of both dates)? ",
f"Count the weekdays (excluding weekends) from "
f"{start_date.strftime('%A, %B')} {start_date.day} to "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(including both start and end dates). ",
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
f"(counting both dates), what's the total count of business days "
f"(Monday through Friday)? ",
]
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
metadata = {
"task": CalendarTask.COUNT_BUSINESS_DAYS.value,
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
}
return question, str(count), metadata
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
"""
task: Given a year, determine whether it is a leap year.
example:
"Is 2024 a leap year? Answer with Yes or No."
"""
semirange = self.config.leap_year_range // 2
year = rng.randint(self.config.year - semirange, self.config.year + semirange)
is_leap = calendar.isleap(year)
answer = "Yes" if is_leap else "No"
templates = [
f"Determine if the year {year} is a leap year. ",
f"Is {year} a leap year? ",
f"Tell me whether {year} is a leap year. ",
]
question = rng.choice(templates) + "Answer with Yes or No."
metadata = {
"task": CalendarTask.IS_LEAP_YEAR.value,
"year": year,
"is_leap": is_leap,
}
return question, answer, metadata
def _random_date_for_year(self, rng: random.Random, year: int) -> date:
"""Return a random date within the given year."""
month = rng.randint(1, 12)
_, last_day = calendar.monthrange(year, month)
day = rng.randint(1, last_day)
return date(year, month, day)
def _random_date_between(self, rng: random.Random, start_date: date, end_date: date) -> date:
"""
Return a random date between start_date and end_date (inclusive).
Assumes start_date <= end_date.
"""
if start_date > end_date:
raise ValueError("start_date must be <= end_date")
delta = (end_date - start_date).days
random_days = rng.randint(0, delta)
return start_date + timedelta(days=random_days)
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
return 0.0
oracle_answer = entry["answer"]
task = entry["metadata"]["task"]
if task in {
CalendarTask.WEEKDAY_OFFSET.value,
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
CalendarTask.WEEKDAY_OF_DATE.value,
}:
if not answer:
return 0.0
answer = answer.strip()
oracle_answer = oracle_answer
weekdays = {d.name.title() for d in Weekday}
if answer == oracle_answer:
return 1.0
if answer in weekdays:
return 0.1
if answer.title() in weekdays:
return 0.05
if answer.title() not in weekdays:
return 0.0
return 0.0
# denser reward for numerical tasks
elif task in {
CalendarTask.COUNT_BUSINESS_DAYS.value,
CalendarTask.COUNT_DAYS.value,
CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
}:
try:
ans_num = int(answer.strip())
oracle_num = int(oracle_answer.strip())
if oracle_num == 0:
return 1.0 if ans_num == 0 else 0.0
relative_error = abs(ans_num - oracle_num) / oracle_num
return max(0.0, math.exp(-5 * relative_error))
except (ValueError, AttributeError):
return 0.0
elif task == CalendarTask.IS_LEAP_YEAR.value:
if answer.strip().lower() == oracle_answer.lower():
return 1.0
return 0.0
return 0.0
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)

View file

@ -0,0 +1,103 @@
from dataclasses import dataclass
from typing import Dict, Any
import operator
import numpy as np
from reasoning_gym.core.base_curriculum import BaseCurriculum
@dataclass
class ChainSumDataset:
"""Dataset generator for chain arithmetic problems."""
def __init__(self):
# Define operator mappings
self.pedmas = {
'**': (operator.pow, 3), # (function, precedence)
'*': (operator.mul, 2),
'/': (operator.truediv, 2),
'+': (operator.add, 1),
'-': (operator.sub, 1)
}
self.curriculum = None
def generate(self, curriculum: BaseCurriculum) -> Dict[str, Any]:
"""Generate a problem using the curriculum's template system"""
self.curriculum = curriculum
max_attempts = 10
for _ in range(max_attempts):
try:
template = curriculum.get_template(curriculum.rng)
return template.eval(self, curriculum.rng)
except ValueError as e:
if "Invalid operation" in str(e):
continue
raise
def _parse_expression(self, executed_parts: Dict[str, str]) -> tuple[list, list]:
"""Extract values and operators from executed parts"""
values = []
operators = []
i = 0
while f"term_{i}" in executed_parts:
val = executed_parts[f"term_{i}"].lstrip('+')
try:
num = val.lstrip('-')
if num.startswith(('0b', '0x')):
sign = -1 if val.startswith('-') else 1
base = 2 if num.startswith('0b') else 16 if num.startswith('0x') else 10
values.append(sign * float(int(num[2:], base)))
else:
values.append(float(val))
except ValueError:
values.append(val)
i += 1
# Extract operators
for i in range(len(values) - 1):
if f"op_{i}" in executed_parts:
operators.append(executed_parts[f"op_{i}"])
return values, operators
def _evaluate_expression(self, values: list, operators: list) -> float:
"""Evaluate expression respecting operator precedence"""
if not operators:
return values[0] if values else 0
vals, ops = list(values), list(operators)
def handle_edge(op, a, b):
# Handle division first
if op == '/':
if np.isclose(b, 0):
raise ValueError("chain_sum.py: Invalid operation, division by zero")
# Handle exponentiation edge cases
if op == '**':
if np.isclose(a, 0) and b < 0:
raise ValueError("chain_sum.py: Invalid operation, zero with negative exponent")
if a < 0 and not isinstance(b, int) and not b.is_integer():
raise ValueError("chain_sum.py: Invalid operation, fractional exponent of negative base")
# Handle potential overflows
try:
result = self.pedmas[op][0](a, b)
if abs(result) > np.finfo(float).max:
raise OverflowError
return result
except OverflowError:
raise ValueError("chain_sum.py: Invalid operation, overflow in calculation")
for precedence in sorted({self.pedmas[op][1] for op in ops}, reverse=True):
i = 0
while i < len(ops):
if self.pedmas[ops[i]][1] != precedence:
i += 1
continue
op = ops[i]
a, b = vals[i], vals[i + 1]
result = handle_edge(op, a, b)
vals[i] = result # Replace first value with result
del vals[i + 1] # Remove second value
del ops[i] # Remove used operator
return vals[0]

View file

@ -0,0 +1,123 @@
"""Fraction simplification task generator"""
from dataclasses import dataclass
from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass
class FractionSimplificationConfig:
"""Configuration for fraction simplification task generation"""
min_value: int = 1 # Minimum value for numerator/denominator
max_value: int = 1000 # Maximum value for numerator/denominator
min_factor: int = 1 # Minimum multiplication factor
max_factor: int = 100 # Maximum multiplication factor
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_value > 0, "min_value must be positive"
assert self.max_value > self.min_value, "max_value must be > min_value"
assert self.min_factor >= 1, "min_factor must be at least 1"
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
# Validate styles
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
for style in self.styles:
assert style in valid_styles, f"Invalid style: {style}. Must be one of {valid_styles}"
class FractionSimplificationDataset(ProceduralDataset):
"""Generates fraction simplification tasks"""
def __init__(self, config: FractionSimplificationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
"""Generate a random fraction and its simplified form.
Returns (numerator, denominator, simplified_num, simplified_den)"""
# Try to generate valid fractions until we get one that meets our criteria
for _ in range(10): # Limit attempts to avoid infinite loop
# Generate the simplified fraction first
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Make sure they're coprime by dividing by their GCD
common = gcd(simplified_num, simplified_den)
simplified_num //= common
simplified_den //= common
# Check if simplified fraction is within bounds
if (
self.config.min_value <= simplified_num <= self.config.max_value
and self.config.min_value <= simplified_den <= self.config.max_value
):
# Ensure numerator is smaller than denominator
if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num
# Multiply both by a random factor to create the unsimplified version
factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor
denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den
# If we failed to find a good fraction after max attempts,
# generate one that's guaranteed to be within bounds
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Ensure numerator is smaller than denominator
if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num
factor = rng.randint(self.config.min_factor, self.config.max_factor)
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles"""
if style == "plain":
return f"{num}/{den}"
elif style == "latex_inline":
return f"${num}/{den}$"
elif style == "latex_frac":
return f"$\\frac{{{num}}}{{{den}}}$"
elif style == "latex_dfrac":
return f"$\\dfrac{{{num}}}{{{den}}}$"
else:
raise ValueError(f"Unknown fraction style: {style}")
def __getitem__(self, idx: int) -> dict:
"""Generate a single fraction simplification task"""
rng = Random(self.seed + idx)
num, den, simple_num, simple_den = self._generate_fraction(rng)
# Choose a random style from configured styles
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
# Format both question and answer in the same style
question_fraction = self._format_fraction(num, den, style)
answer_fraction = self._format_fraction(simple_num, simple_den, style)
return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
"answer": answer_fraction,
"metadata": {
"numerator": num,
"denominator": den,
"simplified_numerator": simple_num,
"simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style,
},
}
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)

View file

@ -0,0 +1,66 @@
"""Greatest Common Divisor (GCD) task generator"""
from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass
class GCDConfig:
"""Configuration for GCD task generation"""
min_numbers: int = 2 # Minimum numbers to find GCD of
max_numbers: int = 2 # Maximum numbers to find GCD of
min_value: int = 1 # Minimum value for each number
max_value: int = 1000 # Maximum value for each number
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers >= 2, "min_numbers must be at least 2"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
assert self.min_value >= 1, "min_value must be positive"
assert self.max_value > self.min_value, "max_value must be > min_value"
class GCDDataset(ProceduralDataset):
"""Generates Greatest Common Divisor (GCD) tasks"""
def __init__(self, config: GCDConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their GCD.
Will try up to 3 times to find numbers with GCD > 1."""
# Try up to 3 times to get GCD > 1
for _ in range(3):
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(gcd, numbers)
if result > 1:
break
# Return the last generated numbers, whether they met the criteria or not
return numbers, result
def __getitem__(self, idx: int) -> dict:
"""Generate a single GCD task"""
rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"answer": str(result),
"metadata": {"numbers": numbers, "result": result},
}
register_dataset("gcd", GCDDataset, GCDConfig)

View file

@ -0,0 +1,69 @@
"""Least Common Multiple (LCM) task generator"""
from dataclasses import dataclass
from functools import reduce
from math import lcm
from random import Random
from typing import List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass
class LCMConfig:
"""Configuration for LCM task generation"""
min_numbers: int = 2 # Minimum numbers to find LCM of
max_numbers: int = 2 # Maximum numbers to find LCM of
min_value: int = 1 # Minimum value for each number
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_numbers >= 2, "min_numbers must be at least 2"
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
assert self.min_value >= 1, "min_value must be positive"
assert self.max_value > self.min_value, "max_value must be > min_value"
class LCMDataset(ProceduralDataset):
"""Generates Least Common Multiple (LCM) tasks"""
def __init__(self, config: LCMConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their LCM.
Will try up to 3 times to find numbers with LCM < product."""
def calculate_product(nums: List[int]) -> int:
return reduce(lambda x, y: x * y, nums)
# Try up to 3 times to get LCM < product
for _ in range(3):
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(lcm, numbers)
if result < calculate_product(numbers):
break
# Return the last generated numbers, whether they met the criteria or not
return numbers, result
def __getitem__(self, idx: int) -> dict:
"""Generate a single LCM task"""
rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
"answer": str(result),
"metadata": {"numbers": numbers, "result": result},
}
register_dataset("lcm", LCMDataset, LCMConfig)

View file

@ -0,0 +1,118 @@
"""Leg counting task generator"""
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
ANIMALS = {
# Animals with 0 legs
"snake": 0,
"sea slug": 0,
"jellyfish": 0,
"flatworm": 0,
"leech": 0,
# Animals with 2 legs
"chicken": 2,
"bird": 2,
"human": 2,
"duck": 2,
# Animals with 4 legs
"dog": 4,
"cat": 4,
"cow": 4,
"horse": 4,
"lion": 4,
"elephant": 4,
"giraffe": 4,
"tiger": 4,
"deer": 4,
"sheep": 4,
# Animals with 5 legs
"starfish": 5,
# Animals with 6 legs
"insect": 6,
"ant": 6,
"butterfly": 6,
"beetle": 6,
"bee": 6,
"wasp": 6,
"grasshopper": 6,
"cricket": 6,
"cockroach": 6,
"praying mantis": 6,
"firefly": 6,
# Animals with 8 legs
"spider": 8,
"scorpion": 8,
# Animals with 10 legs
"crab": 10,
"lobster": 10,
"shrimp": 10,
# Animals with 14 legs
"woodlouse": 14,
}
@dataclass
class LegCountingConfig:
"""Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_animals > 0, "min_animals must be positive"
assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals"
assert self.max_instances > 0, "max_instances must be positive"
class LegCountingDataset(ProceduralDataset):
"""Generates leg counting arithmetic tasks"""
def __init__(self, config: LegCountingConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_animals(self, rng: Random) -> Dict[str, int]:
"""Generate a random set of animals and their counts"""
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
animals = {}
# Select random animals
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
for animal in selected_animals:
count = rng.randint(1, self.config.max_instances)
animals[animal] = count
return animals
def __getitem__(self, idx: int) -> dict:
"""Generate a single leg counting task"""
rng = Random(self.seed + idx)
# Generate random animals and their counts
animals = self._generate_animals(rng)
# Calculate total legs
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
# Format animal counts for question
animal_list = []
for animal, count in animals.items():
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
return {
"question": question,
"answer": str(total_legs),
"metadata": {"animals": animals, "total_legs": total_legs},
}
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)

View file

@ -0,0 +1,69 @@
"""Prime factorization task generator"""
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass
class PrimeFactorizationConfig:
"""Configuration for prime factorization task generation"""
min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_value >= 2, "min_value must be >= 2"
assert self.max_value >= self.min_value, "max_value must be >= min_value"
class PrimeFactorizationDataset(ProceduralDataset):
"""Generates prime factorization tasks"""
def __init__(self, config: PrimeFactorizationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _prime_factors(self, n: int) -> List[int]:
"""Compute prime factors of a number"""
factors = []
d = 2
while n > 1:
while n % d == 0:
factors.append(d)
n //= d
d += 1
if d * d > n:
if n > 1:
factors.append(n)
break
return factors
def __getitem__(self, idx: int) -> dict:
"""Generate a single prime factorization task"""
rng = Random(self.seed + idx)
# Generate random number to factorize
number = rng.randint(self.config.min_value, self.config.max_value)
# Calculate prime factors
factors = self._prime_factors(number)
# Format answer as multiplication of prime factors
answer = " × ".join(map(str, factors))
return {
"question": (
f"Find the prime factorization of {number}. Write the factors separated by × "
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
),
"answer": answer,
"metadata": {"number": number, "factors": factors},
}
register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)

View file

@ -0,0 +1,323 @@
import random
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import List, Optional
import pytz
from dateutil import parser
from ..factory import ProceduralDataset, register_dataset
@dataclass
class TimeIntervalsConfig:
"""Configuration for time interval calculation tasks"""
min_time: time = time.min
max_time: time = time.max
max_time_difference_seconds: int = 24 * 60 * 60
min_date: date = date(1900, 1, 1)
max_date: date = date(3000, 1, 1)
max_date_difference_days: int = 100
task_types: List[str] = field(
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
)
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.max_time_difference_seconds > 0, "max_time_difference_seconds must be positive"
assert self.max_date_difference_days > 0, "max_date_difference_days must be positive"
assert self.min_date < self.max_date, "min_date must be before max_date"
class TimeIntervalsDataset(ProceduralDataset):
"""Generates time interval calculation tasks with various formats and complexities"""
TEMPLATES = [
"What is the duration between {start} and {end}? Please answer in {format}.",
"Calculate the time difference between {start} and {end}. Express the result in {format}.",
"How much time elapsed from {start} to {end}? Give your answer in {format}.",
"A meeting started at {start} and ended at {end}. How long was the meeting? Answer in {format}.",
"A system operation started at {start} and completed at {end}. What was the operation duration? Answer in {format}.",
"A database query started at {start} and ended at {end}. How long did the query take? Answer in {format}.",
"A flight departed at {start} and arrived at {end}. How long was the flight? Answer in {format}.",
"A video call started at {start} and ended at {end}. How long was the call? Answer in {format}.",
"A system backup started at {start} and completed at {end}. What was the total backup duration? Answer in {format}.",
"A conference call began at {start} and ended at {end}. How long was the conference? Answer in {format}.",
]
TIME_FORMATS = [
"%H:%M",
"%H:%M:%S",
"%H:%M:%S.%f",
]
DATE_FORMATS = [
"%Y-%m-%d",
"%B %d, %Y",
"%m/%d/%Y",
"%A, %B %d, %Y", # e.g. Monday, January 15, 2024
"%a %b %d %Y", # e.g. Mon Jan 15 2024
"%d %B %Y", # e.g. 15 January 2024
"%Y-%m-%d (%A)", # e.g. 2024-01-15 (Monday)
]
DATETIME_FORMATS = [
"%Y-%m-%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M %z", # For UTC offset format
"%Y-%m-%d %H:%M:%S %z", # For UTC offset with seconds
"%A, %B %d, %Y at %H:%M", # e.g. Monday, January 15, 2024 at 14:30
"%a %b %d %Y %H:%M:%S", # e.g. Mon Jan 15 2024 14:30:45
"%d %B %Y, %H:%M", # e.g. 15 January 2024, 14:30
"%d %B %Y, %H:%M %z", # e.g. 15 January 2024, 14:30 +0000
"%Y-%m-%d (%A) %H:%M:%S %z", # e.g. 2024-01-15 (Monday) 14:30:45 +0000
]
def __init__(self, config: TimeIntervalsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single time interval calculation task"""
item_rng = random.Random(self.seed + idx)
# Randomly choose task type from config
task_type = item_rng.choice(self.config.task_types)
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type)
template = item_rng.choice(self.TEMPLATES)
question = template.format(start=start_time, end=end_time, format=expected_format)
# Calculate the actual difference
if isinstance(start_time, str):
# Handle datetime strings with weekday names in parentheses
start_time = start_time.split(" (")[0] # Remove (Weekday) if present
end_time = end_time.split(" (")[0]
# Parse with UTC offset handling
start_dt = parser.parse(start_time)
end_dt = parser.parse(end_time)
else:
start_dt = start_time
end_dt = end_time
difference = end_dt - start_dt
# Format the answer according to expected_format
if expected_format == "HH:MM":
total_seconds = difference.total_seconds()
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}"
elif expected_format == "HH:MM:SS":
total_seconds = difference.total_seconds()
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}"
elif expected_format == "HH:MM:SS.mmm":
total_seconds = difference.total_seconds()
ms = int((total_seconds % 1) * 1000)
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}.{ms:03d}"
elif expected_format == "D days":
answer = f"{difference.days} days"
else: # "D days, HH:MM" or "D days, HH:MM:SS"
days = difference.days
hours = difference.seconds // 3600
minutes = (difference.seconds % 3600) // 60
seconds = difference.seconds % 60
if expected_format == "D days, HH:MM:SS":
answer = f"{days} days, {hours:02d}:{minutes:02d}:{seconds:02d}"
else: # "D days, HH:MM"
answer = f"{days} days, {hours:02d}:{minutes:02d}"
return {
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type,
"start_time": start_dt,
"end_time": end_dt,
"format": format_str,
"expected_format": expected_format,
},
}
def _generate_times(self, rng: random.Random, task_type: str):
"""Generate start and end times based on task type"""
if task_type.startswith("time"):
if task_type == "time_ms":
format_str = self.TIME_FORMATS[2] # Get milliseconds format
expected_format = "HH:MM:SS.mmm"
else:
format_str = next(f for f in self.TIME_FORMATS if f.count(":") == (2 if "seconds" in task_type else 1))
expected_format = "HH:MM:SS" if "seconds" in task_type else "HH:MM"
# Generate random start time
start_hour = rng.randint(0, 23)
start_minute = rng.randint(0, 59)
start_second = rng.randint(0, 59)
base = datetime.combine(date.today(), time(start_hour, start_minute, start_second))
# Calculate seconds remaining until midnight
seconds_until_midnight = ((24 - start_hour) * 3600) - (start_minute * 60) - start_second
# Use the minimum of config max and seconds until midnight
max_seconds = min(self.config.max_time_difference_seconds, seconds_until_midnight)
diff_seconds = rng.randint(1, max_seconds) if max_seconds > 0 else 0
if task_type == "time_ms":
# Add microseconds for millisecond precision
base = base.replace(microsecond=rng.randint(0, 999) * 1000)
end_time = base + timedelta(seconds=diff_seconds, microseconds=rng.randint(0, 999) * 1000)
# Format with exactly 3 decimal places for milliseconds
start_time = base.strftime(format_str)[:-3] # Remove extra microsecond digits
end_time = end_time.strftime(format_str)[:-3] # Remove extra microsecond digits
else:
start_time = base.strftime(format_str)
end_time = (base + timedelta(seconds=diff_seconds)).strftime(format_str)
elif task_type == "date":
format_str = rng.choice(self.DATE_FORMATS)
expected_format = "D days" # Always return number of days for date tasks
# Generate random start date within configured range, leaving room for end date
max_date_difference_days = min(
self.config.max_date_difference_days, (self.config.max_date - self.config.min_date).days
)
max_start_days = (self.config.max_date - self.config.min_date).days - max_date_difference_days
start_days = rng.randint(0, max_start_days - 1)
start_date = self.config.min_date + timedelta(days=start_days)
# Ensure positive difference between dates
diff_days = rng.randint(0, max_date_difference_days)
end_date = start_date + timedelta(days=diff_days)
start_time = start_date.strftime(format_str)
end_time = end_date.strftime(format_str)
else: # datetime or datetime_tz
format_str = rng.choice(self.DATETIME_FORMATS)
# Choose between HH:MM and HH:MM:SS format for datetime answers
expected_format = rng.choice(["D days, HH:MM", "D days, HH:MM:SS"])
# Generate random start datetime
days_range = (self.config.max_date - self.config.min_date).days
start_days = rng.randint(0, days_range)
start_hour = rng.randint(0, 23)
start_minute = rng.randint(0, 59)
start_second = rng.randint(0, 59)
# Generate random time differences first
diff_days = rng.randint(0, self.config.max_date_difference_days)
diff_seconds = rng.randint(1, self.config.max_time_difference_seconds)
if "%z" in format_str:
# Use simpler timezone format with offset
base = datetime.combine(
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
)
# Generate timezone offsets
start_offset = rng.randint(-12, 12)
end_offset = rng.randint(-12, 12)
# Apply start timezone
base = base.replace(tzinfo=pytz.FixedOffset(start_offset * 60))
start_format = format_str.replace("%z", "%+05d" % (start_offset * 100))
# Calculate end time and convert to end timezone
end_dt = base + timedelta(days=diff_days, seconds=diff_seconds)
end_dt = end_dt.replace(tzinfo=pytz.FixedOffset(end_offset * 60))
end_format = format_str.replace("%z", "%+05d" % (end_offset * 100))
# Format times with their respective timezone offsets
start_time = base.strftime(start_format).rstrip()
end_time = end_dt.strftime(end_format).rstrip()
else:
base = datetime.combine(
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
)
# For non-timezone aware times, both use same format
start_time = base.strftime(format_str).rstrip()
end_time = (base + timedelta(days=diff_days, seconds=diff_seconds)).strftime(format_str).rstrip()
return start_time, end_time, format_str, expected_format
def score_answer(self, answer: Optional[str], entry: dict) -> float:
"""Score an answer based on how close it is to the expected duration
Returns a score between 0 and 1, with partial credit for answers that are
close to correct in the appropriate units/format
"""
if not answer:
return 0.0
expected = entry["answer"]
task_type = entry["metadata"]["task_type"]
try:
if task_type == "date":
# Parse "X days" format
try:
actual = int(answer.strip().split()[0]) # Get number before "days"
expected = int(expected.strip().split()[0])
if actual == expected:
return 1.0
# Partial credit based on how close the day count is
max_diff = self.config.max_date_difference_days
diff = abs(actual - expected)
return max(0.0, 1.0 - (diff / max_diff))
except (ValueError, IndexError):
return 0.0
elif task_type.startswith("time"):
# Parse times into total seconds for comparison
def parse_time(t):
parts = t.strip().split(":")
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
if len(parts) > 2:
if "." in parts[2]: # Has milliseconds
s, ms = parts[2].split(".")
seconds += int(s) + int(ms) / 1000
else:
seconds += int(parts[2])
return seconds
actual_seconds = parse_time(answer)
expected_seconds = parse_time(expected)
if actual_seconds == expected_seconds:
return 1.0
# Partial credit based on how close the times are
max_diff = self.config.max_time_difference_seconds
diff = abs(actual_seconds - expected_seconds)
return max(0.0, 1.0 - (diff / max_diff))
else: # datetime or datetime_tz
# Parse the complex format "X days, HH:MM" or "X days, HH:MM:SS"
def parse_datetime(t):
days = int(t.split(" days,")[0])
time_part = t.split(",")[1].strip()
parts = time_part.split(":")
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
if len(parts) > 2:
seconds += int(parts[2])
return days * 86400 + seconds
actual_seconds = parse_datetime(answer)
expected_seconds = parse_datetime(expected)
if actual_seconds == expected_seconds:
return 1.0
# Partial credit based on total time difference
max_diff = self.config.max_date_difference_days * 86400
diff = abs(actual_seconds - expected_seconds)
return max(0.0, 1.0 - (diff / max_diff))
except (ValueError, IndexError):
return 0.0 # Invalid format
return 0.0
# Register the dataset
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig)