mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
use native types List->list, Dict->dict, Set->set, Tuple->tuple
This commit is contained in:
parent
5d02064b5a
commit
3e7ff3b084
95 changed files with 754 additions and 760 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from reasoning_gym import utils
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
template = rng.choice(templates)
|
||||
return template.format(expression)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
return utils.compute_reward(answer, oracle_answer, allow_commas=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import random
|
|||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from enum import Enum, StrEnum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class CalendarTask(StrEnum):
|
|||
@dataclass
|
||||
class CalendarArithmeticConfig:
|
||||
year: int = 2022
|
||||
tasks: Optional[List[str]] = None
|
||||
tasks: Optional[list[str]] = None
|
||||
offset_upper_bound: int = 100
|
||||
leap_year_range: int = 200
|
||||
seed: Optional[int] = 42
|
||||
|
|
@ -131,7 +131,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _weekday_offset(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
Task: Given a starting date and a day offset (which may be positive or negative),
|
||||
ask what day of the week it will be.
|
||||
|
|
@ -170,7 +170,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, target_weekday, metadata
|
||||
|
||||
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _weekday_of_date(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask what day of the week a given date was.
|
||||
example:
|
||||
|
|
@ -193,7 +193,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _weekday_of_date_from_first_day(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
|
||||
example:
|
||||
|
|
@ -235,7 +235,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, answer_weekday, metadata
|
||||
|
||||
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _recurring_event_day(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: For a recurring event defined by an ordinal weekday pattern in a month,
|
||||
ask on which day of the month the event occurs.
|
||||
|
|
@ -294,7 +294,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, str(event_day), metadata
|
||||
|
||||
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _count_days(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Ask how many times a given weekday occurs in a specified range.
|
||||
example:
|
||||
|
|
@ -334,7 +334,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _count_business_days(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Count the number of business days (Monday-Friday) between two dates.
|
||||
example:
|
||||
|
|
@ -385,7 +385,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
}
|
||||
return question, str(count), metadata
|
||||
|
||||
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
|
||||
def _is_leap_year(self, rng: random.Random) -> tuple[str, str, dict]:
|
||||
"""
|
||||
task: Given a year, determine whether it is a leap year.
|
||||
example:
|
||||
|
|
@ -426,7 +426,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
random_days = rng.randint(0, delta)
|
||||
return start_date + timedelta(days=random_days)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
# we suppose the answer is the last occurence of the expected answer type
|
||||
if answer is None:
|
||||
return 0.0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from reasoning_gym import utils
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class ChainSumDataset(ProceduralDataset):
|
|||
expression = " ".join(expression_parts)
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
return utils.compute_reward(answer, oracle_answer)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import ast
|
|||
from dataclasses import dataclass
|
||||
from decimal import ROUND_HALF_UP, Decimal, getcontext
|
||||
from random import Random
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ class DecimalArithmeticConfig:
|
|||
), "precision must be 2 or more higher than max_num_decimal_places"
|
||||
|
||||
|
||||
def build_grouped_expression(operands: List[str], operators: List[str], rng: Random) -> str:
|
||||
def build_grouped_expression(operands: list[str], operators: list[str], rng: Random) -> str:
|
||||
"""
|
||||
Recursively build an arithmetic expression string from operands and operators,
|
||||
inserting parentheses at random.
|
||||
|
|
@ -53,7 +53,7 @@ def generate_arithmetic_problem(
|
|||
min_num_decimal_places: int,
|
||||
max_num_decimal_places: int,
|
||||
terms: int = 2,
|
||||
operations: Optional[List[str]] = None,
|
||||
operations: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
|
||||
|
|
@ -72,8 +72,8 @@ def generate_arithmetic_problem(
|
|||
if operations is None:
|
||||
operations = ["+", "-", "*", "/"]
|
||||
|
||||
operands: List[str] = []
|
||||
operators: List[str] = []
|
||||
operands: list[str] = []
|
||||
operators: list[str] = []
|
||||
|
||||
for i in range(terms):
|
||||
# Choose a random number of decimal places for this term.
|
||||
|
|
@ -149,7 +149,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
def __init__(self, config: DecimalArithmeticConfig) -> None:
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a single arithmetic task.
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
|
||||
return {"question": problem_str, "answer": answer, "metadata": {}}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""
|
||||
Compares the user's answer (converted to Decimal) with the correct answer.
|
||||
Instead of requiring exact equality, we allow an error up to one unit in the
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer by comparing decimal values instead of strings.
|
||||
Args:
|
||||
answer: The answer to score
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -125,14 +125,14 @@ class DiceDataset(ProceduralDataset):
|
|||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the Dice task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
def __init__(self, config: FractionSimplificationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
|
||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]:
|
||||
"""Generate a random fraction and its simplified form.
|
||||
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
||||
# Try to generate valid fractions until we get one that meets our criteria
|
||||
|
|
@ -134,7 +134,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
except:
|
||||
return None
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]):
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ class GCDDataset(ProceduralDataset):
|
|||
def __init__(self, config: GCDConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
def _generate_numbers(self, rng: Random) -> tuple[list[int], int]:
|
||||
"""Generate a list of random positive integers and their GCD.
|
||||
Will try up to 3 times to find numbers with GCD > 1."""
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from functools import reduce
|
||||
from math import lcm
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -34,11 +34,11 @@ class LCMDataset(ProceduralDataset):
|
|||
def __init__(self, config: LCMConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
def _generate_numbers(self, rng: Random) -> tuple[list[int], int]:
|
||||
"""Generate a list of random positive integers and their LCM.
|
||||
Will try up to 3 times to find numbers with LCM < product."""
|
||||
|
||||
def calculate_product(nums: List[int]) -> int:
|
||||
def calculate_product(nums: list[int]) -> int:
|
||||
return reduce(lambda x, y: x * y, nums)
|
||||
|
||||
# Try up to 3 times to get LCM < product
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class LegCountingDataset(ProceduralDataset):
|
|||
def __init__(self, config: LegCountingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_animals(self, rng: Random) -> Dict[str, int]:
|
||||
def _generate_animals(self, rng: Random) -> dict[str, int]:
|
||||
"""Generate a random set of animals and their counts"""
|
||||
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
||||
animals = {}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ class NumberFormatDataset(ProceduralDataset):
|
|||
output.append(f"{candidate:.15e}")
|
||||
return output
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["metadata"]["solution"]
|
||||
if answer is not None and len(answer) > 0:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from dataclasses import dataclass
|
||||
from math import pow
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
def __init__(self, config: PowerFunctionConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
def __init__(self, config: PrimeFactorizationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _prime_factors(self, n: int) -> List[int]:
|
||||
def _prime_factors(self, n: int) -> list[int]:
|
||||
"""Compute prime factors of a number"""
|
||||
factors = []
|
||||
d = 2
|
||||
|
|
@ -44,11 +44,11 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
break
|
||||
return factors
|
||||
|
||||
def _normalize_answer(self, answer: str) -> List[int]:
|
||||
def _normalize_answer(self, answer: str) -> list[int]:
|
||||
"""Parse and sort factors from a string"""
|
||||
return sorted([int(factor.strip()) for factor in answer.split("×")])
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["answer"]
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from reasoning_gym import utils
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ class ProductsDataset(ProceduralDataset):
|
|||
expression = " ".join(expression_parts)
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
return utils.compute_reward(answer, oracle_answer)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
|
|
@ -19,7 +19,7 @@ class TimeIntervalsConfig:
|
|||
min_date: date = date(1900, 1, 1)
|
||||
max_date: date = date(3000, 1, 1)
|
||||
max_date_difference_days: int = 100
|
||||
task_types: List[str] = field(
|
||||
task_types: list[str] = field(
|
||||
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue