use native types List->list, Dict->dict, Set->set, Tuple->tuple

This commit is contained in:
Andreas Koepf 2025-02-21 15:13:19 +01:00
parent 5d02064b5a
commit 3e7ff3b084
95 changed files with 754 additions and 760 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal, Optional
from reasoning_gym import utils
@ -234,7 +234,7 @@ class BasicArithmeticDataset(ProceduralDataset):
template = rng.choice(templates)
return template.format(expression)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer, allow_commas=False)

View file

@ -4,7 +4,7 @@ import random
from dataclasses import dataclass
from datetime import date, timedelta
from enum import Enum, StrEnum, auto
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -51,7 +51,7 @@ class CalendarTask(StrEnum):
@dataclass
class CalendarArithmeticConfig:
year: int = 2022
tasks: Optional[List[str]] = None
tasks: Optional[list[str]] = None
offset_upper_bound: int = 100
leap_year_range: int = 200
seed: Optional[int] = 42
@ -131,7 +131,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
"metadata": metadata,
}
def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]:
def _weekday_offset(self, rng: random.Random) -> tuple[str, str, dict]:
"""
Task: Given a starting date and a day offset (which may be positive or negative),
ask what day of the week it will be.
@ -170,7 +170,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, target_weekday, metadata
def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]:
def _weekday_of_date(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Ask what day of the week a given date was.
example:
@ -193,7 +193,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, answer_weekday, metadata
def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]:
def _weekday_of_date_from_first_day(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on.
example:
@ -235,7 +235,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, answer_weekday, metadata
def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]:
def _recurring_event_day(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: For a recurring event defined by an ordinal weekday pattern in a month,
ask on which day of the month the event occurs.
@ -294,7 +294,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, str(event_day), metadata
def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]:
def _count_days(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Ask how many times a given weekday occurs in a specified range.
example:
@ -334,7 +334,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, str(count), metadata
def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]:
def _count_business_days(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Count the number of business days (Monday-Friday) between two dates.
example:
@ -385,7 +385,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
}
return question, str(count), metadata
def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]:
def _is_leap_year(self, rng: random.Random) -> tuple[str, str, dict]:
"""
task: Given a year, determine whether it is a leap year.
example:
@ -426,7 +426,7 @@ class CalendarArithmeticDataset(ProceduralDataset):
random_days = rng.randint(0, delta)
return start_date + timedelta(days=random_days)
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
# we suppose the answer is the last occurence of the expected answer type
if answer is None:
return 0.0

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Optional
from reasoning_gym import utils
@ -110,7 +110,7 @@ class ChainSumDataset(ProceduralDataset):
expression = " ".join(expression_parts)
return expression, result
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer)

View file

