formatting

This commit is contained in:
Andreas Koepf 2025-01-24 10:34:07 +01:00
parent 98988c8481
commit 20069b2a7d
37 changed files with 504 additions and 666 deletions

View file

@ -2,12 +2,7 @@
Reasoning Gym - A library of procedural dataset generators for training reasoning models Reasoning Gym - A library of procedural dataset generators for training reasoning models
""" """
from . import arithmetic from . import algorithmic, arithmetic, cognition, data, games, logic
from . import algorithmic
from . import cognition
from . import data
from . import games
from . import logic
__version__ = "0.1.0" __version__ = "0.1.0"
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"] __all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]

View file

@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities:
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
@ -31,5 +32,5 @@ __all__ = [
"number_sorting_dataset", "number_sorting_dataset",
"WordReversalConfig", "WordReversalConfig",
"WordReversalDataset", "WordReversalDataset",
"word_reversal_dataset" "word_reversal_dataset",
] ]

View file

@ -1,17 +1,20 @@
"""Base conversion task generator""" """Base conversion task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional, Tuple from typing import Optional, Tuple
@dataclass @dataclass
class BaseConversionConfig: class BaseConversionConfig:
"""Configuration for base conversion task generation""" """Configuration for base conversion task generation"""
min_base: int = 2 # Minimum base (2=binary)
max_base: int = 16 # Maximum base (16=hex) min_base: int = 2 # Minimum base (2=binary)
min_value: int = 0 # Minimum decimal value to convert max_base: int = 16 # Maximum base (16=hex)
max_value: int = 1000 # Maximum decimal value to convert min_value: int = 0 # Minimum decimal value to convert
max_value: int = 1000 # Maximum decimal value to convert
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -71,14 +74,14 @@ class BaseConversionDataset:
value, source_base, target_base = self._generate_conversion(rng) value, source_base, target_base = self._generate_conversion(rng)
# Convert decimal to source base representation # Convert decimal to source base representation
source_repr = format(value, f'x' if source_base == 16 else f'b' if source_base == 2 else '').strip() source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
if source_base not in (2, 16): if source_base not in (2, 16):
source_repr = format(value, f'{source_base}x').lower().strip() source_repr = format(value, f"{source_base}x").lower().strip()
# Convert decimal to target base for answer # Convert decimal to target base for answer
target_repr = format(value, f'x' if target_base == 16 else f'b' if target_base == 2 else '').strip() target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
if target_base not in (2, 16): if target_base not in (2, 16):
target_repr = format(value, f'{target_base}x').lower().strip() target_repr = format(value, f"{target_base}x").lower().strip()
source_name = self._format_base_name(source_base) source_name = self._format_base_name(source_base)
target_name = self._format_base_name(target_base) target_name = self._format_base_name(target_base)
@ -94,8 +97,8 @@ class BaseConversionDataset:
"source_base": source_base, "source_base": source_base,
"target_base": target_base, "target_base": target_base,
"source_repr": source_repr, "source_repr": source_repr,
"target_repr": target_repr "target_repr": target_repr,
} },
} }

View file

@ -1,18 +1,21 @@
"""Letter counting task generator""" """Letter counting task generator"""
from dataclasses import dataclass
import re import re
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
@dataclass @dataclass
class LetterCountingConfig: class LetterCountingConfig:
"""Configuration for letter counting task generation""" """Configuration for letter counting task generation"""
min_words: int = 5 # Minimum words in span
max_words: int = 15 # Maximum words in span min_words: int = 5 # Minimum words in span
max_words: int = 15 # Maximum words in span
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -31,7 +34,7 @@ class LetterCountingDataset:
# Load and preprocess text # Load and preprocess text
text = read_data_file("in_the_year_2889.txt") text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters # Extract words and clean them to contain only alphanumeric characters
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -54,12 +57,12 @@ class LetterCountingDataset:
# Select random span of words # Select random span of words
span_length = rng.randint(self.config.min_words, self.config.max_words) span_length = rng.randint(self.config.min_words, self.config.max_words)
start_idx = rng.randint(0, len(self.words) - span_length) start_idx = rng.randint(0, len(self.words) - span_length)
span = self.words[start_idx:start_idx + span_length] span = self.words[start_idx : start_idx + span_length]
# Get all unique letters from span # Get all unique letters from span
letters = set(''.join(span).lower()) letters = set("".join(span).lower())
if not letters: if not letters:
letters = {'a'} # Fallback if span has no letters letters = {"a"} # Fallback if span has no letters
# Select random letter that appears in the span # Select random letter that appears in the span
target_letter = rng.choice(list(letters)) target_letter = rng.choice(list(letters))
@ -70,11 +73,7 @@ class LetterCountingDataset:
return { return {
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?', "question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
"answer": str(count), "answer": str(count),
"metadata": { "metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
"span_length": span_length,
"target_letter": target_letter,
"span": span
}
} }

View file

@ -1,20 +1,23 @@
"""Number filtering task generator""" """Number filtering task generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@dataclass @dataclass
class NumberFilteringConfig: class NumberFilteringConfig:
"""Configuration for number filtering task generation""" """Configuration for number filtering task generation"""
min_numbers: int = 3 # Minimum numbers in list
max_numbers: int = 10 # Maximum numbers in list min_numbers: int = 3 # Minimum numbers in list
min_decimals: int = 0 # Minimum decimal places max_numbers: int = 10 # Maximum numbers in list
max_decimals: int = 4 # Maximum decimal places min_decimals: int = 0 # Minimum decimal places
min_value: float = -100.0 # Minimum number value max_decimals: int = 4 # Maximum decimal places
max_value: float = 100.0 # Maximum number value min_value: float = -100.0 # Minimum number value
max_value: float = 100.0 # Maximum number value
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -96,15 +99,17 @@ class NumberFilteringDataset:
result_strs = [str_numbers[numbers.index(n)] for n in result] result_strs = [str_numbers[numbers.index(n)] for n in result]
return { return {
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} " "question": (
f"in this list: {str_numbers}"), f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
f"in this list: {str_numbers}"
),
"answer": str(result_strs) if result_strs else "[]", "answer": str(result_strs) if result_strs else "[]",
"metadata": { "metadata": {
"original_numbers": str_numbers, "original_numbers": str_numbers,
"filter_value": filter_str, "filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}", "operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs "result": result_strs,
} },
} }

