formatting

This commit is contained in:
Andreas Koepf 2025-01-24 10:34:07 +01:00
parent 98988c8481
commit 20069b2a7d
37 changed files with 504 additions and 666 deletions

View file

@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset
from .fraction_simplification import (
FractionSimplificationConfig,
FractionSimplificationDataset,
fraction_simplification_dataset,
)
from .gcd import GCDConfig, GCDDataset, gcd_dataset
from .lcm import LCMConfig, LCMDataset, lcm_dataset
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
@ -25,7 +29,7 @@ __all__ = [
"FractionSimplificationDataset",
"fraction_simplification_dataset",
"GCDConfig",
"GCDDataset",
"GCDDataset",
"gcd_dataset",
"LCMConfig",
"LCMDataset",
@ -35,5 +39,5 @@ __all__ = [
"leg_counting_dataset",
"PrimeFactorizationConfig",
"PrimeFactorizationDataset",
"prime_factorization_dataset"
"prime_factorization_dataset",
]

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
expression = " ".join(expression_parts)
return expression, result
def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style"""
if self.config.format_style == "simple":

View file

@ -1,6 +1,7 @@
import random
from dataclasses import dataclass
from typing import Optional
from ..dataset import ProceduralDataset
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
},
}
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
"""Generate a chain sum task

View file

@ -1,21 +1,24 @@
"""Fraction simplification task generator"""
from dataclasses import dataclass
from random import Random
from typing import Optional, Tuple, Sequence
from ..dataset import ProceduralDataset
from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from ..dataset import ProceduralDataset
@dataclass
class FractionSimplificationConfig:
"""Configuration for fraction simplification task generation"""
min_value: int = 1 # Minimum value for numerator/denominator
max_value: int = 1000 # Maximum value for numerator/denominator
min_factor: int = 1 # Minimum multiplication factor
max_factor: int = 100 # Maximum multiplication factor
min_value: int = 1 # Minimum value for numerator/denominator
max_value: int = 1000 # Maximum value for numerator/denominator
min_factor: int = 1 # Minimum multiplication factor
max_factor: int = 100 # Maximum multiplication factor
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
@ -23,7 +26,7 @@ class FractionSimplificationConfig:
assert self.max_value > self.min_value, "max_value must be > min_value"
assert self.min_factor >= 1, "min_factor must be at least 1"
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
# Validate styles
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
for style in self.styles:
@ -46,37 +49,38 @@ class FractionSimplificationDataset(ProceduralDataset):
# Generate the simplified fraction first
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Make sure they're coprime by dividing by their GCD
common = gcd(simplified_num, simplified_den)
simplified_num //= common
simplified_den //= common
# Check if simplified fraction is within bounds
if (self.config.min_value <= simplified_num <= self.config.max_value and
self.config.min_value <= simplified_den <= self.config.max_value):
if (
self.config.min_value <= simplified_num <= self.config.max_value
and self.config.min_value <= simplified_den <= self.config.max_value
):
# Ensure numerator is smaller than denominator
if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num
# Multiply both by a random factor to create the unsimplified version
factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor
denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den
# If we failed to find a good fraction after max attempts,
# generate one that's guaranteed to be within bounds
simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Ensure numerator is smaller than denominator
if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num
factor = rng.randint(self.config.min_factor, self.config.max_factor)
return (simplified_num * factor, simplified_den * factor,
simplified_num, simplified_den)
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles"""
@ -95,16 +99,16 @@ class FractionSimplificationDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
"""Generate a single fraction simplification task"""
rng = Random(self.seed + idx)
num, den, simple_num, simple_den = self._generate_fraction(rng)
# Choose a random style from configured styles
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)]
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
# Format both question and answer in the same style
question_fraction = self._format_fraction(num, den, style)
answer_fraction = self._format_fraction(simple_num, simple_den, style)
return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
"answer": answer_fraction,
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_numerator": simple_num,
"simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style
}
"style": style,
},
}

View file

@ -1,21 +1,24 @@
"""Greatest Common Divisor (GCD) task generator"""
from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
from math import gcd
from functools import reduce
@dataclass
class GCDConfig:
"""Configuration for GCD task generation"""
min_numbers: int = 2 # Minimum numbers to find GCD of
max_numbers: int = 2 # Maximum numbers to find GCD of
min_value: int = 1 # Minimum value for each number
max_value: int = 1000 # Maximum value for each number
min_numbers: int = 2 # Minimum numbers to find GCD of
max_numbers: int = 2 # Maximum numbers to find GCD of
min_value: int = 1 # Minimum value for each number
max_value: int = 1000 # Maximum value for each number
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
@ -38,33 +41,28 @@ class GCDDataset(ProceduralDataset):
Will try up to 3 times to find numbers with GCD > 1."""
for _ in range(3): # Try up to 3 times to get GCD > 1
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value)
for _ in range(num_count)]
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(gcd, numbers)
if result > 1:
return numbers, result
# If we failed to find GCD > 1 after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value)
for _ in range(num_count)]
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(gcd, numbers)
return numbers, result
def __getitem__(self, idx: int) -> dict:
"""Generate a single GCD task"""
rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"answer": str(result),
"metadata": {
"numbers": numbers,
"result": result
}
"metadata": {"numbers": numbers, "result": result},
}