@ -2,7 +2,7 @@ import ast
from dataclasses import dataclass
from decimal import ROUND_HALF_UP, Decimal, getcontext
from random import Random
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -25,7 +25,7 @@ class DecimalArithmeticConfig:
), "precision must be 2 or more higher than max_num_decimal_places"
def build_grouped_expression(operands: List[str], operators: List[str], rng: Random) -> str:
def build_grouped_expression(operands: list[str], operators: list[str], rng: Random) -> str:
"""
Recursively build an arithmetic expression string from operands and operators,
inserting parentheses at random.
@ -53,7 +53,7 @@ def generate_arithmetic_problem(
min_num_decimal_places: int,
max_num_decimal_places: int,
terms: int = 2,
operations: Optional[List[str]] = None,
operations: Optional[list[str]] = None,
) -> str:
"""
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
@ -72,8 +72,8 @@ def generate_arithmetic_problem(
if operations is None:
operations = ["+", "-", "*", "/"]
operands: List[str] = []
operators: List[str] = []
operands: list[str] = []
operators: list[str] = []
for i in range(terms):
# Choose a random number of decimal places for this term.
@ -149,7 +149,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
def __init__(self, config: DecimalArithmeticConfig) -> None:
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> Dict[str, Any]:
def __getitem__(self, idx: int) -> dict[str, Any]:
"""
Generate a single arithmetic task.
@ -180,7 +180,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
return {"question": problem_str, "answer": answer, "metadata": {}}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
Compares the user's answer (converted to Decimal) with the correct answer.
Instead of requiring exact equality, we allow an error up to one unit in the

View file

@ -133,7 +133,7 @@ class DecimalChainSumDataset(ProceduralDataset):
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
return expression, result
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer by comparing decimal values instead of strings.
Args:
answer: The answer to score

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -125,14 +125,14 @@ class DiceDataset(ProceduralDataset):
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the Dice task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -4,7 +4,7 @@ import re
from dataclasses import dataclass
from math import gcd
from random import Random
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence
from ..factory import ProceduralDataset, register_dataset
@ -42,7 +42,7 @@ class FractionSimplificationDataset(ProceduralDataset):
def __init__(self, config: FractionSimplificationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]:
"""Generate a random fraction and its simplified form.
Returns (numerator, denominator, simplified_num, simplified_den)"""
# Try to generate valid fractions until we get one that meets our criteria
@ -134,7 +134,7 @@ class FractionSimplificationDataset(ProceduralDataset):
except:
return None
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]):
reward = 0.0
metadata = entry["metadata"]
try:

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import List, Optional, Tuple
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
@ -34,7 +34,7 @@ class GCDDataset(ProceduralDataset):
def __init__(self, config: GCDConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
def _generate_numbers(self, rng: Random) -> tuple[list[int], int]:
"""Generate a list of random positive integers and their GCD.
Will try up to 3 times to find numbers with GCD > 1."""

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from functools import reduce
from math import lcm
from random import Random
from typing import List, Optional, Tuple
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
@ -34,11 +34,11 @@ class LCMDataset(ProceduralDataset):
def __init__(self, config: LCMConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
def _generate_numbers(self, rng: Random) -> tuple[list[int], int]:
"""Generate a list of random positive integers and their LCM.
Will try up to 3 times to find numbers with LCM < product."""
def calculate_product(nums: List[int]) -> int:
def calculate_product(nums: list[int]) -> int:
return reduce(lambda x, y: x * y, nums)
# Try up to 3 times to get LCM < product

View file

@ -93,7 +93,7 @@ class LegCountingDataset(ProceduralDataset):
def __init__(self, config: LegCountingConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_animals(self, rng: Random) -> Dict[str, int]:
def _generate_animals(self, rng: Random) -> dict[str, int]:
"""Generate a random set of animals and their counts"""
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
animals = {}

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -67,7 +67,7 @@ class NumberFormatDataset(ProceduralDataset):
output.append(f"{candidate:.15e}")
return output
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["metadata"]["solution"]
if answer is not None and len(answer) > 0:

View file

@ -3,7 +3,7 @@
from dataclasses import dataclass
from math import pow
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -46,7 +46,7 @@ class PowerFunctionDataset(ProceduralDataset):
def __init__(self, config: PowerFunctionConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:

View file

@ -3,7 +3,7 @@
import math
from dataclasses import dataclass
from random import Random
from typing import Dict, List, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -29,7 +29,7 @@ class PrimeFactorizationDataset(ProceduralDataset):
def __init__(self, config: PrimeFactorizationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _prime_factors(self, n: int) -> List[int]:
def _prime_factors(self, n: int) -> list[int]:
"""Compute prime factors of a number"""
factors = []
d = 2
@ -44,11 +44,11 @@ class PrimeFactorizationDataset(ProceduralDataset):
break
return factors
def _normalize_answer(self, answer: str) -> List[int]:
def _normalize_answer(self, answer: str) -> list[int]:
"""Parse and sort factors from a string"""
return sorted([int(factor.strip()) for factor in answer.split("×")])
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None:

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Optional
from reasoning_gym import utils
@ -102,7 +102,7 @@ class ProductsDataset(ProceduralDataset):
expression = " ".join(expression_parts)
return expression, result
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer)

View file

@ -1,7 +1,7 @@
import random
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import List, Optional
from typing import Optional
import pytz
from dateutil import parser
@ -19,7 +19,7 @@ class TimeIntervalsConfig:
min_date: date = date(1900, 1, 1)
max_date: date = date(3000, 1, 1)
max_date_difference_days: int = 100
task_types: List[str] = field(
task_types: list[str] = field(
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
)
seed: Optional[int] = None