View file

@ -1,20 +1,23 @@
"""Number sorting task generator""" """Number sorting task generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@dataclass @dataclass
class NumberSortingConfig: class NumberSortingConfig:
"""Configuration for number sorting task generation""" """Configuration for number sorting task generation"""
min_numbers: int = 3 # Minimum numbers to sort
max_numbers: int = 10 # Maximum numbers to sort min_numbers: int = 3 # Minimum numbers to sort
min_decimals: int = 0 # Minimum decimal places max_numbers: int = 10 # Maximum numbers to sort
max_decimals: int = 2 # Maximum decimal places min_decimals: int = 0 # Minimum decimal places
max_decimals: int = 2 # Maximum decimal places
min_value: float = -100.0 # Minimum value min_value: float = -100.0 # Minimum value
max_value: float = 100.0 # Maximum value max_value: float = 100.0 # Maximum value
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -82,7 +85,7 @@ class NumberSortingDataset:
desc_numbers = sorted(numbers, reverse=True) desc_numbers = sorted(numbers, reverse=True)
# Format answers as string lists # Format answers as string lists
decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0 decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
asc_answer = [self._format_number(n, decimals) for n in asc_numbers] asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
desc_answer = [self._format_number(n, decimals) for n in desc_numbers] desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
@ -94,11 +97,7 @@ class NumberSortingDataset:
return { return {
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
"answer": str(answer), "answer": str(answer),
"metadata": { "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
"original_numbers": number_strs,
"direction": direction,
"sorted_numbers": answer
}
} }

View file

@ -1,18 +1,21 @@
"""Word reversal task generator""" """Word reversal task generator"""
from dataclasses import dataclass
import re import re
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
@dataclass @dataclass
class WordReversalConfig: class WordReversalConfig:
"""Configuration for word reversal task generation""" """Configuration for word reversal task generation"""
min_words: int = 3 # Minimum words in list
max_words: int = 8 # Maximum words in list min_words: int = 3 # Minimum words in list
max_words: int = 8 # Maximum words in list
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -31,7 +34,7 @@ class WordReversalDataset:
# Load and preprocess text # Load and preprocess text
text = read_data_file("in_the_year_2889.txt") text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters # Extract words and clean them to contain only alphanumeric characters
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -63,10 +66,7 @@ class WordReversalDataset:
return { return {
"question": f"Reverse this list of words: {question}", "question": f"Reverse this list of words: {question}",
"answer": answer, "answer": answer,
"metadata": { "metadata": {"num_words": num_words, "words": words},
"num_words": num_words,
"words": words
}
} }

View file

@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset from .fraction_simplification import (
FractionSimplificationConfig,
FractionSimplificationDataset,
fraction_simplification_dataset,
)
from .gcd import GCDConfig, GCDDataset, gcd_dataset from .gcd import GCDConfig, GCDDataset, gcd_dataset
from .lcm import LCMConfig, LCMDataset, lcm_dataset from .lcm import LCMConfig, LCMDataset, lcm_dataset
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
@ -35,5 +39,5 @@ __all__ = [
"leg_counting_dataset", "leg_counting_dataset",
"PrimeFactorizationConfig", "PrimeFactorizationConfig",
"PrimeFactorizationDataset", "PrimeFactorizationDataset",
"prime_factorization_dataset" "prime_factorization_dataset",
] ]

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
expression = " ".join(expression_parts) expression = " ".join(expression_parts)
return expression, result return expression, result
def _format_question(self, rng: Random, expression: str) -> str: def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style""" """Format the expression according to config style"""
if self.config.format_style == "simple": if self.config.format_style == "simple":

View file

@ -1,6 +1,7 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
}, },
} }
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
"""Generate a chain sum task """Generate a chain sum task

View file

@ -1,21 +1,24 @@
"""Fraction simplification task generator""" """Fraction simplification task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random
from typing import Optional, Tuple, Sequence
from ..dataset import ProceduralDataset
from math import gcd from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from ..dataset import ProceduralDataset
@dataclass @dataclass
class FractionSimplificationConfig: class FractionSimplificationConfig:
"""Configuration for fraction simplification task generation""" """Configuration for fraction simplification task generation"""
min_value: int = 1 # Minimum value for numerator/denominator
max_value: int = 1000 # Maximum value for numerator/denominator min_value: int = 1 # Minimum value for numerator/denominator
min_factor: int = 1 # Minimum multiplication factor max_value: int = 1000 # Maximum value for numerator/denominator
max_factor: int = 100 # Maximum multiplication factor min_factor: int = 1 # Minimum multiplication factor
max_factor: int = 100 # Maximum multiplication factor
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -53,8 +56,10 @@ class FractionSimplificationDataset(ProceduralDataset):
simplified_den //= common simplified_den //= common
# Check if simplified fraction is within bounds # Check if simplified fraction is within bounds
if (self.config.min_value <= simplified_num <= self.config.max_value and if (
self.config.min_value <= simplified_den <= self.config.max_value): self.config.min_value <= simplified_num <= self.config.max_value
and self.config.min_value <= simplified_den <= self.config.max_value
):
# Ensure numerator is smaller than denominator # Ensure numerator is smaller than denominator
if simplified_num > simplified_den: if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num simplified_num, simplified_den = simplified_den, simplified_num
@ -75,8 +80,7 @@ class FractionSimplificationDataset(ProceduralDataset):
simplified_num, simplified_den = simplified_den, simplified_num simplified_num, simplified_den = simplified_den, simplified_num
factor = rng.randint(self.config.min_factor, self.config.max_factor) factor = rng.randint(self.config.min_factor, self.config.max_factor)
return (simplified_num * factor, simplified_den * factor, return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
simplified_num, simplified_den)
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles""" """Format a fraction in various styles"""
@ -99,7 +103,7 @@ class FractionSimplificationDataset(ProceduralDataset):
num, den, simple_num, simple_den = self._generate_fraction(rng) num, den, simple_num, simple_den = self._generate_fraction(rng)
# Choose a random style from configured styles # Choose a random style from configured styles
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)] style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
# Format both question and answer in the same style # Format both question and answer in the same style
question_fraction = self._format_fraction(num, den, style) question_fraction = self._format_fraction(num, den, style)
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_numerator": simple_num, "simplified_numerator": simple_num,
"simplified_denominator": simple_den, "simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den "reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style "style": style,
} },
} }

View file

@ -1,21 +1,24 @@
"""Greatest Common Divisor (GCD) task generator""" """Greatest Common Divisor (GCD) task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
from math import gcd
from functools import reduce
@dataclass @dataclass
class GCDConfig: class GCDConfig:
"""Configuration for GCD task generation""" """Configuration for GCD task generation"""
min_numbers: int = 2 # Minimum numbers to find GCD of
max_numbers: int = 2 # Maximum numbers to find GCD of min_numbers: int = 2 # Minimum numbers to find GCD of
min_value: int = 1 # Minimum value for each number max_numbers: int = 2 # Maximum numbers to find GCD of
max_value: int = 1000 # Maximum value for each number min_value: int = 1 # Minimum value for each number
max_value: int = 1000 # Maximum value for each number
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -38,16 +41,14 @@ class GCDDataset(ProceduralDataset):
Will try up to 3 times to find numbers with GCD > 1.""" Will try up to 3 times to find numbers with GCD > 1."""
for _ in range(3): # Try up to 3 times to get GCD > 1 for _ in range(3): # Try up to 3 times to get GCD > 1
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(gcd, numbers) result = reduce(gcd, numbers)
if result > 1: if result > 1:
return numbers, result return numbers, result
# If we failed to find GCD > 1 after 3 tries, generate one final set # If we failed to find GCD > 1 after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(gcd, numbers) result = reduce(gcd, numbers)
return numbers, result return numbers, result
@ -61,10 +62,7 @@ class GCDDataset(ProceduralDataset):
return { return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}", "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {"numbers": numbers, "result": result},
"numbers": numbers,
"result": result
}
} }

View file

@ -1,21 +1,24 @@
"""Least Common Multiple (LCM) task generator""" """Least Common Multiple (LCM) task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import lcm
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
from math import lcm
from functools import reduce
@dataclass @dataclass
class LCMConfig: class LCMConfig:
"""Configuration for LCM task generation""" """Configuration for LCM task generation"""
min_numbers: int = 2 # Minimum numbers to find LCM of
max_numbers: int = 2 # Maximum numbers to find LCM of min_numbers: int = 2 # Minimum numbers to find LCM of
min_value: int = 1 # Minimum value for each number max_numbers: int = 2 # Maximum numbers to find LCM of
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast) min_value: int = 1 # Minimum value for each number
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -36,21 +39,20 @@ class LCMDataset(ProceduralDataset):
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their LCM. """Generate a list of random positive integers and their LCM.
Will try up to 3 times to find numbers with LCM < product.""" Will try up to 3 times to find numbers with LCM < product."""
def calculate_product(nums: List[int]) -> int: def calculate_product(nums: List[int]) -> int:
return reduce(lambda x, y: x * y, nums) return reduce(lambda x, y: x * y, nums)
for _ in range(3): # Try up to 3 times to get LCM < product for _ in range(3): # Try up to 3 times to get LCM < product
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(lcm, numbers) result = reduce(lcm, numbers)
if result < calculate_product(numbers): if result < calculate_product(numbers):
return numbers, result return numbers, result
# If we failed to find LCM < product after 3 tries, generate one final set # If we failed to find LCM < product after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(lcm, numbers) result = reduce(lcm, numbers)
return numbers, result return numbers, result
@ -64,10 +66,7 @@ class LCMDataset(ProceduralDataset):
return { return {
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}", "question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {"numbers": numbers, "result": result},
"numbers": numbers,
"result": result
}
} }

View file

@ -1,7 +1,9 @@
"""Leg counting task generator""" """Leg counting task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Dict, Optional from typing import Dict, Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
ANIMALS = { ANIMALS = {
@ -52,14 +54,16 @@ ANIMALS = {
"woodlouse": 14, "woodlouse": 14,
} }
@dataclass @dataclass
class LegCountingConfig: class LegCountingConfig:
"""Configuration for leg counting task generation""" """Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals min_animals: int = 2 # Minimum number of animals in problem
max_instances: int = 3 # Maximum instances of each animal max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -109,10 +113,7 @@ class LegCountingDataset(ProceduralDataset):
return { return {
"question": question, "question": question,
"answer": str(total_legs), "answer": str(total_legs),
"metadata": { "metadata": {"animals": animals, "total_legs": total_legs},
"animals": animals,
"total_legs": total_legs
}
} }

View file

@ -1,16 +1,20 @@
"""Prime factorization task generator""" """Prime factorization task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@dataclass @dataclass
class PrimeFactorizationConfig: class PrimeFactorizationConfig:
"""Configuration for prime factorization task generation""" """Configuration for prime factorization task generation"""
min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -55,13 +59,12 @@ class PrimeFactorizationDataset(ProceduralDataset):
answer = " × ".join(map(str, factors)) answer = " × ".join(map(str, factors))
return { return {
"question": (f"Find the prime factorization of {number}. Write the factors separated by × " "question": (
f"(Example: for 12 the answer would be: 2 × 2 × 3)"), f"Find the prime factorization of {number}. Write the factors separated by × "
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
),
"answer": answer, "answer": answer,
"metadata": { "metadata": {"number": number, "factors": factors},
"number": number,
"factors": factors
}
} }

View file

@ -40,7 +40,7 @@ class SequenceConfig:
class PatternRule: class PatternRule:
"""Represents a composable sequence pattern rule""" """Represents a composable sequence pattern rule"""
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None): def __init__(self, operations: List[Operation], parameters: List[int], subrules: List["PatternRule"] = None):
self.operations = operations self.operations = operations
self.parameters = parameters self.parameters = parameters
self.subrules = subrules or [] self.subrules = subrules or []
@ -66,14 +66,14 @@ class PatternRule:
elif op == Operation.COMPOSE: elif op == Operation.COMPOSE:
# Apply each subrule in sequence, passing the result through # Apply each subrule in sequence, passing the result through
for subrule in self.subrules: for subrule in self.subrules:
temp_sequence = sequence[:position + 1] temp_sequence = sequence[: position + 1]
temp_sequence[-1] = result # Use current result as input temp_sequence[-1] = result # Use current result as input
result = subrule.apply(temp_sequence, position) result = subrule.apply(temp_sequence, position)
return result return result
@classmethod @classmethod
def compose(cls, rules: List['PatternRule']) -> 'PatternRule': def compose(cls, rules: List["PatternRule"]) -> "PatternRule":
"""Create a new rule that composes multiple rules together""" """Create a new rule that composes multiple rules together"""
return cls([Operation.COMPOSE], [0], subrules=rules) return cls([Operation.COMPOSE], [0], subrules=rules)

View file

@ -4,6 +4,7 @@ from importlib import resources
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
def get_data_file_path(filename: str) -> Path: def get_data_file_path(filename: str) -> Path:
"""Get the path to a data file in the package. """Get the path to a data file in the package.
@ -18,7 +19,8 @@ def get_data_file_path(filename: str) -> Path:
>>> with open(path) as f: >>> with open(path) as f:
... content = f.read() ... content = f.read()
""" """
return resources.files('reasoning_gym.data').joinpath(filename) return resources.files("reasoning_gym.data").joinpath(filename)
def read_data_file(filename: str) -> str: def read_data_file(filename: str) -> str:
"""Read the contents of a data file in the package. """Read the contents of a data file in the package.
@ -32,6 +34,7 @@ def read_data_file(filename: str) -> str:
Example: Example:
>>> content = read_data_file("pg19362.txt") >>> content = read_data_file("pg19362.txt")
""" """
return resources.files('reasoning_gym.data').joinpath(filename).read_text() return resources.files("reasoning_gym.data").joinpath(filename).read_text()
__all__ = ['get_data_file_path', 'read_data_file']
__all__ = ["get_data_file_path", "read_data_file"]

View file

@ -1048,5 +1048,3 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks. subscribe to our email newsletter to hear about new eBooks.

View file

@ -1,8 +1,9 @@
"""Base class for procedural dataset generators""" """Base class for procedural dataset generators"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sized, Iterable from collections.abc import Iterable, Sized
from random import Random from random import Random
from typing import Optional, Iterator, Dict, Any from typing import Any, Dict, Iterator, Optional
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):

View file

@ -14,5 +14,5 @@ __all__ = [
"mini_sudoku_dataset", "mini_sudoku_dataset",
"SudokuConfig", "SudokuConfig",
"SudokuDataset", "SudokuDataset",
"sudoku_dataset" "sudoku_dataset",
] ]

View file

@ -1,16 +1,19 @@
"""Mini Sudoku (4x4) puzzle generator""" """Mini Sudoku (4x4) puzzle generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
@dataclass @dataclass
class MiniSudokuConfig: class MiniSudokuConfig:
"""Configuration for 4x4 sudoku puzzle generation""" """Configuration for 4x4 sudoku puzzle generation"""
min_empty: int = 8 # Minimum number of empty cells
max_empty: int = 12 # Maximum number of empty cells min_empty: int = 8 # Minimum number of empty cells
max_empty: int = 12 # Maximum number of empty cells
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -142,11 +145,7 @@ class MiniSudokuDataset:
return { return {
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}", "question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
"answer": solution_str, "answer": solution_str,
"metadata": { "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty
}
} }

View file

@ -1,16 +1,19 @@
"""Sudoku puzzle generator""" """Sudoku puzzle generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
@dataclass @dataclass
class SudokuConfig: class SudokuConfig:
"""Configuration for sudoku puzzle generation""" """Configuration for sudoku puzzle generation"""
min_empty: int = 30 # Minimum number of empty cells
max_empty: int = 50 # Maximum number of empty cells min_empty: int = 30 # Minimum number of empty cells
max_empty: int = 50 # Maximum number of empty cells
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -132,11 +135,7 @@ class SudokuDataset:
return { return {
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}", "question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
"answer": solution_str, "answer": solution_str,
"metadata": { "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty
}
} }

View file

@ -1,5 +1,7 @@
import pytest
from random import Random from random import Random
import pytest
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
@ -30,14 +32,7 @@ def test_arithmetic_dataset_deterministic():
def test_arithmetic_dataset_items(): def test_arithmetic_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = BasicArithmeticDatasetConfig( config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
min_terms=2,
max_terms=4,
min_digits=1,
max_digits=2,
size=100,
seed=42
)
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -62,7 +57,7 @@ def test_arithmetic_dataset_format_styles():
min_terms=2, min_terms=2,
max_terms=3, # Keep expressions simple for testing max_terms=3, # Keep expressions simple for testing
min_digits=1, min_digits=1,
max_digits=2 max_digits=2,
) )
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
assert all(item["question"].endswith("=") for item in dataset) assert all(item["question"].endswith("=") for item in dataset)
@ -74,12 +69,7 @@ def test_arithmetic_dataset_format_styles():
def test_arithmetic_dataset_iteration(): def test_arithmetic_dataset_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = BasicArithmeticDatasetConfig( config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
min_terms=2,
max_terms=2,
size=5, # Small size for testing
seed=42
)
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
# Test manual iteration # Test manual iteration

View file

@ -1,10 +1,8 @@
"""Tests for base conversion task generation""" """Tests for base conversion task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.base_conversion import ( from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
BaseConversionConfig,
BaseConversionDataset,
)
def test_base_conversion_config_validation(): def test_base_conversion_config_validation():
@ -38,14 +36,7 @@ def test_base_conversion_dataset_deterministic():
def test_base_conversion_dataset_items(): def test_base_conversion_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = BaseConversionConfig( config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
min_base=2,
max_base=16,
min_value=0,
max_value=1000,
size=10,
seed=42
)
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -74,9 +65,9 @@ def test_base_conversion_dataset_items():
# Verify conversion correctness # Verify conversion correctness
decimal_value = item["metadata"]["decimal_value"] decimal_value = item["metadata"]["decimal_value"]
target_base = item["metadata"]["target_base"] target_base = item["metadata"]["target_base"]
expected = format(decimal_value, 'x' if target_base == 16 else 'b' if target_base == 2 else '').strip() expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip()
if target_base not in (2, 16): if target_base not in (2, 16):
expected = format(decimal_value, f'{target_base}x').lower().strip() expected = format(decimal_value, f"{target_base}x").lower().strip()
assert item["answer"] == expected assert item["answer"] == expected
@ -100,7 +91,7 @@ def test_base_conversion_special_bases():
min_value=0, min_value=0,
max_value=255, # Use small range for predictable results max_value=255, # Use small range for predictable results
size=100, size=100,
seed=42 seed=42,
) )
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)
@ -112,11 +103,11 @@ def test_base_conversion_special_bases():
if item["metadata"]["target_base"] == 2: if item["metadata"]["target_base"] == 2:
binary_found = True binary_found = True
# Verify binary format # Verify binary format
assert all(c in '01' for c in item["answer"]) assert all(c in "01" for c in item["answer"])
elif item["metadata"]["target_base"] == 16: elif item["metadata"]["target_base"] == 16:
hex_found = True hex_found = True
# Verify hex format # Verify hex format
assert all(c in '0123456789abcdef' for c in item["answer"]) assert all(c in "0123456789abcdef" for c in item["answer"])
assert binary_found, "No binary conversion tasks generated" assert binary_found, "No binary conversion tasks generated"
assert hex_found, "No hexadecimal conversion tasks generated" assert hex_found, "No hexadecimal conversion tasks generated"
@ -130,7 +121,7 @@ def test_base_conversion_formatting():
min_value=10, # Ensure multi-digit numbers min_value=10, # Ensure multi-digit numbers
max_value=1000, max_value=1000,
size=10, size=10,
seed=42 seed=42,
) )
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)

View file

@ -1,4 +1,5 @@
import pytest import pytest
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
@ -25,14 +26,7 @@ def test_chain_sum_deterministic():
def test_chain_sum_items(): def test_chain_sum_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
min_terms=2,
max_terms=4,
min_digits=1,
max_digits=2,
size=100,
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -60,7 +54,7 @@ def test_chain_sum_number_ranges():
min_digits=3, # Should generate numbers >= 100 min_digits=3, # Should generate numbers >= 100
max_digits=3, # Should generate numbers <= 999 max_digits=3, # Should generate numbers <= 999
size=50, size=50,
seed=42 seed=42,
) )
dataset = ChainSum(config) dataset = ChainSum(config)
@ -74,16 +68,8 @@ def test_chain_sum_number_ranges():
else: else:
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
# Test 1-digit numbers # Test 1-digit numbers
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
min_terms=2,
max_terms=2,
min_digits=1,
max_digits=1,
size=50,
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
@ -95,16 +81,11 @@ def test_chain_sum_number_ranges():
else: else:
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
def test_chain_sum_negation(): def test_chain_sum_negation():
"""Test that allow_negation controls number ranges""" """Test that allow_negation controls number ranges"""
config = ChainSumConfig( config = ChainSumConfig(
min_terms=2, min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
max_terms=2,
min_digits=2,
max_digits=2,
size=100,
seed=42,
allow_negation=True
) )
dataset = ChainSum(config) dataset = ChainSum(config)
@ -115,7 +96,7 @@ def test_chain_sum_negation():
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith('-') and n[1:].isdigit())] numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())]
for num in numbers: for num in numbers:
if num > 0: if num > 0:
@ -129,12 +110,7 @@ def test_chain_sum_negation():
def test_chain_sum_iteration(): def test_chain_sum_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
min_terms=2,
max_terms=2,
size=5, # Small size for testing
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
# Test manual iteration # Test manual iteration

View file

@ -1,6 +1,8 @@
import pytest
from math import gcd from math import gcd
from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig
import pytest
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
def test_fraction_config_validation(): def test_fraction_config_validation():
@ -34,14 +36,7 @@ def test_fraction_deterministic():
def test_fraction_items(): def test_fraction_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=1, max_value=20, min_factor=2, max_factor=5, size=50, seed=42)
min_value=1,
max_value=20,
min_factor=2,
max_factor=5,
size=50,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -79,14 +74,7 @@ def test_fraction_items():
def test_fraction_ranges(): def test_fraction_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=5, max_value=15, min_factor=3, max_factor=4, size=20, seed=42)
min_value=5,
max_value=15,
min_factor=3,
max_factor=4,
size=20,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -125,14 +113,7 @@ def test_fraction_iteration():
def test_fraction_numerator_smaller(): def test_fraction_numerator_smaller():
"""Test that numerators are always smaller than denominators""" """Test that numerators are always smaller than denominators"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=1, max_value=100, min_factor=2, max_factor=5, size=50, seed=42)
min_value=1,
max_value=100,
min_factor=2,
max_factor=5,
size=50,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -140,9 +121,11 @@ def test_fraction_numerator_smaller():
metadata = item["metadata"] metadata = item["metadata"]
# Check original fraction # Check original fraction
assert metadata["numerator"] <= metadata["denominator"], \ assert (
f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}" metadata["numerator"] <= metadata["denominator"]
), f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
# Check simplified fraction # Check simplified fraction
assert metadata["simplified_numerator"] <= metadata["simplified_denominator"], \ assert (
f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" metadata["simplified_numerator"] <= metadata["simplified_denominator"]
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"

View file

@ -1,7 +1,9 @@
import pytest
from math import gcd
from functools import reduce from functools import reduce
from reasoning_gym.arithmetic import GCDDataset, GCDConfig from math import gcd
import pytest
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
def test_gcd_config_validation(): def test_gcd_config_validation():
@ -35,14 +37,7 @@ def test_gcd_deterministic():
def test_gcd_items(): def test_gcd_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=4, min_value=1, max_value=100, size=50, seed=42)
min_numbers=2,
max_numbers=4,
min_value=1,
max_value=100,
size=50,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -70,14 +65,7 @@ def test_gcd_items():
def test_gcd_number_ranges(): def test_gcd_number_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=2, min_value=50, max_value=100, size=20, seed=42)
min_numbers=2,
max_numbers=2,
min_value=50,
max_value=100,
size=20,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -109,14 +97,7 @@ def test_gcd_iteration():
def test_gcd_special_cases(): def test_gcd_special_cases():
"""Test some special GCD cases""" """Test some special GCD cases"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=100, size=100, seed=42)
min_numbers=2,
max_numbers=2,
min_value=1,
max_value=100,
size=100,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
# Track if we see some interesting GCD cases # Track if we see some interesting GCD cases

View file

@ -1,7 +1,9 @@
import pytest
from math import lcm
from functools import reduce from functools import reduce
from reasoning_gym.arithmetic import LCMDataset, LCMConfig from math import lcm
import pytest
from reasoning_gym.arithmetic import LCMConfig, LCMDataset
def test_lcm_config_validation(): def test_lcm_config_validation():
@ -36,12 +38,7 @@ def test_lcm_deterministic():
def test_lcm_items(): def test_lcm_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LCMConfig( config = LCMConfig(
min_numbers=2, min_numbers=2, max_numbers=4, min_value=1, max_value=20, size=50, seed=42 # Keep small for testing
max_numbers=4,
min_value=1,
max_value=20, # Keep small for testing
size=50,
seed=42
) )
dataset = LCMDataset(config) dataset = LCMDataset(config)
@ -70,14 +67,7 @@ def test_lcm_items():
def test_lcm_number_ranges(): def test_lcm_number_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = LCMConfig( config = LCMConfig(min_numbers=2, max_numbers=2, min_value=5, max_value=15, size=20, seed=42)
min_numbers=2,
max_numbers=2,
min_value=5,
max_value=15,
size=20,
seed=42
)
dataset = LCMDataset(config) dataset = LCMDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -109,14 +99,7 @@ def test_lcm_iteration():
def test_lcm_special_cases(): def test_lcm_special_cases():
"""Test some special LCM cases""" """Test some special LCM cases"""
config = LCMConfig( config = LCMConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=20, size=100, seed=42)
min_numbers=2,
max_numbers=2,
min_value=1,
max_value=20,
size=100,
seed=42
)
dataset = LCMDataset(config) dataset = LCMDataset(config)
# Track if we see some interesting LCM cases # Track if we see some interesting LCM cases

View file

@ -1,11 +1,8 @@
"""Tests for leg counting task generation""" """Tests for leg counting task generation"""
import pytest import pytest
from reasoning_gym.arithmetic.leg_counting import ( from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset
LegCountingConfig,
LegCountingDataset,
ANIMALS,
)
def test_leg_counting_config_validation(): def test_leg_counting_config_validation():
@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic():
def test_leg_counting_dataset_items(): def test_leg_counting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LegCountingConfig( config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42)
min_animals=2,
max_animals=4,
max_instances=2,
size=10,
seed=42
)
dataset = LegCountingDataset(config) dataset = LegCountingDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):

View file

@ -1,10 +1,8 @@
"""Tests for letter counting task generation""" """Tests for letter counting task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.letter_counting import ( from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset
LetterCountingConfig,
LetterCountingDataset,
)
def test_letter_counting_config_validation(): def test_letter_counting_config_validation():
@ -30,12 +28,7 @@ def test_letter_counting_dataset_deterministic():
def test_letter_counting_dataset_items(): def test_letter_counting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LetterCountingConfig( config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42)
min_words=3,
max_words=6,
size=10,
seed=42
)
dataset = LetterCountingDataset(config) dataset = LetterCountingDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):

View file

@ -1,10 +1,8 @@
"""Tests for mini sudoku puzzle generation""" """Tests for mini sudoku puzzle generation"""
import pytest import pytest
from reasoning_gym.games.mini_sudoku import ( from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
MiniSudokuConfig,
MiniSudokuDataset,
)
def test_mini_sudoku_config_validation(): def test_mini_sudoku_config_validation():
@ -34,12 +32,7 @@ def test_mini_sudoku_dataset_deterministic():
def test_mini_sudoku_dataset_items(): def test_mini_sudoku_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = MiniSudokuConfig( config = MiniSudokuConfig(min_empty=8, max_empty=12, size=10, seed=42)
min_empty=8,
max_empty=12,
size=10,
seed=42
)
dataset = MiniSudokuDataset(config) dataset = MiniSudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -94,12 +87,7 @@ def test_mini_sudoku_dataset_iteration():
def test_mini_sudoku_board_generation(): def test_mini_sudoku_board_generation():
"""Test that generated boards are valid""" """Test that generated boards are valid"""
config = MiniSudokuConfig( config = MiniSudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
min_empty=0, # Force complete board
max_empty=0,
size=5,
seed=42
)
dataset = MiniSudokuDataset(config) dataset = MiniSudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -127,7 +115,7 @@ def is_valid_solution(board: list[list[int]]) -> bool:
box = [] box = []
for i in range(2): for i in range(2):
for j in range(2): for j in range(2):
box.append(board[box_i*2 + i][box_j*2 + j]) box.append(board[box_i * 2 + i][box_j * 2 + j])
if set(box) != set(range(1, 5)): if set(box) != set(range(1, 5)):
return False return False

View file

@ -1,10 +1,8 @@
"""Tests for number filtering task generation""" """Tests for number filtering task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.number_filtering import ( from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
NumberFilteringConfig,
NumberFilteringDataset,
)
def test_number_filtering_config_validation(): def test_number_filtering_config_validation():
@ -39,14 +37,7 @@ def test_number_filtering_dataset_deterministic():
def test_number_filtering_dataset_items(): def test_number_filtering_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = NumberFilteringConfig( config = NumberFilteringConfig(
min_numbers=3, min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
max_numbers=6,
min_decimals=1,
max_decimals=3,
min_value=-10.0,
max_value=10.0,
size=10,
seed=42
) )
dataset = NumberFilteringDataset(config) dataset = NumberFilteringDataset(config)
@ -71,7 +62,7 @@ def test_number_filtering_dataset_items():
# Verify decimal places # Verify decimal places
for num in numbers: for num in numbers:
decimal_places = len(num.split('.')[-1]) if '.' in num else 0 decimal_places = len(num.split(".")[-1]) if "." in num else 0
assert decimal_places >= config.min_decimals assert decimal_places >= config.min_decimals
assert decimal_places <= config.max_decimals assert decimal_places <= config.max_decimals
@ -117,11 +108,11 @@ def test_number_filtering_precision():
min_value=0.0, min_value=0.0,
max_value=1.0, max_value=1.0,
size=1, size=1,
seed=42 seed=42,
) )
dataset = NumberFilteringDataset(config) dataset = NumberFilteringDataset(config)
item = dataset[0] item = dataset[0]
# Check that string representations maintain precision # Check that string representations maintain precision
for num in item["metadata"]["original_numbers"]: for num in item["metadata"]["original_numbers"]:
assert len(num.split('.')[-1]) == 2 assert len(num.split(".")[-1]) == 2

View file

@ -1,10 +1,8 @@
"""Tests for number sorting task generation""" """Tests for number sorting task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.number_sorting import ( from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
NumberSortingConfig,
NumberSortingDataset,
)
def test_number_sorting_config_validation(): def test_number_sorting_config_validation():
@ -39,14 +37,7 @@ def test_number_sorting_dataset_deterministic():
def test_number_sorting_dataset_items(): def test_number_sorting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = NumberSortingConfig( config = NumberSortingConfig(
min_numbers=3, min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
max_numbers=6,
min_decimals=1,
max_decimals=3,
min_value=-10.0,
max_value=10.0,
size=10,
seed=42
) )
dataset = NumberSortingDataset(config) dataset = NumberSortingDataset(config)
@ -70,7 +61,7 @@ def test_number_sorting_dataset_items():
# Verify decimal places # Verify decimal places
for num in numbers: for num in numbers:
decimal_places = len(num.split('.')[-1]) if '.' in num else 0 decimal_places = len(num.split(".")[-1]) if "." in num else 0
assert decimal_places >= config.min_decimals assert decimal_places >= config.min_decimals
assert decimal_places <= config.max_decimals assert decimal_places <= config.max_decimals

View file

@ -1,10 +1,8 @@
"""Tests for prime factorization task generation""" """Tests for prime factorization task generation"""
import pytest import pytest
from reasoning_gym.arithmetic.prime_factorization import ( from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
PrimeFactorizationConfig,
PrimeFactorizationDataset,
)
def test_prime_factorization_config_validation(): def test_prime_factorization_config_validation():
@ -30,12 +28,7 @@ def test_prime_factorization_dataset_deterministic():
def test_prime_factorization_dataset_items(): def test_prime_factorization_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = PrimeFactorizationConfig( config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42)
min_value=2,
max_value=100,
size=10,
seed=42
)
dataset = PrimeFactorizationDataset(config) dataset = PrimeFactorizationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -83,12 +76,7 @@ def test_prime_factorization_dataset_iteration():
def test_prime_factorization_known_values(): def test_prime_factorization_known_values():
"""Test factorization of known values""" """Test factorization of known values"""
config = PrimeFactorizationConfig( config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
min_value=12,
max_value=12, # Force specific number
size=1,
seed=42
)
dataset = PrimeFactorizationDataset(config) dataset = PrimeFactorizationDataset(config)
item = dataset[0] item = dataset[0]
@ -101,7 +89,7 @@ def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime""" """Helper function to check if a number is prime"""
if n < 2: if n < 2:
return False return False
for i in range(2, int(n ** 0.5) + 1): for i in range(2, int(n**0.5) + 1):
if n % i == 0: if n % i == 0:
return False return False
return True return True

View file

@ -30,7 +30,7 @@ def test_pattern_rule():
# Test rule composition # Test rule composition
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
rule2 = PatternRule([Operation.ADD], [3]) # Add 3 rule2 = PatternRule([Operation.ADD], [3]) # Add 3
composed = PatternRule.compose([rule1, rule2]) composed = PatternRule.compose([rule1, rule2])
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3 assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3

View file

@ -1,10 +1,8 @@
"""Tests for sudoku puzzle generation""" """Tests for sudoku puzzle generation"""
import pytest import pytest
from reasoning_gym.games.sudoku import ( from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset
SudokuConfig,
SudokuDataset,
)
def test_sudoku_config_validation(): def test_sudoku_config_validation():
@ -34,12 +32,7 @@ def test_sudoku_dataset_deterministic():
def test_sudoku_dataset_items(): def test_sudoku_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = SudokuConfig( config = SudokuConfig(min_empty=30, max_empty=40, size=10, seed=42)
min_empty=30,
max_empty=40,
size=10,
seed=42
)
dataset = SudokuDataset(config) dataset = SudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -94,12 +87,7 @@ def test_sudoku_dataset_iteration():
def test_sudoku_board_generation(): def test_sudoku_board_generation():
"""Test that generated boards are valid""" """Test that generated boards are valid"""
config = SudokuConfig( config = SudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
min_empty=0, # Force complete board
max_empty=0,
size=5,
seed=42
)
dataset = SudokuDataset(config) dataset = SudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -127,7 +115,7 @@ def is_valid_solution(board: list[list[int]]) -> bool:
box = [] box = []
for i in range(3): for i in range(3):
for j in range(3): for j in range(3):
box.append(board[box_i*3 + i][box_j*3 + j]) box.append(board[box_i * 3 + i][box_j * 3 + j])
if set(box) != set(range(1, 10)): if set(box) != set(range(1, 10)):
return False return False

View file

@ -1,10 +1,8 @@
"""Tests for word reversal task generation""" """Tests for word reversal task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.word_reversal import ( from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset
WordReversalConfig,
WordReversalDataset,
)
def test_word_reversal_config_validation(): def test_word_reversal_config_validation():
@ -30,12 +28,7 @@ def test_word_reversal_dataset_deterministic():
def test_word_reversal_dataset_items(): def test_word_reversal_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = WordReversalConfig( config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42)
min_words=3,
max_words=6,
size=10,
seed=42
)
dataset = WordReversalDataset(config) dataset = WordReversalDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):