View file

@ -1,21 +1,24 @@
"""Least Common Multiple (LCM) task generator"""
from dataclasses import dataclass
from functools import reduce
from math import lcm
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
from math import lcm
from functools import reduce
@dataclass
class LCMConfig:
"""Configuration for LCM task generation"""
min_numbers: int = 2 # Minimum numbers to find LCM of
max_numbers: int = 2 # Maximum numbers to find LCM of
min_value: int = 1 # Minimum value for each number
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
min_numbers: int = 2 # Minimum numbers to find LCM of
max_numbers: int = 2 # Maximum numbers to find LCM of
min_value: int = 1 # Minimum value for each number
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
@ -36,38 +39,34 @@ class LCMDataset(ProceduralDataset):
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their LCM.
Will try up to 3 times to find numbers with LCM < product."""
def calculate_product(nums: List[int]) -> int:
return reduce(lambda x, y: x * y, nums)
for _ in range(3): # Try up to 3 times to get LCM < product
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value)
for _ in range(num_count)]
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(lcm, numbers)
if result < calculate_product(numbers):
return numbers, result
# If we failed to find LCM < product after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value)
for _ in range(num_count)]
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
result = reduce(lcm, numbers)
return numbers, result
def __getitem__(self, idx: int) -> dict:
"""Generate a single LCM task"""
rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
"answer": str(result),
"metadata": {
"numbers": numbers,
"result": result
}
"metadata": {"numbers": numbers, "result": result},
}

View file

@ -1,7 +1,9 @@
"""Leg counting task generator"""
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..dataset import ProceduralDataset
ANIMALS = {
@ -52,14 +54,16 @@ ANIMALS = {
"woodlouse": 14,
}
@dataclass
class LegCountingConfig:
"""Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
@ -80,39 +84,36 @@ class LegCountingDataset(ProceduralDataset):
"""Generate a random set of animals and their counts"""
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
animals = {}
# Select random animals
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
for animal in selected_animals:
count = rng.randint(1, self.config.max_instances)
animals[animal] = count
return animals
def __getitem__(self, idx: int) -> dict:
"""Generate a single leg counting task"""
rng = Random(self.seed + idx)
# Generate random animals and their counts
animals = self._generate_animals(rng)
# Calculate total legs
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
# Format animal counts for question
animal_list = []
for animal, count in animals.items():
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
return {
"question": question,
"answer": str(total_legs),
"metadata": {
"animals": animals,
"total_legs": total_legs
}
"metadata": {"animals": animals, "total_legs": total_legs},
}

View file

@ -1,16 +1,20 @@
"""Prime factorization task generator"""
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
@dataclass
class PrimeFactorizationConfig:
"""Configuration for prime factorization task generation"""
min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize
min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
@ -44,24 +48,23 @@ class PrimeFactorizationDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
"""Generate a single prime factorization task"""
rng = Random(self.seed + idx)
# Generate random number to factorize
number = rng.randint(self.config.min_value, self.config.max_value)
# Calculate prime factors
factors = self._prime_factors(number)
# Format answer as multiplication of prime factors
answer = " × ".join(map(str, factors))
return {
"question": (f"Find the prime factorization of {number}. Write the factors separated by × "
f"(Example: for 12 the answer would be: 2 × 2 × 3)"),
"question": (
f"Find the prime factorization of {number}. Write the factors separated by × "
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
),
"answer": answer,
"metadata": {
"number": number,
"factors": factors
}
"metadata": {"number": number, "factors": factors},
}