mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
formatting
This commit is contained in:
parent
98988c8481
commit
20069b2a7d
37 changed files with 504 additions and 666 deletions
|
|
@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
|
|||
|
||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
||||
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset
|
||||
from .fraction_simplification import (
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationDataset,
|
||||
fraction_simplification_dataset,
|
||||
)
|
||||
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
||||
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
||||
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
||||
|
|
@ -25,7 +29,7 @@ __all__ = [
|
|||
"FractionSimplificationDataset",
|
||||
"fraction_simplification_dataset",
|
||||
"GCDConfig",
|
||||
"GCDDataset",
|
||||
"GCDDataset",
|
||||
"gcd_dataset",
|
||||
"LCMConfig",
|
||||
"LCMDataset",
|
||||
|
|
@ -35,5 +39,5 @@ __all__ = [
|
|||
"leg_counting_dataset",
|
||||
"PrimeFactorizationConfig",
|
||||
"PrimeFactorizationDataset",
|
||||
"prime_factorization_dataset"
|
||||
"prime_factorization_dataset",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
|
|
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
expression = " ".join(expression_parts)
|
||||
return expression, result
|
||||
|
||||
|
||||
def _format_question(self, rng: Random, expression: str) -> str:
|
||||
"""Format the expression according to config style"""
|
||||
if self.config.format_style == "simple":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
|
|
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
||||
"""Generate a chain sum task
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,24 @@
|
|||
"""Fraction simplification task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional, Tuple, Sequence
|
||||
from ..dataset import ProceduralDataset
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class FractionSimplificationConfig:
|
||||
"""Configuration for fraction simplification task generation"""
|
||||
min_value: int = 1 # Minimum value for numerator/denominator
|
||||
max_value: int = 1000 # Maximum value for numerator/denominator
|
||||
min_factor: int = 1 # Minimum multiplication factor
|
||||
max_factor: int = 100 # Maximum multiplication factor
|
||||
|
||||
min_value: int = 1 # Minimum value for numerator/denominator
|
||||
max_value: int = 1000 # Maximum value for numerator/denominator
|
||||
min_factor: int = 1 # Minimum multiplication factor
|
||||
max_factor: int = 100 # Maximum multiplication factor
|
||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
@ -23,7 +26,7 @@ class FractionSimplificationConfig:
|
|||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
assert self.min_factor >= 1, "min_factor must be at least 1"
|
||||
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
|
||||
|
||||
|
||||
# Validate styles
|
||||
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
|
||||
for style in self.styles:
|
||||
|
|
@ -46,37 +49,38 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
# Generate the simplified fraction first
|
||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
|
||||
# Make sure they're coprime by dividing by their GCD
|
||||
common = gcd(simplified_num, simplified_den)
|
||||
simplified_num //= common
|
||||
simplified_den //= common
|
||||
|
||||
|
||||
# Check if simplified fraction is within bounds
|
||||
if (self.config.min_value <= simplified_num <= self.config.max_value and
|
||||
self.config.min_value <= simplified_den <= self.config.max_value):
|
||||
if (
|
||||
self.config.min_value <= simplified_num <= self.config.max_value
|
||||
and self.config.min_value <= simplified_den <= self.config.max_value
|
||||
):
|
||||
# Ensure numerator is smaller than denominator
|
||||
if simplified_num > simplified_den:
|
||||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
|
||||
# Multiply both by a random factor to create the unsimplified version
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
numerator = simplified_num * factor
|
||||
denominator = simplified_den * factor
|
||||
return numerator, denominator, simplified_num, simplified_den
|
||||
|
||||
|
||||
# If we failed to find a good fraction after max attempts,
|
||||
# generate one that's guaranteed to be within bounds
|
||||
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
|
||||
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
|
||||
# Ensure numerator is smaller than denominator
|
||||
if simplified_num > simplified_den:
|
||||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
return (simplified_num * factor, simplified_den * factor,
|
||||
simplified_num, simplified_den)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||
|
||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||
"""Format a fraction in various styles"""
|
||||
|
|
@ -95,16 +99,16 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single fraction simplification task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||
|
||||
|
||||
# Choose a random style from configured styles
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)]
|
||||
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||
|
||||
# Format both question and answer in the same style
|
||||
question_fraction = self._format_fraction(num, den, style)
|
||||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||
|
||||
|
||||
return {
|
||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
||||
"answer": answer_fraction,
|
||||
|
|
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"simplified_numerator": simple_num,
|
||||
"simplified_denominator": simple_den,
|
||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||
"style": style
|
||||
}
|
||||
"style": style,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,24 @@
|
|||
"""Greatest Common Divisor (GCD) task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from math import gcd
|
||||
from functools import reduce
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCDConfig:
|
||||
"""Configuration for GCD task generation"""
|
||||
min_numbers: int = 2 # Minimum numbers to find GCD of
|
||||
max_numbers: int = 2 # Maximum numbers to find GCD of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 1000 # Maximum value for each number
|
||||
|
||||
min_numbers: int = 2 # Minimum numbers to find GCD of
|
||||
max_numbers: int = 2 # Maximum numbers to find GCD of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 1000 # Maximum value for each number
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
@ -38,33 +41,28 @@ class GCDDataset(ProceduralDataset):
|
|||
Will try up to 3 times to find numbers with GCD > 1."""
|
||||
for _ in range(3): # Try up to 3 times to get GCD > 1
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(gcd, numbers)
|
||||
if result > 1:
|
||||
return numbers, result
|
||||
|
||||
|
||||
# If we failed to find GCD > 1 after 3 tries, generate one final set
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(gcd, numbers)
|
||||
return numbers, result
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single GCD task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
|
||||
return {
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"numbers": numbers,
|
||||
"result": result
|
||||
}
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,24 @@
|
|||
"""Least Common Multiple (LCM) task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import lcm
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from math import lcm
|
||||
from functools import reduce
|
||||
|
||||
|
||||
@dataclass
|
||||
class LCMConfig:
|
||||
"""Configuration for LCM task generation"""
|
||||
min_numbers: int = 2 # Minimum numbers to find LCM of
|
||||
max_numbers: int = 2 # Maximum numbers to find LCM of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
||||
|
||||
min_numbers: int = 2 # Minimum numbers to find LCM of
|
||||
max_numbers: int = 2 # Maximum numbers to find LCM of
|
||||
min_value: int = 1 # Minimum value for each number
|
||||
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
@ -36,38 +39,34 @@ class LCMDataset(ProceduralDataset):
|
|||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their LCM.
|
||||
Will try up to 3 times to find numbers with LCM < product."""
|
||||
|
||||
def calculate_product(nums: List[int]) -> int:
|
||||
return reduce(lambda x, y: x * y, nums)
|
||||
|
||||
|
||||
for _ in range(3): # Try up to 3 times to get LCM < product
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(lcm, numbers)
|
||||
if result < calculate_product(numbers):
|
||||
return numbers, result
|
||||
|
||||
|
||||
# If we failed to find LCM < product after 3 tries, generate one final set
|
||||
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value)
|
||||
for _ in range(num_count)]
|
||||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
|
||||
result = reduce(lcm, numbers)
|
||||
return numbers, result
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single LCM task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
|
||||
return {
|
||||
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"numbers": numbers,
|
||||
"result": result
|
||||
}
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Leg counting task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
ANIMALS = {
|
||||
|
|
@ -52,14 +54,16 @@ ANIMALS = {
|
|||
"woodlouse": 14,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegCountingConfig:
|
||||
"""Configuration for leg counting task generation"""
|
||||
min_animals: int = 2 # Minimum number of animals in problem
|
||||
max_animals: int = 5 # Maximum number of animals
|
||||
max_instances: int = 3 # Maximum instances of each animal
|
||||
|
||||
min_animals: int = 2 # Minimum number of animals in problem
|
||||
max_animals: int = 5 # Maximum number of animals
|
||||
max_instances: int = 3 # Maximum instances of each animal
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
@ -80,39 +84,36 @@ class LegCountingDataset(ProceduralDataset):
|
|||
"""Generate a random set of animals and their counts"""
|
||||
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
||||
animals = {}
|
||||
|
||||
|
||||
# Select random animals
|
||||
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
|
||||
for animal in selected_animals:
|
||||
count = rng.randint(1, self.config.max_instances)
|
||||
animals[animal] = count
|
||||
|
||||
|
||||
return animals
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single leg counting task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
# Generate random animals and their counts
|
||||
animals = self._generate_animals(rng)
|
||||
|
||||
|
||||
# Calculate total legs
|
||||
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
||||
|
||||
|
||||
# Format animal counts for question
|
||||
animal_list = []
|
||||
for animal, count in animals.items():
|
||||
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
||||
|
||||
|
||||
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
||||
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(total_legs),
|
||||
"metadata": {
|
||||
"animals": animals,
|
||||
"total_legs": total_legs
|
||||
}
|
||||
"metadata": {"animals": animals, "total_legs": total_legs},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
"""Prime factorization task generator"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrimeFactorizationConfig:
|
||||
"""Configuration for prime factorization task generation"""
|
||||
min_value: int = 2 # Minimum number to factorize
|
||||
max_value: int = 1000 # Maximum number to factorize
|
||||
|
||||
min_value: int = 2 # Minimum number to factorize
|
||||
max_value: int = 1000 # Maximum number to factorize
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
@ -44,24 +48,23 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single prime factorization task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
|
||||
# Generate random number to factorize
|
||||
number = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
|
||||
# Calculate prime factors
|
||||
factors = self._prime_factors(number)
|
||||
|
||||
|
||||
# Format answer as multiplication of prime factors
|
||||
answer = " × ".join(map(str, factors))
|
||||
|
||||
|
||||
return {
|
||||
"question": (f"Find the prime factorization of {number}. Write the factors separated by × "
|
||||
f"(Example: for 12 the answer would be: 2 × 2 × 3)"),
|
||||
"question": (
|
||||
f"Find the prime factorization of {number}. Write the factors separated by × "
|
||||
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
|
||||
),
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"number": number,
|
||||
"factors": factors
|
||||
}
|
||||
"metadata": {"number": number, "factors": factors},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue