mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
Revert "Restructure {reasoning_gym, tests}/{core, exercises, curricula}"
This reverts commit 10dbb374b0.
This commit is contained in:
parent
b756f26c09
commit
4c3ae0aebf
109 changed files with 0 additions and 0 deletions
30
reasoning_gym/arithmetic/__init__.py
Normal file
30
reasoning_gym/arithmetic/__init__.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
Arithmetic tasks for training reasoning capabilities:
|
||||
- Basic arithmetic
|
||||
- Chain sums
|
||||
- Word problems
|
||||
- Leg counting
|
||||
- Time intervals
|
||||
"""
|
||||
|
||||
# from .basic_arithmetic import BasicArithmeticDataset
|
||||
# from .calendar_arithmetic import CalendarArithmeticDataset
|
||||
from .chain_sum import ChainSumDataset
|
||||
# from .fraction_simplification import FractionSimplificationDataset
|
||||
# from .gcd import GcdDataset
|
||||
# from .lcm import LcmDataset
|
||||
# from .leg_counting import LegCountingDataset
|
||||
# from .prime_factorization import PrimeFactorizationDataset
|
||||
# from .time_intervals import TimeIntervalsDataset
|
||||
|
||||
__all__ = [
|
||||
# "BasicArithmeticDataset",
|
||||
# "CalendarArithmeticDataset",
|
||||
"ChainSumDataset",
|
||||
# "FractionSimplificationDataset",
|
||||
# "GcdDataset",
|
||||
# "LcmDataset",
|
||||
# "LegCountingDataset",
|
||||
# "PrimeFactorizationDataset",
|
||||
# "TimeIntervalsDataset",
|
||||
]
|
||||
235
reasoning_gym/arithmetic/basic_arithmetic.py
Normal file
235
reasoning_gym/arithmetic/basic_arithmetic.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicArithmeticDatasetConfig:
|
||||
"""Configuration for arithmetic dataset generation"""
|
||||
|
||||
min_terms: int = 2
|
||||
max_terms: int = 6
|
||||
min_digits: int = 1
|
||||
max_digits: int = 4
|
||||
operators: list[str] = ("+", "-", "*", "/")
|
||||
allow_parentheses: bool = True
|
||||
allow_negation: bool = True
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
format_style: Literal["simple", "natural"] = "simple"
|
||||
whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_terms > 0, "min_terms must be positive"
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||
assert self.min_digits > 0, "min_digits must be positive"
|
||||
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
|
||||
assert len(self.operators) > 0, "must provide at least one operator"
|
||||
for op in self.operators:
|
||||
assert op in ["+", "-", "*", "/"], f"unsupported operator: {op}"
|
||||
|
||||
|
||||
def find_common_divisors(a: int, b: int) -> list[int]:
|
||||
# Helper function to find GCD using Euclidean algorithm
|
||||
def gcd(x, y):
|
||||
while y:
|
||||
x, y = y, x % y
|
||||
return x
|
||||
|
||||
# Get the GCD of the two numbers
|
||||
gcd_value = gcd(abs(a), abs(b))
|
||||
# Find all divisors of the GCD
|
||||
divisors = []
|
||||
i = 1
|
||||
# We only need to check up to sqrt(gcd_value)
|
||||
while i * i <= gcd_value:
|
||||
if gcd_value % i == 0:
|
||||
divisors.append(i)
|
||||
# Don't add the same number twice for perfect squares
|
||||
if i * i != gcd_value:
|
||||
divisors.append(gcd_value // i)
|
||||
i += 1
|
||||
return divisors
|
||||
|
||||
|
||||
def eval_floordiv(exp: str) -> int:
|
||||
return eval(exp.replace("/", "//"))
|
||||
|
||||
|
||||
class BasicArithmeticDataset(ProceduralDataset):
|
||||
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
|
||||
|
||||
def __init__(self, config: BasicArithmeticDatasetConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
||||
Args:
|
||||
idx: Index of the item to generate
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the formatted arithmetic expression
|
||||
- answer: str, the ground truth result
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
# Create deterministic RNG from base seed and idx
|
||||
item_rng = Random(self.seed + idx)
|
||||
|
||||
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
||||
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
||||
|
||||
if self.config.allow_parentheses:
|
||||
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
|
||||
else:
|
||||
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
|
||||
|
||||
question = self._format_question(item_rng, expression)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
"expression": expression,
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
|
||||
"""Generate a complex arithmetic task with possible parentheses"""
|
||||
|
||||
def add_terms(remaining: int) -> list[str]:
|
||||
# split terms randomly into left and right
|
||||
num_left = rng.randint(1, remaining)
|
||||
num_right = remaining - num_left
|
||||
|
||||
left_parts = []
|
||||
if num_left > 1 and rng.random() > 0.5 and self.config.allow_parentheses:
|
||||
if rng.random() > 0.5 and self.config.allow_negation:
|
||||
left_parts.append("-(")
|
||||
else:
|
||||
left_parts.append("(")
|
||||
left_parts.extend(add_terms(num_left))
|
||||
left_parts.append(")")
|
||||
else:
|
||||
for i in range(num_left):
|
||||
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
|
||||
left_parts.append(str(c))
|
||||
if i + 1 < num_left:
|
||||
left_parts.append(rng.choice([o for o in self.config.operators if o != "/"]))
|
||||
|
||||
if num_right == 0:
|
||||
return left_parts
|
||||
|
||||
op = rng.choice(self.config.operators)
|
||||
if op != "/":
|
||||
left_parts.append(op)
|
||||
left_parts.extend(add_terms(num_right))
|
||||
else:
|
||||
# left part has parantheses or no division
|
||||
dividend = eval_floordiv("".join(left_parts) if left_parts[-1] == ")" else left_parts[-1])
|
||||
left_parts.append(op)
|
||||
|
||||
if num_right > 1:
|
||||
right_parts = add_terms(num_right - 1)
|
||||
if right_parts[-1] == ")":
|
||||
right_value = eval_floordiv("".join(right_parts))
|
||||
|
||||
if right_value == 0:
|
||||
correction = 1
|
||||
else:
|
||||
target = rng.choice(find_common_divisors(dividend, right_value))
|
||||
correction = target - right_value
|
||||
|
||||
right_parts.pop()
|
||||
right_parts.append("+")
|
||||
right_parts.append(str(correction))
|
||||
right_parts.append(")")
|
||||
|
||||
else:
|
||||
divisor = rng.choice(find_common_divisors(dividend, 0))
|
||||
left_parts.append(str(divisor))
|
||||
left_parts.append("+")
|
||||
|
||||
left_parts.extend(right_parts)
|
||||
else:
|
||||
if dividend != 0:
|
||||
divisor = rng.choice(find_common_divisors(dividend, 0))
|
||||
else:
|
||||
divisor = rng.randint(1, 10**num_digits - 1)
|
||||
left_parts.append(str(divisor))
|
||||
|
||||
return left_parts
|
||||
|
||||
parts = add_terms(num_terms)
|
||||
|
||||
# Add whitespace according to config
|
||||
if self.config.whitespace == "no_space":
|
||||
expression = "".join(parts)
|
||||
elif self.config.whitespace == "single":
|
||||
expression = " ".join(parts)
|
||||
else: # random
|
||||
space_parts = []
|
||||
for p in parts:
|
||||
if rng.random() < 0.15:
|
||||
space_parts.append(" ")
|
||||
space_parts.append(p)
|
||||
expression = "".join(space_parts).strip()
|
||||
result = eval_floordiv(expression) # Note: eval is safe here as we control the input
|
||||
|
||||
return expression, result
|
||||
|
||||
def _generate_simple_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
|
||||
"""Generate a simple linear arithmetic task without parentheses"""
|
||||
constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)]
|
||||
operators = [rng.choice(self.config.operators) for _ in range(num_terms - 1)]
|
||||
|
||||
# Build expression and compute result
|
||||
expression_parts = []
|
||||
result = constants[0]
|
||||
|
||||
expression_parts.append(str(constants[0]))
|
||||
for i, op in enumerate(operators):
|
||||
c = constants[i + 1]
|
||||
expression_parts.append(op)
|
||||
expression_parts.append(str(c))
|
||||
|
||||
if op == "+":
|
||||
result += c
|
||||
elif op == "-":
|
||||
result -= c
|
||||
elif op == "*":
|
||||
result *= c
|
||||
elif op == "/":
|
||||
# Find a number that divides result evenly
|
||||
divisors = [d for d in range(2, min(abs(result), 10**num_digits)) if result % d == 0]
|
||||
if divisors:
|
||||
c = rng.choice(divisors)
|
||||
result //= c
|
||||
else:
|
||||
# Fallback to multiplication if no clean division possible
|
||||
op = "*"
|
||||
c = rng.randint(1, 10**num_digits - 1)
|
||||
result *= c
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported operator: {op}")
|
||||
|
||||
expression = " ".join(expression_parts)
|
||||
return expression, result
|
||||
|
||||
def _format_question(self, rng: Random, expression: str) -> str:
|
||||
"""Format the expression according to config style"""
|
||||
if self.config.format_style == "simple":
|
||||
return f"{expression} ="
|
||||
else:
|
||||
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
|
||||
return rng.choice(templates).format(expression)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)
|
||||
490
reasoning_gym/arithmetic/calendar_arithmetic.py
Normal file
490
reasoning_gym/arithmetic/calendar_arithmetic.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
import calendar
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
class Weekday(Enum):
|
||||
MONDAY = auto()
|
||||
TUESDAY = auto()
|
||||
WEDNESDAY = auto()
|
||||
THURSDAY = auto()
|
||||
FRIDAY = auto()
|
||||
SATURDAY = auto()
|
||||
SUNDAY = auto()
|
||||
|
||||
@classmethod
|
||||
def from_date(cls, d: date) -> "Weekday":
|
||||
return list(cls)[d.weekday()]
|
||||
|
||||
@classmethod
|
||||
def random(cls, rng: random.Random) -> "Weekday":
|
||||
return list(cls)[rng.randint(0, 6)]
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, idx) -> "Weekday":
|
||||
return list(cls)[idx]
|
||||
|
||||
@property
|
||||
def index(self) -> int:
|
||||
return self.value - 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name.capitalize()
|
||||
|
||||
|
||||
class CalendarTask(Enum):
|
||||
WEEKDAY_OFFSET = "weekday_offset"
|
||||
WEEKDAY_OF_DATE = "weekday_of_date"
|
||||
WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day"
|
||||
RECURRING_EVENT_CALCULATIONS = "recurring_event_day"
|
||||
COUNT_DAYS = "count_days"
|
||||
COUNT_BUSINESS_DAYS = "count_business_days"
|
||||
IS_LEAP_YEAR = "is_leap_year"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalendarArithmeticConfig:
|
||||
year: int = 2022
|
||||
tasks: Optional[List[str]] = None
|
||||
offset_upper_bound: int = 100
|
||||
leap_year_range: int = 200
|
||||
seed: Optional[int] = 42
|
||||
size: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tasks is None:
|
||||
self.tasks = [task.value for task in CalendarTask]
|
||||
else:
|
||||
self.tasks = [task.lower() for task in self.tasks]
|
||||
valid_tasks = {task.value for task in CalendarTask}
|
||||
invalid_tasks = set(self.tasks) - valid_tasks
|
||||
if invalid_tasks:
|
||||
valid_task_list = ", ".join(sorted(valid_tasks))
|
||||
raise ValueError(
|
||||
f"Invalid tasks: {', '.join(sorted(invalid_tasks))}. " f"Valid tasks are: {valid_task_list}"
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the configuration parameters."""
|
||||
if not isinstance(self.year, int) or self.year <= 0:
|
||||
raise ValueError(f"year must be a positive integer, got {self.year}")
|
||||
|
||||
if self.seed is not None and not isinstance(self.seed, int):
|
||||
raise ValueError(f"seed must be an integer or None, got {type(self.seed)}")
|
||||
|
||||
if not isinstance(self.size, int) or self.size <= 0:
|
||||
raise ValueError(f"size must be a positive integer, got {self.size}")
|
||||
|
||||
|
||||
class CalendarArithmeticDataset(ProceduralDataset):
|
||||
DAY_QUESTION_TEMPLATES = [
|
||||
"Answer with the weekday's name (e.g., Monday, Tuesday, etc.).",
|
||||
"Provide the full name of the weekday.",
|
||||
"State the weekday (Monday through Sunday).",
|
||||
"Give the weekday name in full.",
|
||||
"Reply with just the weekday name.",
|
||||
"Write out the full weekday name.",
|
||||
"Respond with the weekday (Monday-Sunday).",
|
||||
"Answer using the complete weekday name.",
|
||||
"Name the day of the week in full.",
|
||||
]
|
||||
|
||||
COUNT_QUESTION_TEMPLATES = [
|
||||
"Answer with a number.",
|
||||
"Provide the count as a number.",
|
||||
"Respond with just the number.",
|
||||
"Write the total number.",
|
||||
"Give the count numerically.",
|
||||
"State the amount as a number.",
|
||||
"Reply with the numerical value.",
|
||||
"Express your answer as a number.",
|
||||
]
|
||||
|
||||
def __init__(self, config: CalendarArithmeticConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
self.task_handlers = {
|
||||
CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day,
|
||||
CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day,
|
||||
CalendarTask.COUNT_DAYS.value: self._count_days,
|
||||
CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days,
|
||||
CalendarTask.IS_LEAP_YEAR.value: self._is_leap_year,
|
||||
}
|
||||
|
||||
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
item_rng = random.Random(self.seed + idx)
|
||||
task = item_rng.choice(self.tasks)
|
||||
question, answer, metadata = task(item_rng)
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(answer),
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
Task: Given a starting date and a day offset (which may be positive or negative),
|
||||
ask what day of the week it will be.
|
||||
Examples:
|
||||
- "If today is Wednesday, March 13, 2024, what day of the week will it be in 10 days? Answer with the weekday's name."
|
||||
- "If today is Wednesday, March 13, 2024, what day of the week was it 10 days ago? Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
start_date = self._random_date_for_year(rng, year)
|
||||
offset = rng.randint(1, self.config.offset_upper_bound)
|
||||
sign = rng.choice([-1, 1])
|
||||
offset_days = sign * offset
|
||||
target_date = start_date + timedelta(days=offset_days)
|
||||
target_weekday = target_date.strftime("%A")
|
||||
|
||||
date_str = f"{start_date.strftime('%A')}, {start_date.strftime('%B')} {start_date.day}, {start_date.year}"
|
||||
if offset_days >= 0:
|
||||
templates = [
|
||||
f"If today is {date_str}, what day of the week will it be in {offset_days} days? ",
|
||||
f"Starting from {date_str}, which weekday falls after a {offset_days}-day jump? ",
|
||||
f"Count forward {offset_days} days from {date_str} - what's the weekday? ",
|
||||
]
|
||||
else:
|
||||
templates = [
|
||||
f"If today is {date_str}, what day of the week was it {abs(offset_days)} days ago? ",
|
||||
f"Starting from {date_str}, which weekday was it {abs(offset_days)} days before? ",
|
||||
f"Count backward {abs(offset_days)} days from {date_str} - what's the weekday? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OFFSET.value,
|
||||
"start_date": start_date.isoformat(),
|
||||
"offset_days": offset_days,
|
||||
"target_date": target_date.isoformat(),
|
||||
}
|
||||
return question, target_weekday, metadata
|
||||
|
||||
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask what day of the week a given date was.
|
||||
example:
|
||||
"What day of the week was January 15, 2024?
|
||||
Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
target_date = self._random_date_for_year(rng, year)
|
||||
answer_weekday = target_date.strftime("%A")
|
||||
templates = [
|
||||
f"What day of the week was {target_date.strftime('%B')} {target_date.day}, {year}?",
|
||||
f"On which weekday did {target_date.strftime('%B')} {target_date.day}, {year} fall?",
|
||||
f"Name the day of the week for {target_date.strftime('%m/%d/%Y')}.",
|
||||
]
|
||||
|
||||
question = f"{rng.choice(templates)} {rng.choice(self.DAY_QUESTION_TEMPLATES)}"
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
"target_date": target_date.isoformat(),
|
||||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
|
||||
example:
|
||||
"If the first day of the year was a Monday, what day of the week will December 31 be?
|
||||
Answer with the weekday's name."
|
||||
"""
|
||||
year = self.config.year
|
||||
first_day = Weekday.random(rng)
|
||||
first_day_index = first_day.index
|
||||
# Ensure target date is not January 1.
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
max_date = min(year_start + max_delta, year_end)
|
||||
while True:
|
||||
target_date = self._random_date_between(rng, year_start, max_date)
|
||||
if target_date != date(year, 1, 1):
|
||||
break
|
||||
delta_days = (target_date - date(year, 1, 1)).days
|
||||
answer_index = (first_day_index + delta_days) % 7
|
||||
answer_weekday = Weekday(answer_index + 1)
|
||||
|
||||
templates = [
|
||||
f"If the first day of the year was a {first_day}, what day of the week will "
|
||||
f"{target_date.strftime('%B')} {target_date.day} be? ",
|
||||
f"Given that January 1 fell on a {first_day}, which weekday occurs on "
|
||||
f"{target_date.strftime('%B')} {target_date.day}? ",
|
||||
f"In a year where {first_day} is January 1st, name the weekday of "
|
||||
f"{target_date.strftime('%B')} {target_date.day}. ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
"year": year,
|
||||
"first_day": str(first_day),
|
||||
"target_date": target_date.isoformat(),
|
||||
"delta_days": delta_days,
|
||||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: For a recurring event defined by an ordinal weekday pattern in a month,
|
||||
ask on which day of the month the event occurs.
|
||||
example:
|
||||
"If a meeting is scheduled on the second Tuesday of May 2024, on which day does it fall?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
month = rng.randint(1, 12)
|
||||
ordinals = ["first", "second", "third", "fourth", "last"]
|
||||
ordinal = rng.choice(ordinals)
|
||||
weekday = Weekday.random(rng)
|
||||
month_name = calendar.month_name[month]
|
||||
_, last_day = calendar.monthrange(year, month)
|
||||
|
||||
if ordinal != "last":
|
||||
ordinal_number = {"first": 1, "second": 2, "third": 3, "fourth": 4}[ordinal]
|
||||
count = 0
|
||||
event_day = None
|
||||
for day in range(1, last_day + 1):
|
||||
d = date(year, month, day)
|
||||
if d.strftime("%A") == str(weekday):
|
||||
count += 1
|
||||
if count == ordinal_number:
|
||||
event_day = day
|
||||
break
|
||||
if event_day is None:
|
||||
# This should rarely happen but in some months the ordinal may not exist.
|
||||
event_day = -1
|
||||
else:
|
||||
event_day = None
|
||||
for day in range(last_day, 0, -1):
|
||||
d = date(year, month, day)
|
||||
if d.strftime("%A") == str(weekday):
|
||||
event_day = day
|
||||
break
|
||||
if event_day is None:
|
||||
event_day = -1
|
||||
|
||||
templates = [
|
||||
f"If a meeting is scheduled on the {ordinal} {weekday} of {month_name} {year}, on which day of the month does it occur? ",
|
||||
f"In {month_name} {year}, if an event recurs on the {ordinal} {weekday}, what is the date (day of the month) of the event? ",
|
||||
f"Determine the day of the month for the {ordinal} {weekday} in {month_name} {year}. ",
|
||||
]
|
||||
question = (
|
||||
rng.choice(templates)
|
||||
+ rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
+ " Answer with -1 if the ordinal does not exist in the month."
|
||||
)
|
||||
metadata = {
|
||||
"task": CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
|
||||
"year": year,
|
||||
"month": month,
|
||||
"ordinal": ordinal,
|
||||
"weekday": str(weekday),
|
||||
}
|
||||
return question, str(event_day), metadata
|
||||
|
||||
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask how many times a given weekday occurs in a specified range.
|
||||
example:
|
||||
"How many days are there between March 1, 2024 and March 15, 2024?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
start_date = self._random_date_between(rng, year_start, year_end)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
end_date = self._random_date_between(rng, start_date, min(year_end, start_date + max_delta))
|
||||
weekday = Weekday.random(rng)
|
||||
|
||||
def count_weekday_between(d1: date, d2: date, weekday: str) -> int:
|
||||
days = (d2 - d1).days + 1
|
||||
return sum(1 for i in range(days) if (d1 + timedelta(days=i)).strftime("%A") == weekday)
|
||||
|
||||
count = count_weekday_between(start_date, end_date, str(weekday))
|
||||
|
||||
templates = [
|
||||
f"How many {weekday}s are there from {start_date.strftime('%A, %B')} {start_date.day}, {year} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} (inclusive of both dates)? ",
|
||||
f"Count the occurrences of {weekday} from {start_date.strftime('%A, %B')} {start_date.day} "
|
||||
f"to {end_date.strftime('%A, %B')} {end_date.day}, {year} (including both start and end dates). ",
|
||||
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(counting both dates), how many times does {weekday} occur? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.COUNT_DAYS.value,
|
||||
"year": year,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Count the number of business days (Monday-Friday) between two dates.
|
||||
example:
|
||||
"How many business days (Monday-Friday) are there between March 1, 2024 and March 15, 2024?
|
||||
Answer with a number."
|
||||
"""
|
||||
year = self.config.year
|
||||
year_start = date(year, 1, 1)
|
||||
year_end = date(year, 12, 31)
|
||||
start_date = self._random_date_between(rng, year_start, year_end)
|
||||
max_delta = timedelta(days=self.config.offset_upper_bound)
|
||||
end_date = self._random_date_between(rng, start_date, start_date + max_delta)
|
||||
|
||||
count = 0
|
||||
|
||||
def business_days_between(d1: date, d2: date) -> int:
|
||||
days = (d2 - d1).days + 1
|
||||
weeks, remainder = divmod(days, 7)
|
||||
count = weeks * 5
|
||||
start_weekday = d1.weekday()
|
||||
for i in range(remainder):
|
||||
if (start_weekday + i) % 7 < 5:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
count = business_days_between(start_date, end_date)
|
||||
|
||||
templates = [
|
||||
f"How many business days (Monday-Friday) are there from "
|
||||
f"{start_date.strftime('%A, %B')} {start_date.day}, {year} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(inclusive of both dates)? ",
|
||||
f"Count the weekdays (excluding weekends) from "
|
||||
f"{start_date.strftime('%A, %B')} {start_date.day} to "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(including both start and end dates). ",
|
||||
f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and "
|
||||
f"{end_date.strftime('%A, %B')} {end_date.day}, {year} "
|
||||
f"(counting both dates), what's the total count of business days "
|
||||
f"(Monday through Friday)? ",
|
||||
]
|
||||
|
||||
question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES)
|
||||
metadata = {
|
||||
"task": CalendarTask.COUNT_BUSINESS_DAYS.value,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given a year, determine whether it is a leap year.
|
||||
example:
|
||||
"Is 2024 a leap year? Answer with Yes or No."
|
||||
"""
|
||||
semirange = self.config.leap_year_range // 2
|
||||
year = rng.randint(self.config.year - semirange, self.config.year + semirange)
|
||||
is_leap = calendar.isleap(year)
|
||||
answer = "Yes" if is_leap else "No"
|
||||
templates = [
|
||||
f"Determine if the year {year} is a leap year. ",
|
||||
f"Is {year} a leap year? ",
|
||||
f"Tell me whether {year} is a leap year. ",
|
||||
]
|
||||
question = rng.choice(templates) + "Answer with Yes or No."
|
||||
metadata = {
|
||||
"task": CalendarTask.IS_LEAP_YEAR.value,
|
||||
"year": year,
|
||||
"is_leap": is_leap,
|
||||
}
|
||||
return question, answer, metadata
|
||||
|
||||
def _random_date_for_year(self, rng: random.Random, year: int) -> date:
|
||||
"""Return a random date within the given year."""
|
||||
month = rng.randint(1, 12)
|
||||
_, last_day = calendar.monthrange(year, month)
|
||||
day = rng.randint(1, last_day)
|
||||
return date(year, month, day)
|
||||
|
||||
def _random_date_between(self, rng: random.Random, start_date: date, end_date: date) -> date:
|
||||
"""
|
||||
Return a random date between start_date and end_date (inclusive).
|
||||
Assumes start_date <= end_date.
|
||||
"""
|
||||
if start_date > end_date:
|
||||
raise ValueError("start_date must be <= end_date")
|
||||
delta = (end_date - start_date).days
|
||||
random_days = rng.randint(0, delta)
|
||||
return start_date + timedelta(days=random_days)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
# we suppose the answer is the last occurence of the expected answer type
|
||||
if answer is None:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
task = entry["metadata"]["task"]
|
||||
|
||||
if task in {
|
||||
CalendarTask.WEEKDAY_OFFSET.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value,
|
||||
CalendarTask.WEEKDAY_OF_DATE.value,
|
||||
}:
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
answer = answer.strip()
|
||||
oracle_answer = oracle_answer
|
||||
weekdays = {d.name.title() for d in Weekday}
|
||||
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
|
||||
if answer in weekdays:
|
||||
return 0.1
|
||||
|
||||
if answer.title() in weekdays:
|
||||
return 0.05
|
||||
|
||||
if answer.title() not in weekdays:
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
|
||||
# denser reward for numerical tasks
|
||||
elif task in {
|
||||
CalendarTask.COUNT_BUSINESS_DAYS.value,
|
||||
CalendarTask.COUNT_DAYS.value,
|
||||
CalendarTask.RECURRING_EVENT_CALCULATIONS.value,
|
||||
}:
|
||||
try:
|
||||
ans_num = int(answer.strip())
|
||||
oracle_num = int(oracle_answer.strip())
|
||||
|
||||
if oracle_num == 0:
|
||||
return 1.0 if ans_num == 0 else 0.0
|
||||
|
||||
relative_error = abs(ans_num - oracle_num) / oracle_num
|
||||
return max(0.0, math.exp(-5 * relative_error))
|
||||
|
||||
except (ValueError, AttributeError):
|
||||
return 0.0
|
||||
|
||||
elif task == CalendarTask.IS_LEAP_YEAR.value:
|
||||
if answer.strip().lower() == oracle_answer.lower():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig)
|
||||
103
reasoning_gym/arithmetic/chain_sum.py
Normal file
103
reasoning_gym/arithmetic/chain_sum.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
import operator
|
||||
import numpy as np
|
||||
from reasoning_gym.core.base_curriculum import BaseCurriculum
|
||||
|
||||
@dataclass
|
||||
class ChainSumDataset:
|
||||
"""Dataset generator for chain arithmetic problems."""
|
||||
def __init__(self):
|
||||
# Define operator mappings
|
||||
self.pedmas = {
|
||||
'**': (operator.pow, 3), # (function, precedence)
|
||||
'*': (operator.mul, 2),
|
||||
'/': (operator.truediv, 2),
|
||||
'+': (operator.add, 1),
|
||||
'-': (operator.sub, 1)
|
||||
}
|
||||
self.curriculum = None
|
||||
|
||||
def generate(self, curriculum: BaseCurriculum) -> Dict[str, Any]:
|
||||
"""Generate a problem using the curriculum's template system"""
|
||||
self.curriculum = curriculum
|
||||
max_attempts = 10
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
template = curriculum.get_template(curriculum.rng)
|
||||
return template.eval(self, curriculum.rng)
|
||||
except ValueError as e:
|
||||
if "Invalid operation" in str(e):
|
||||
continue
|
||||
raise
|
||||
|
||||
def _parse_expression(self, executed_parts: Dict[str, str]) -> tuple[list, list]:
|
||||
"""Extract values and operators from executed parts"""
|
||||
values = []
|
||||
operators = []
|
||||
|
||||
i = 0
|
||||
while f"term_{i}" in executed_parts:
|
||||
val = executed_parts[f"term_{i}"].lstrip('+')
|
||||
try:
|
||||
num = val.lstrip('-')
|
||||
if num.startswith(('0b', '0x')):
|
||||
sign = -1 if val.startswith('-') else 1
|
||||
base = 2 if num.startswith('0b') else 16 if num.startswith('0x') else 10
|
||||
values.append(sign * float(int(num[2:], base)))
|
||||
else:
|
||||
values.append(float(val))
|
||||
except ValueError:
|
||||
values.append(val)
|
||||
i += 1
|
||||
|
||||
# Extract operators
|
||||
for i in range(len(values) - 1):
|
||||
if f"op_{i}" in executed_parts:
|
||||
operators.append(executed_parts[f"op_{i}"])
|
||||
|
||||
return values, operators
|
||||
|
||||
def _evaluate_expression(self, values: list, operators: list) -> float:
|
||||
"""Evaluate expression respecting operator precedence"""
|
||||
if not operators:
|
||||
return values[0] if values else 0
|
||||
|
||||
vals, ops = list(values), list(operators)
|
||||
|
||||
def handle_edge(op, a, b):
|
||||
# Handle division first
|
||||
if op == '/':
|
||||
if np.isclose(b, 0):
|
||||
raise ValueError("chain_sum.py: Invalid operation, division by zero")
|
||||
# Handle exponentiation edge cases
|
||||
if op == '**':
|
||||
if np.isclose(a, 0) and b < 0:
|
||||
raise ValueError("chain_sum.py: Invalid operation, zero with negative exponent")
|
||||
if a < 0 and not isinstance(b, int) and not b.is_integer():
|
||||
raise ValueError("chain_sum.py: Invalid operation, fractional exponent of negative base")
|
||||
|
||||
# Handle potential overflows
|
||||
try:
|
||||
result = self.pedmas[op][0](a, b)
|
||||
if abs(result) > np.finfo(float).max:
|
||||
raise OverflowError
|
||||
return result
|
||||
except OverflowError:
|
||||
raise ValueError("chain_sum.py: Invalid operation, overflow in calculation")
|
||||
|
||||
for precedence in sorted({self.pedmas[op][1] for op in ops}, reverse=True):
|
||||
i = 0
|
||||
while i < len(ops):
|
||||
if self.pedmas[ops[i]][1] != precedence:
|
||||
i += 1
|
||||
continue
|
||||
op = ops[i]
|
||||
a, b = vals[i], vals[i + 1]
|
||||
result = handle_edge(op, a, b)
|
||||
vals[i] = result # Replace first value with result
|
||||
del vals[i + 1] # Remove second value
|
||||
del ops[i] # Remove used operator
|
||||
|
||||
return vals[0]
|
||||
123
reasoning_gym/arithmetic/fraction_simplification.py
Normal file
123
reasoning_gym/arithmetic/fraction_simplification.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Fraction simplification task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class FractionSimplificationConfig:
|
||||
"""Configuration for fraction simplification task generation"""
|
||||
|
||||
min_value: int = 1 # Minimum value for numerator/denominator
|
||||
max_value: int = 1000 # Maximum value for numerator/denominator
|
||||
min_factor: int = 1 # Minimum multiplication factor
|
||||
max_factor: int = 100 # Maximum multiplication factor
|
||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_value > 0, "min_value must be positive"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
assert self.min_factor >= 1, "min_factor must be at least 1"
|
||||
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
|
||||
|
||||
# Validate styles
|
||||
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
|
||||
for style in self.styles:
|
||||
assert style in valid_styles, f"Invalid style: {style}. Must be one of {valid_styles}"
|
||||
|
||||
|
||||
class FractionSimplificationDataset(ProceduralDataset):
|
||||
"""Generates fraction simplification tasks"""
|
||||
|
||||
def __init__(self, config: FractionSimplificationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
|
||||
"""Generate a random fraction and its simplified form.
|
||||
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
||||
# Try to generate valid fractions until we get one that meets our criteria
|
||||
for _ in range(10): # Limit attempts to avoid infinite loop
|
||||
# Generate the simplified fraction first
|
||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Make sure they're coprime by dividing by their GCD
|
||||
common = gcd(simplified_num, simplified_den)
|
||||
simplified_num //= common
|
||||
simplified_den //= common
|
||||
|
||||
# Check if simplified fraction is within bounds
|
||||
if (
|
||||
self.config.min_value <= simplified_num <= self.config.max_value
|
||||
and self.config.min_value <= simplified_den <= self.config.max_value
|
||||
):
|
||||
# Ensure numerator is smaller than denominator
|
||||
if simplified_num > simplified_den:
|
||||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
# Multiply both by a random factor to create the unsimplified version
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
numerator = simplified_num * factor
|
||||
denominator = simplified_den * factor
|
||||
return numerator, denominator, simplified_num, simplified_den
|
||||
|
||||
# If we failed to find a good fraction after max attempts,
|
||||
# generate one that's guaranteed to be within bounds
|
||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Ensure numerator is smaller than denominator
|
||||
if simplified_num > simplified_den:
|
||||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||
|
||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||
"""Format a fraction in various styles"""
|
||||
if style == "plain":
|
||||
return f"{num}/{den}"
|
||||
elif style == "latex_inline":
|
||||
return f"${num}/{den}$"
|
||||
elif style == "latex_frac":
|
||||
return f"$\\frac{{{num}}}{{{den}}}$"
|
||||
elif style == "latex_dfrac":
|
||||
return f"$\\dfrac{{{num}}}{{{den}}}$"
|
||||
else:
|
||||
raise ValueError(f"Unknown fraction style: {style}")
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single fraction simplification task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||
|
||||
# Choose a random style from configured styles
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||
|
||||
# Format both question and answer in the same style
|
||||
question_fraction = self._format_fraction(num, den, style)
|
||||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||
|
||||
return {
|
||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
||||
"answer": answer_fraction,
|
||||
"metadata": {
|
||||
"numerator": num,
|
||||
"denominator": den,
|
||||
"simplified_numerator": simple_num,
|
||||
"simplified_denominator": simple_den,
|
||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||
"style": style,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
66
reasoning_gym/arithmetic/gcd.py
Normal file
66
reasoning_gym/arithmetic/gcd.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Greatest Common Divisor (GCD) task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCDConfig:
|
||||
"""Configuration for GCD task generation"""
|
||||
|
||||
min_numbers: int = 2 # Minimum numbers to find GCD of
|
||||
max_numbers: int = 2 # Maximum numbers to find GCD of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 1000 # Maximum value for each number
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers >= 2, "min_numbers must be at least 2"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
assert self.min_value >= 1, "min_value must be positive"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class GCDDataset(ProceduralDataset):
|
||||
"""Generates Greatest Common Divisor (GCD) tasks"""
|
||||
|
||||
def __init__(self, config: GCDConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their GCD.
|
||||
Will try up to 3 times to find numbers with GCD > 1."""
|
||||
|
||||
# Try up to 3 times to get GCD > 1
|
||||
for _ in range(3):
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(gcd, numbers)
|
||||
if result > 1:
|
||||
break
|
||||
|
||||
# Return the last generated numbers, whether they met the criteria or not
|
||||
return numbers, result
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single GCD task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||
"answer": str(result),
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("gcd", GCDDataset, GCDConfig)
|
||||
69
reasoning_gym/arithmetic/lcm.py
Normal file
69
reasoning_gym/arithmetic/lcm.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Least Common Multiple (LCM) task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import lcm
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class LCMConfig:
|
||||
"""Configuration for LCM task generation"""
|
||||
|
||||
min_numbers: int = 2 # Minimum numbers to find LCM of
|
||||
max_numbers: int = 2 # Maximum numbers to find LCM of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers >= 2, "min_numbers must be at least 2"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
assert self.min_value >= 1, "min_value must be positive"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
||||
|
||||
class LCMDataset(ProceduralDataset):
|
||||
"""Generates Least Common Multiple (LCM) tasks"""
|
||||
|
||||
def __init__(self, config: LCMConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their LCM.
|
||||
Will try up to 3 times to find numbers with LCM < product."""
|
||||
|
||||
def calculate_product(nums: List[int]) -> int:
|
||||
return reduce(lambda x, y: x * y, nums)
|
||||
|
||||
# Try up to 3 times to get LCM < product
|
||||
for _ in range(3):
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(lcm, numbers)
|
||||
if result < calculate_product(numbers):
|
||||
break
|
||||
|
||||
# Return the last generated numbers, whether they met the criteria or not
|
||||
return numbers, result
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single LCM task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
||||
"answer": str(result),
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("lcm", LCMDataset, LCMConfig)
|
||||
118
reasoning_gym/arithmetic/leg_counting.py
Normal file
118
reasoning_gym/arithmetic/leg_counting.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Leg counting task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
ANIMALS = {
|
||||
# Animals with 0 legs
|
||||
"snake": 0,
|
||||
"sea slug": 0,
|
||||
"jellyfish": 0,
|
||||
"flatworm": 0,
|
||||
"leech": 0,
|
||||
# Animals with 2 legs
|
||||
"chicken": 2,
|
||||
"bird": 2,
|
||||
"human": 2,
|
||||
"duck": 2,
|
||||
# Animals with 4 legs
|
||||
"dog": 4,
|
||||
"cat": 4,
|
||||
"cow": 4,
|
||||
"horse": 4,
|
||||
"lion": 4,
|
||||
"elephant": 4,
|
||||
"giraffe": 4,
|
||||
"tiger": 4,
|
||||
"deer": 4,
|
||||
"sheep": 4,
|
||||
# Animals with 5 legs
|
||||
"starfish": 5,
|
||||
# Animals with 6 legs
|
||||
"insect": 6,
|
||||
"ant": 6,
|
||||
"butterfly": 6,
|
||||
"beetle": 6,
|
||||
"bee": 6,
|
||||
"wasp": 6,
|
||||
"grasshopper": 6,
|
||||
"cricket": 6,
|
||||
"cockroach": 6,
|
||||
"praying mantis": 6,
|
||||
"firefly": 6,
|
||||
# Animals with 8 legs
|
||||
"spider": 8,
|
||||
"scorpion": 8,
|
||||
# Animals with 10 legs
|
||||
"crab": 10,
|
||||
"lobster": 10,
|
||||
"shrimp": 10,
|
||||
# Animals with 14 legs
|
||||
"woodlouse": 14,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegCountingConfig:
|
||||
"""Configuration for leg counting task generation"""
|
||||
|
||||
min_animals: int = 2 # Minimum number of animals in problem
|
||||
max_animals: int = 5 # Maximum number of animals
|
||||
max_instances: int = 3 # Maximum instances of each animal
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_animals > 0, "min_animals must be positive"
|
||||
assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals"
|
||||
assert self.max_instances > 0, "max_instances must be positive"
|
||||
|
||||
|
||||
class LegCountingDataset(ProceduralDataset):
|
||||
"""Generates leg counting arithmetic tasks"""
|
||||
|
||||
def __init__(self, config: LegCountingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_animals(self, rng: Random) -> Dict[str, int]:
|
||||
"""Generate a random set of animals and their counts"""
|
||||
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
||||
animals = {}
|
||||
|
||||
# Select random animals
|
||||
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
|
||||
for animal in selected_animals:
|
||||
count = rng.randint(1, self.config.max_instances)
|
||||
animals[animal] = count
|
||||
|
||||
return animals
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single leg counting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Generate random animals and their counts
|
||||
animals = self._generate_animals(rng)
|
||||
|
||||
# Calculate total legs
|
||||
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
||||
|
||||
# Format animal counts for question
|
||||
animal_list = []
|
||||
for animal, count in animals.items():
|
||||
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
||||
|
||||
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(total_legs),
|
||||
"metadata": {"animals": animals, "total_legs": total_legs},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
|
||||
69
reasoning_gym/arithmetic/prime_factorization.py
Normal file
69
reasoning_gym/arithmetic/prime_factorization.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Prime factorization task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrimeFactorizationConfig:
|
||||
"""Configuration for prime factorization task generation"""
|
||||
|
||||
min_value: int = 2 # Minimum number to factorize
|
||||
max_value: int = 1000 # Maximum number to factorize
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_value >= 2, "min_value must be >= 2"
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
||||
|
||||
|
||||
class PrimeFactorizationDataset(ProceduralDataset):
|
||||
"""Generates prime factorization tasks"""
|
||||
|
||||
def __init__(self, config: PrimeFactorizationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _prime_factors(self, n: int) -> List[int]:
|
||||
"""Compute prime factors of a number"""
|
||||
factors = []
|
||||
d = 2
|
||||
while n > 1:
|
||||
while n % d == 0:
|
||||
factors.append(d)
|
||||
n //= d
|
||||
d += 1
|
||||
if d * d > n:
|
||||
if n > 1:
|
||||
factors.append(n)
|
||||
break
|
||||
return factors
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single prime factorization task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Generate random number to factorize
|
||||
number = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
# Calculate prime factors
|
||||
factors = self._prime_factors(number)
|
||||
|
||||
# Format answer as multiplication of prime factors
|
||||
answer = " × ".join(map(str, factors))
|
||||
|
||||
return {
|
||||
"question": (
|
||||
f"Find the prime factorization of {number}. Write the factors separated by × "
|
||||
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
|
||||
),
|
||||
"answer": answer,
|
||||
"metadata": {"number": number, "factors": factors},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)
|
||||
323
reasoning_gym/arithmetic/time_intervals.py
Normal file
323
reasoning_gym/arithmetic/time_intervals.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeIntervalsConfig:
|
||||
"""Configuration for time interval calculation tasks"""
|
||||
|
||||
min_time: time = time.min
|
||||
max_time: time = time.max
|
||||
max_time_difference_seconds: int = 24 * 60 * 60
|
||||
min_date: date = date(1900, 1, 1)
|
||||
max_date: date = date(3000, 1, 1)
|
||||
max_date_difference_days: int = 100
|
||||
task_types: List[str] = field(
|
||||
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.max_time_difference_seconds > 0, "max_time_difference_seconds must be positive"
|
||||
assert self.max_date_difference_days > 0, "max_date_difference_days must be positive"
|
||||
assert self.min_date < self.max_date, "min_date must be before max_date"
|
||||
|
||||
|
||||
class TimeIntervalsDataset(ProceduralDataset):
|
||||
"""Generates time interval calculation tasks with various formats and complexities"""
|
||||
|
||||
TEMPLATES = [
|
||||
"What is the duration between {start} and {end}? Please answer in {format}.",
|
||||
"Calculate the time difference between {start} and {end}. Express the result in {format}.",
|
||||
"How much time elapsed from {start} to {end}? Give your answer in {format}.",
|
||||
"A meeting started at {start} and ended at {end}. How long was the meeting? Answer in {format}.",
|
||||
"A system operation started at {start} and completed at {end}. What was the operation duration? Answer in {format}.",
|
||||
"A database query started at {start} and ended at {end}. How long did the query take? Answer in {format}.",
|
||||
"A flight departed at {start} and arrived at {end}. How long was the flight? Answer in {format}.",
|
||||
"A video call started at {start} and ended at {end}. How long was the call? Answer in {format}.",
|
||||
"A system backup started at {start} and completed at {end}. What was the total backup duration? Answer in {format}.",
|
||||
"A conference call began at {start} and ended at {end}. How long was the conference? Answer in {format}.",
|
||||
]
|
||||
|
||||
TIME_FORMATS = [
|
||||
"%H:%M",
|
||||
"%H:%M:%S",
|
||||
"%H:%M:%S.%f",
|
||||
]
|
||||
|
||||
DATE_FORMATS = [
|
||||
"%Y-%m-%d",
|
||||
"%B %d, %Y",
|
||||
"%m/%d/%Y",
|
||||
"%A, %B %d, %Y", # e.g. Monday, January 15, 2024
|
||||
"%a %b %d %Y", # e.g. Mon Jan 15 2024
|
||||
"%d %B %Y", # e.g. 15 January 2024
|
||||
"%Y-%m-%d (%A)", # e.g. 2024-01-15 (Monday)
|
||||
]
|
||||
|
||||
DATETIME_FORMATS = [
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M %z", # For UTC offset format
|
||||
"%Y-%m-%d %H:%M:%S %z", # For UTC offset with seconds
|
||||
"%A, %B %d, %Y at %H:%M", # e.g. Monday, January 15, 2024 at 14:30
|
||||
"%a %b %d %Y %H:%M:%S", # e.g. Mon Jan 15 2024 14:30:45
|
||||
"%d %B %Y, %H:%M", # e.g. 15 January 2024, 14:30
|
||||
"%d %B %Y, %H:%M %z", # e.g. 15 January 2024, 14:30 +0000
|
||||
"%Y-%m-%d (%A) %H:%M:%S %z", # e.g. 2024-01-15 (Monday) 14:30:45 +0000
|
||||
]
|
||||
|
||||
def __init__(self, config: TimeIntervalsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single time interval calculation task"""
|
||||
item_rng = random.Random(self.seed + idx)
|
||||
|
||||
# Randomly choose task type from config
|
||||
task_type = item_rng.choice(self.config.task_types)
|
||||
|
||||
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type)
|
||||
|
||||
template = item_rng.choice(self.TEMPLATES)
|
||||
question = template.format(start=start_time, end=end_time, format=expected_format)
|
||||
|
||||
# Calculate the actual difference
|
||||
if isinstance(start_time, str):
|
||||
# Handle datetime strings with weekday names in parentheses
|
||||
start_time = start_time.split(" (")[0] # Remove (Weekday) if present
|
||||
end_time = end_time.split(" (")[0]
|
||||
# Parse with UTC offset handling
|
||||
start_dt = parser.parse(start_time)
|
||||
end_dt = parser.parse(end_time)
|
||||
else:
|
||||
start_dt = start_time
|
||||
end_dt = end_time
|
||||
|
||||
difference = end_dt - start_dt
|
||||
|
||||
# Format the answer according to expected_format
|
||||
if expected_format == "HH:MM":
|
||||
total_seconds = difference.total_seconds()
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}"
|
||||
elif expected_format == "HH:MM:SS":
|
||||
total_seconds = difference.total_seconds()
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}"
|
||||
elif expected_format == "HH:MM:SS.mmm":
|
||||
total_seconds = difference.total_seconds()
|
||||
ms = int((total_seconds % 1) * 1000)
|
||||
answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}.{ms:03d}"
|
||||
elif expected_format == "D days":
|
||||
answer = f"{difference.days} days"
|
||||
else: # "D days, HH:MM" or "D days, HH:MM:SS"
|
||||
days = difference.days
|
||||
hours = difference.seconds // 3600
|
||||
minutes = (difference.seconds % 3600) // 60
|
||||
seconds = difference.seconds % 60
|
||||
if expected_format == "D days, HH:MM:SS":
|
||||
answer = f"{days} days, {hours:02d}:{minutes:02d}:{seconds:02d}"
|
||||
else: # "D days, HH:MM"
|
||||
answer = f"{days} days, {hours:02d}:{minutes:02d}"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"task_type": task_type,
|
||||
"start_time": start_dt,
|
||||
"end_time": end_dt,
|
||||
"format": format_str,
|
||||
"expected_format": expected_format,
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_times(self, rng: random.Random, task_type: str):
|
||||
"""Generate start and end times based on task type"""
|
||||
if task_type.startswith("time"):
|
||||
if task_type == "time_ms":
|
||||
format_str = self.TIME_FORMATS[2] # Get milliseconds format
|
||||
expected_format = "HH:MM:SS.mmm"
|
||||
else:
|
||||
format_str = next(f for f in self.TIME_FORMATS if f.count(":") == (2 if "seconds" in task_type else 1))
|
||||
expected_format = "HH:MM:SS" if "seconds" in task_type else "HH:MM"
|
||||
|
||||
# Generate random start time
|
||||
start_hour = rng.randint(0, 23)
|
||||
start_minute = rng.randint(0, 59)
|
||||
start_second = rng.randint(0, 59)
|
||||
base = datetime.combine(date.today(), time(start_hour, start_minute, start_second))
|
||||
|
||||
# Calculate seconds remaining until midnight
|
||||
seconds_until_midnight = ((24 - start_hour) * 3600) - (start_minute * 60) - start_second
|
||||
# Use the minimum of config max and seconds until midnight
|
||||
max_seconds = min(self.config.max_time_difference_seconds, seconds_until_midnight)
|
||||
diff_seconds = rng.randint(1, max_seconds) if max_seconds > 0 else 0
|
||||
|
||||
if task_type == "time_ms":
|
||||
# Add microseconds for millisecond precision
|
||||
base = base.replace(microsecond=rng.randint(0, 999) * 1000)
|
||||
end_time = base + timedelta(seconds=diff_seconds, microseconds=rng.randint(0, 999) * 1000)
|
||||
# Format with exactly 3 decimal places for milliseconds
|
||||
start_time = base.strftime(format_str)[:-3] # Remove extra microsecond digits
|
||||
end_time = end_time.strftime(format_str)[:-3] # Remove extra microsecond digits
|
||||
else:
|
||||
start_time = base.strftime(format_str)
|
||||
end_time = (base + timedelta(seconds=diff_seconds)).strftime(format_str)
|
||||
|
||||
elif task_type == "date":
|
||||
format_str = rng.choice(self.DATE_FORMATS)
|
||||
expected_format = "D days" # Always return number of days for date tasks
|
||||
|
||||
# Generate random start date within configured range, leaving room for end date
|
||||
max_date_difference_days = min(
|
||||
self.config.max_date_difference_days, (self.config.max_date - self.config.min_date).days
|
||||
)
|
||||
max_start_days = (self.config.max_date - self.config.min_date).days - max_date_difference_days
|
||||
start_days = rng.randint(0, max_start_days - 1)
|
||||
start_date = self.config.min_date + timedelta(days=start_days)
|
||||
|
||||
# Ensure positive difference between dates
|
||||
diff_days = rng.randint(0, max_date_difference_days)
|
||||
end_date = start_date + timedelta(days=diff_days)
|
||||
|
||||
start_time = start_date.strftime(format_str)
|
||||
end_time = end_date.strftime(format_str)
|
||||
|
||||
else: # datetime or datetime_tz
|
||||
format_str = rng.choice(self.DATETIME_FORMATS)
|
||||
# Choose between HH:MM and HH:MM:SS format for datetime answers
|
||||
expected_format = rng.choice(["D days, HH:MM", "D days, HH:MM:SS"])
|
||||
|
||||
# Generate random start datetime
|
||||
days_range = (self.config.max_date - self.config.min_date).days
|
||||
start_days = rng.randint(0, days_range)
|
||||
start_hour = rng.randint(0, 23)
|
||||
start_minute = rng.randint(0, 59)
|
||||
start_second = rng.randint(0, 59)
|
||||
|
||||
# Generate random time differences first
|
||||
diff_days = rng.randint(0, self.config.max_date_difference_days)
|
||||
diff_seconds = rng.randint(1, self.config.max_time_difference_seconds)
|
||||
|
||||
if "%z" in format_str:
|
||||
# Use simpler timezone format with offset
|
||||
base = datetime.combine(
|
||||
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
|
||||
)
|
||||
# Generate timezone offsets
|
||||
start_offset = rng.randint(-12, 12)
|
||||
end_offset = rng.randint(-12, 12)
|
||||
|
||||
# Apply start timezone
|
||||
base = base.replace(tzinfo=pytz.FixedOffset(start_offset * 60))
|
||||
start_format = format_str.replace("%z", "%+05d" % (start_offset * 100))
|
||||
|
||||
# Calculate end time and convert to end timezone
|
||||
end_dt = base + timedelta(days=diff_days, seconds=diff_seconds)
|
||||
end_dt = end_dt.replace(tzinfo=pytz.FixedOffset(end_offset * 60))
|
||||
end_format = format_str.replace("%z", "%+05d" % (end_offset * 100))
|
||||
|
||||
# Format times with their respective timezone offsets
|
||||
start_time = base.strftime(start_format).rstrip()
|
||||
end_time = end_dt.strftime(end_format).rstrip()
|
||||
else:
|
||||
base = datetime.combine(
|
||||
self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second)
|
||||
)
|
||||
# For non-timezone aware times, both use same format
|
||||
start_time = base.strftime(format_str).rstrip()
|
||||
end_time = (base + timedelta(days=diff_days, seconds=diff_seconds)).strftime(format_str).rstrip()
|
||||
|
||||
return start_time, end_time, format_str, expected_format
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict) -> float:
|
||||
"""Score an answer based on how close it is to the expected duration
|
||||
|
||||
Returns a score between 0 and 1, with partial credit for answers that are
|
||||
close to correct in the appropriate units/format
|
||||
"""
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
expected = entry["answer"]
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
|
||||
try:
|
||||
if task_type == "date":
|
||||
# Parse "X days" format
|
||||
try:
|
||||
actual = int(answer.strip().split()[0]) # Get number before "days"
|
||||
expected = int(expected.strip().split()[0])
|
||||
if actual == expected:
|
||||
return 1.0
|
||||
# Partial credit based on how close the day count is
|
||||
max_diff = self.config.max_date_difference_days
|
||||
diff = abs(actual - expected)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
except (ValueError, IndexError):
|
||||
return 0.0
|
||||
|
||||
elif task_type.startswith("time"):
|
||||
# Parse times into total seconds for comparison
|
||||
def parse_time(t):
|
||||
parts = t.strip().split(":")
|
||||
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
|
||||
if len(parts) > 2:
|
||||
if "." in parts[2]: # Has milliseconds
|
||||
s, ms = parts[2].split(".")
|
||||
seconds += int(s) + int(ms) / 1000
|
||||
else:
|
||||
seconds += int(parts[2])
|
||||
return seconds
|
||||
|
||||
actual_seconds = parse_time(answer)
|
||||
expected_seconds = parse_time(expected)
|
||||
|
||||
if actual_seconds == expected_seconds:
|
||||
return 1.0
|
||||
|
||||
# Partial credit based on how close the times are
|
||||
max_diff = self.config.max_time_difference_seconds
|
||||
diff = abs(actual_seconds - expected_seconds)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
|
||||
else: # datetime or datetime_tz
|
||||
# Parse the complex format "X days, HH:MM" or "X days, HH:MM:SS"
|
||||
def parse_datetime(t):
|
||||
days = int(t.split(" days,")[0])
|
||||
time_part = t.split(",")[1].strip()
|
||||
parts = time_part.split(":")
|
||||
seconds = int(parts[0]) * 3600 + int(parts[1]) * 60
|
||||
if len(parts) > 2:
|
||||
seconds += int(parts[2])
|
||||
return days * 86400 + seconds
|
||||
|
||||
actual_seconds = parse_datetime(answer)
|
||||
expected_seconds = parse_datetime(expected)
|
||||
|
||||
if actual_seconds == expected_seconds:
|
||||
return 1.0
|
||||
|
||||
# Partial credit based on total time difference
|
||||
max_diff = self.config.max_date_difference_days * 86400
|
||||
diff = abs(actual_seconds - expected_seconds)
|
||||
return max(0.0, 1.0 - (diff / max_diff))
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return 0.0 # Invalid format
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig)
|
||||
Loading…
Add table
Add a link
Reference in a new issue