formatting

This commit is contained in:
Andreas Koepf 2025-01-24 10:34:07 +01:00
parent 98988c8481
commit 20069b2a7d
37 changed files with 504 additions and 666 deletions

View file

@ -2,12 +2,7 @@
Reasoning Gym - A library of procedural dataset generators for training reasoning models Reasoning Gym - A library of procedural dataset generators for training reasoning models
""" """
from . import arithmetic from . import algorithmic, arithmetic, cognition, data, games, logic
from . import algorithmic
from . import cognition
from . import data
from . import games
from . import logic
__version__ = "0.1.0" __version__ = "0.1.0"
__all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"] __all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"]

View file

@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities:
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
@ -20,8 +21,8 @@ __all__ = [
"BaseConversionDataset", "BaseConversionDataset",
"base_conversion_dataset", "base_conversion_dataset",
"chain_sum_dataset", "chain_sum_dataset",
"LetterCountingConfig", "LetterCountingConfig",
"LetterCountingDataset", "LetterCountingDataset",
"letter_counting_dataset", "letter_counting_dataset",
"NumberFilteringConfig", "NumberFilteringConfig",
"NumberFilteringDataset", "NumberFilteringDataset",
@ -31,5 +32,5 @@ __all__ = [
"number_sorting_dataset", "number_sorting_dataset",
"WordReversalConfig", "WordReversalConfig",
"WordReversalDataset", "WordReversalDataset",
"word_reversal_dataset" "word_reversal_dataset",
] ]

View file

@ -1,17 +1,20 @@
"""Base conversion task generator""" """Base conversion task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional, Tuple from typing import Optional, Tuple
@dataclass @dataclass
class BaseConversionConfig: class BaseConversionConfig:
"""Configuration for base conversion task generation""" """Configuration for base conversion task generation"""
min_base: int = 2 # Minimum base (2=binary)
max_base: int = 16 # Maximum base (16=hex) min_base: int = 2 # Minimum base (2=binary)
min_value: int = 0 # Minimum decimal value to convert max_base: int = 16 # Maximum base (16=hex)
max_value: int = 1000 # Maximum decimal value to convert min_value: int = 0 # Minimum decimal value to convert
max_value: int = 1000 # Maximum decimal value to convert
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -55,37 +58,37 @@ class BaseConversionDataset:
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]: def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
"""Generate random value and source/target bases""" """Generate random value and source/target bases"""
value = rng.randint(self.config.min_value, self.config.max_value) value = rng.randint(self.config.min_value, self.config.max_value)
# Choose source and target bases # Choose source and target bases
source_base = rng.randint(self.config.min_base, self.config.max_base) source_base = rng.randint(self.config.min_base, self.config.max_base)
target_base = rng.randint(self.config.min_base, self.config.max_base) target_base = rng.randint(self.config.min_base, self.config.max_base)
while target_base == source_base: # Ensure different bases while target_base == source_base: # Ensure different bases
target_base = rng.randint(self.config.min_base, self.config.max_base) target_base = rng.randint(self.config.min_base, self.config.max_base)
return value, source_base, target_base return value, source_base, target_base
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single base conversion task""" """Generate a single base conversion task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
value, source_base, target_base = self._generate_conversion(rng) value, source_base, target_base = self._generate_conversion(rng)
# Convert decimal to source base representation # Convert decimal to source base representation
source_repr = format(value, f'x' if source_base == 16 else f'b' if source_base == 2 else '').strip() source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip()
if source_base not in (2, 16): if source_base not in (2, 16):
source_repr = format(value, f'{source_base}x').lower().strip() source_repr = format(value, f"{source_base}x").lower().strip()
# Convert decimal to target base for answer # Convert decimal to target base for answer
target_repr = format(value, f'x' if target_base == 16 else f'b' if target_base == 2 else '').strip() target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip()
if target_base not in (2, 16): if target_base not in (2, 16):
target_repr = format(value, f'{target_base}x').lower().strip() target_repr = format(value, f"{target_base}x").lower().strip()
source_name = self._format_base_name(source_base) source_name = self._format_base_name(source_base)
target_name = self._format_base_name(target_base) target_name = self._format_base_name(target_base)
# Add hint for bases > 10 about using lowercase letters # Add hint for bases > 10 about using lowercase letters
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else "" hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
return { return {
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}", "question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
"answer": target_repr, "answer": target_repr,
@ -94,8 +97,8 @@ class BaseConversionDataset:
"source_base": source_base, "source_base": source_base,
"target_base": target_base, "target_base": target_base,
"source_repr": source_repr, "source_repr": source_repr,
"target_repr": target_repr "target_repr": target_repr,
} },
} }

View file

@ -1,18 +1,21 @@
"""Letter counting task generator""" """Letter counting task generator"""
from dataclasses import dataclass
import re import re
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
@dataclass @dataclass
class LetterCountingConfig: class LetterCountingConfig:
"""Configuration for letter counting task generation""" """Configuration for letter counting task generation"""
min_words: int = 5 # Minimum words in span
max_words: int = 15 # Maximum words in span min_words: int = 5 # Minimum words in span
max_words: int = 15 # Maximum words in span
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -27,11 +30,11 @@ class LetterCountingDataset:
self.config = config self.config = config
self.config.validate() self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
# Load and preprocess text # Load and preprocess text
text = read_data_file("in_the_year_2889.txt") text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters # Extract words and clean them to contain only alphanumeric characters
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -50,31 +53,27 @@ class LetterCountingDataset:
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single letter counting task""" """Generate a single letter counting task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Select random span of words # Select random span of words
span_length = rng.randint(self.config.min_words, self.config.max_words) span_length = rng.randint(self.config.min_words, self.config.max_words)
start_idx = rng.randint(0, len(self.words) - span_length) start_idx = rng.randint(0, len(self.words) - span_length)
span = self.words[start_idx:start_idx + span_length] span = self.words[start_idx : start_idx + span_length]
# Get all unique letters from span # Get all unique letters from span
letters = set(''.join(span).lower()) letters = set("".join(span).lower())
if not letters: if not letters:
letters = {'a'} # Fallback if span has no letters letters = {"a"} # Fallback if span has no letters
# Select random letter that appears in the span # Select random letter that appears in the span
target_letter = rng.choice(list(letters)) target_letter = rng.choice(list(letters))
# Count occurrences # Count occurrences
count = sum(word.lower().count(target_letter) for word in span) count = sum(word.lower().count(target_letter) for word in span)
return { return {
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?', "question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
"answer": str(count), "answer": str(count),
"metadata": { "metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
"span_length": span_length,
"target_letter": target_letter,
"span": span
}
} }

View file

@ -1,20 +1,23 @@
"""Number filtering task generator""" """Number filtering task generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@dataclass @dataclass
class NumberFilteringConfig: class NumberFilteringConfig:
"""Configuration for number filtering task generation""" """Configuration for number filtering task generation"""
min_numbers: int = 3 # Minimum numbers in list
max_numbers: int = 10 # Maximum numbers in list min_numbers: int = 3 # Minimum numbers in list
min_decimals: int = 0 # Minimum decimal places max_numbers: int = 10 # Maximum numbers in list
max_decimals: int = 4 # Maximum decimal places min_decimals: int = 0 # Minimum decimal places
min_value: float = -100.0 # Minimum number value max_decimals: int = 4 # Maximum decimal places
max_value: float = 100.0 # Maximum number value min_value: float = -100.0 # Minimum number value
max_value: float = 100.0 # Maximum number value
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -56,23 +59,23 @@ class NumberFilteringDataset:
count = rng.randint(self.config.min_numbers, self.config.max_numbers) count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [] numbers = []
str_numbers = [] str_numbers = []
for _ in range(count): for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value) num = rng.uniform(self.config.min_value, self.config.max_value)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
str_num = self._format_number(num, decimals) str_num = self._format_number(num, decimals)
numbers.append(float(str_num)) # Convert back to simulate precision loss numbers.append(float(str_num)) # Convert back to simulate precision loss
str_numbers.append(str_num) str_numbers.append(str_num)
return numbers, str_numbers return numbers, str_numbers
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single number filtering task""" """Generate a single number filtering task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Generate numbers and their string representations # Generate numbers and their string representations
numbers, str_numbers = self._generate_numbers(rng) numbers, str_numbers = self._generate_numbers(rng)
# Determine filter value between min and max of generated numbers # Determine filter value between min and max of generated numbers
min_val = min(numbers) min_val = min(numbers)
max_val = max(numbers) max_val = max(numbers)
@ -80,31 +83,33 @@ class NumberFilteringDataset:
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
filter_str = self._format_number(filter_value, decimals) filter_str = self._format_number(filter_value, decimals)
filter_value = float(filter_str) # Convert back to simulate precision loss filter_value = float(filter_str) # Convert back to simulate precision loss
# Randomly choose filter operation # Randomly choose filter operation
keep_larger = rng.choice([True, False]) keep_larger = rng.choice([True, False])
larger_smaller = "larger" if keep_larger else "smaller" larger_smaller = "larger" if keep_larger else "smaller"
keep_remove = "keep" if rng.choice([True, False]) else "remove" keep_remove = "keep" if rng.choice([True, False]) else "remove"
# Apply filter based on chosen operation # Apply filter based on chosen operation
if keep_remove == "keep": if keep_remove == "keep":
result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)] result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)]
else: # remove else: # remove
result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)] result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)]
# Format results as strings with original precision # Format results as strings with original precision
result_strs = [str_numbers[numbers.index(n)] for n in result] result_strs = [str_numbers[numbers.index(n)] for n in result]
return { return {
"question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} " "question": (
f"in this list: {str_numbers}"), f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} "
f"in this list: {str_numbers}"
),
"answer": str(result_strs) if result_strs else "[]", "answer": str(result_strs) if result_strs else "[]",
"metadata": { "metadata": {
"original_numbers": str_numbers, "original_numbers": str_numbers,
"filter_value": filter_str, "filter_value": filter_str,
"operation": f"{keep_remove}_{larger_smaller}", "operation": f"{keep_remove}_{larger_smaller}",
"result": result_strs "result": result_strs,
} },
} }

View file

@ -1,20 +1,23 @@
"""Number sorting task generator""" """Number sorting task generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@dataclass @dataclass
class NumberSortingConfig: class NumberSortingConfig:
"""Configuration for number sorting task generation""" """Configuration for number sorting task generation"""
min_numbers: int = 3 # Minimum numbers to sort
max_numbers: int = 10 # Maximum numbers to sort min_numbers: int = 3 # Minimum numbers to sort
min_decimals: int = 0 # Minimum decimal places max_numbers: int = 10 # Maximum numbers to sort
max_decimals: int = 2 # Maximum decimal places min_decimals: int = 0 # Minimum decimal places
max_decimals: int = 2 # Maximum decimal places
min_value: float = -100.0 # Minimum value min_value: float = -100.0 # Minimum value
max_value: float = 100.0 # Maximum value max_value: float = 100.0 # Maximum value
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -57,10 +60,10 @@ class NumberSortingDataset:
"""Generate list of numbers and their string representations""" """Generate list of numbers and their string representations"""
count = rng.randint(self.config.min_numbers, self.config.max_numbers) count = rng.randint(self.config.min_numbers, self.config.max_numbers)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
numbers = [] numbers = []
number_strs = [] number_strs = []
for _ in range(count): for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value) num = rng.uniform(self.config.min_value, self.config.max_value)
num_str = self._format_number(num, decimals) num_str = self._format_number(num, decimals)
@ -68,37 +71,33 @@ class NumberSortingDataset:
num = float(num_str) num = float(num_str)
numbers.append(num) numbers.append(num)
number_strs.append(num_str) number_strs.append(num_str)
return numbers, number_strs return numbers, number_strs
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single sorting task""" """Generate a single sorting task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
numbers, number_strs = self._generate_numbers(rng) numbers, number_strs = self._generate_numbers(rng)
# Generate both ascending and descending answers # Generate both ascending and descending answers
asc_numbers = sorted(numbers) asc_numbers = sorted(numbers)
desc_numbers = sorted(numbers, reverse=True) desc_numbers = sorted(numbers, reverse=True)
# Format answers as string lists # Format answers as string lists
decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0 decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
asc_answer = [self._format_number(n, decimals) for n in asc_numbers] asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
desc_answer = [self._format_number(n, decimals) for n in desc_numbers] desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
# Randomly choose ascending or descending # Randomly choose ascending or descending
is_ascending = rng.choice([True, False]) is_ascending = rng.choice([True, False])
direction = "ascending" if is_ascending else "descending" direction = "ascending" if is_ascending else "descending"
answer = asc_answer if is_ascending else desc_answer answer = asc_answer if is_ascending else desc_answer
return { return {
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
"answer": str(answer), "answer": str(answer),
"metadata": { "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
"original_numbers": number_strs,
"direction": direction,
"sorted_numbers": answer
}
} }

View file

@ -1,18 +1,21 @@
"""Word reversal task generator""" """Word reversal task generator"""
from dataclasses import dataclass
import re import re
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
@dataclass @dataclass
class WordReversalConfig: class WordReversalConfig:
"""Configuration for word reversal task generation""" """Configuration for word reversal task generation"""
min_words: int = 3 # Minimum words in list
max_words: int = 8 # Maximum words in list min_words: int = 3 # Minimum words in list
max_words: int = 8 # Maximum words in list
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -27,11 +30,11 @@ class WordReversalDataset:
self.config = config self.config = config
self.config.validate() self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
# Load and preprocess text # Load and preprocess text
text = read_data_file("in_the_year_2889.txt") text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters # Extract words and clean them to contain only alphanumeric characters
self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -50,23 +53,20 @@ class WordReversalDataset:
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single word reversal task""" """Generate a single word reversal task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Select random words # Select random words
num_words = rng.randint(self.config.min_words, self.config.max_words) num_words = rng.randint(self.config.min_words, self.config.max_words)
word_indices = rng.sample(range(len(self.words)), num_words) word_indices = rng.sample(range(len(self.words)), num_words)
words = [self.words[i] for i in word_indices] words = [self.words[i] for i in word_indices]
# Create question and answer # Create question and answer
question = ", ".join(words) question = ", ".join(words)
answer = ", ".join(reversed(words)) answer = ", ".join(reversed(words))
return { return {
"question": f"Reverse this list of words: {question}", "question": f"Reverse this list of words: {question}",
"answer": answer, "answer": answer,
"metadata": { "metadata": {"num_words": num_words, "words": words},
"num_words": num_words,
"words": words
}
} }

View file

@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities:
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset from .fraction_simplification import (
FractionSimplificationConfig,
FractionSimplificationDataset,
fraction_simplification_dataset,
)
from .gcd import GCDConfig, GCDDataset, gcd_dataset from .gcd import GCDConfig, GCDDataset, gcd_dataset
from .lcm import LCMConfig, LCMDataset, lcm_dataset from .lcm import LCMConfig, LCMDataset, lcm_dataset
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
@ -25,7 +29,7 @@ __all__ = [
"FractionSimplificationDataset", "FractionSimplificationDataset",
"fraction_simplification_dataset", "fraction_simplification_dataset",
"GCDConfig", "GCDConfig",
"GCDDataset", "GCDDataset",
"gcd_dataset", "gcd_dataset",
"LCMConfig", "LCMConfig",
"LCMDataset", "LCMDataset",
@ -35,5 +39,5 @@ __all__ = [
"leg_counting_dataset", "leg_counting_dataset",
"PrimeFactorizationConfig", "PrimeFactorizationConfig",
"PrimeFactorizationDataset", "PrimeFactorizationDataset",
"prime_factorization_dataset" "prime_factorization_dataset",
] ]

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset):
expression = " ".join(expression_parts) expression = " ".join(expression_parts)
return expression, result return expression, result
def _format_question(self, rng: Random, expression: str) -> str: def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style""" """Format the expression according to config style"""
if self.config.format_style == "simple": if self.config.format_style == "simple":

View file

@ -1,6 +1,7 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset):
}, },
} }
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
"""Generate a chain sum task """Generate a chain sum task

View file

@ -1,21 +1,24 @@
"""Fraction simplification task generator""" """Fraction simplification task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random
from typing import Optional, Tuple, Sequence
from ..dataset import ProceduralDataset
from math import gcd from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from ..dataset import ProceduralDataset
@dataclass @dataclass
class FractionSimplificationConfig: class FractionSimplificationConfig:
"""Configuration for fraction simplification task generation""" """Configuration for fraction simplification task generation"""
min_value: int = 1 # Minimum value for numerator/denominator
max_value: int = 1000 # Maximum value for numerator/denominator min_value: int = 1 # Minimum value for numerator/denominator
min_factor: int = 1 # Minimum multiplication factor max_value: int = 1000 # Maximum value for numerator/denominator
max_factor: int = 100 # Maximum multiplication factor min_factor: int = 1 # Minimum multiplication factor
max_factor: int = 100 # Maximum multiplication factor
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -23,7 +26,7 @@ class FractionSimplificationConfig:
assert self.max_value > self.min_value, "max_value must be > min_value" assert self.max_value > self.min_value, "max_value must be > min_value"
assert self.min_factor >= 1, "min_factor must be at least 1" assert self.min_factor >= 1, "min_factor must be at least 1"
assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor"
# Validate styles # Validate styles
valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"} valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"}
for style in self.styles: for style in self.styles:
@ -46,37 +49,38 @@ class FractionSimplificationDataset(ProceduralDataset):
# Generate the simplified fraction first # Generate the simplified fraction first
simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Make sure they're coprime by dividing by their GCD # Make sure they're coprime by dividing by their GCD
common = gcd(simplified_num, simplified_den) common = gcd(simplified_num, simplified_den)
simplified_num //= common simplified_num //= common
simplified_den //= common simplified_den //= common
# Check if simplified fraction is within bounds # Check if simplified fraction is within bounds
if (self.config.min_value <= simplified_num <= self.config.max_value and if (
self.config.min_value <= simplified_den <= self.config.max_value): self.config.min_value <= simplified_num <= self.config.max_value
and self.config.min_value <= simplified_den <= self.config.max_value
):
# Ensure numerator is smaller than denominator # Ensure numerator is smaller than denominator
if simplified_num > simplified_den: if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num simplified_num, simplified_den = simplified_den, simplified_num
# Multiply both by a random factor to create the unsimplified version # Multiply both by a random factor to create the unsimplified version
factor = rng.randint(self.config.min_factor, self.config.max_factor) factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor numerator = simplified_num * factor
denominator = simplified_den * factor denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den return numerator, denominator, simplified_num, simplified_den
# If we failed to find a good fraction after max attempts, # If we failed to find a good fraction after max attempts,
# generate one that's guaranteed to be within bounds # generate one that's guaranteed to be within bounds
simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_num = rng.randint(self.config.min_value, self.config.max_value)
simplified_den = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value)
# Ensure numerator is smaller than denominator # Ensure numerator is smaller than denominator
if simplified_num > simplified_den: if simplified_num > simplified_den:
simplified_num, simplified_den = simplified_den, simplified_num simplified_num, simplified_den = simplified_den, simplified_num
factor = rng.randint(self.config.min_factor, self.config.max_factor) factor = rng.randint(self.config.min_factor, self.config.max_factor)
return (simplified_num * factor, simplified_den * factor, return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
simplified_num, simplified_den)
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles""" """Format a fraction in various styles"""
@ -95,16 +99,16 @@ class FractionSimplificationDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single fraction simplification task""" """Generate a single fraction simplification task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
num, den, simple_num, simple_den = self._generate_fraction(rng) num, den, simple_num, simple_den = self._generate_fraction(rng)
# Choose a random style from configured styles # Choose a random style from configured styles
style = self.config.styles[rng.randint(0, len(self.config.styles)-1)] style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
# Format both question and answer in the same style # Format both question and answer in the same style
question_fraction = self._format_fraction(num, den, style) question_fraction = self._format_fraction(num, den, style)
answer_fraction = self._format_fraction(simple_num, simple_den, style) answer_fraction = self._format_fraction(simple_num, simple_den, style)
return { return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms", "question": f"Simplify the fraction {question_fraction} to its lowest terms",
"answer": answer_fraction, "answer": answer_fraction,
@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_numerator": simple_num, "simplified_numerator": simple_num,
"simplified_denominator": simple_den, "simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den "reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style "style": style,
} },
} }

View file

@ -1,21 +1,24 @@
"""Greatest Common Divisor (GCD) task generator""" """Greatest Common Divisor (GCD) task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
from math import gcd
from functools import reduce
@dataclass @dataclass
class GCDConfig: class GCDConfig:
"""Configuration for GCD task generation""" """Configuration for GCD task generation"""
min_numbers: int = 2 # Minimum numbers to find GCD of
max_numbers: int = 2 # Maximum numbers to find GCD of min_numbers: int = 2 # Minimum numbers to find GCD of
min_value: int = 1 # Minimum value for each number max_numbers: int = 2 # Maximum numbers to find GCD of
max_value: int = 1000 # Maximum value for each number min_value: int = 1 # Minimum value for each number
max_value: int = 1000 # Maximum value for each number
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -38,33 +41,28 @@ class GCDDataset(ProceduralDataset):
Will try up to 3 times to find numbers with GCD > 1.""" Will try up to 3 times to find numbers with GCD > 1."""
for _ in range(3): # Try up to 3 times to get GCD > 1 for _ in range(3): # Try up to 3 times to get GCD > 1
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(gcd, numbers) result = reduce(gcd, numbers)
if result > 1: if result > 1:
return numbers, result return numbers, result
# If we failed to find GCD > 1 after 3 tries, generate one final set # If we failed to find GCD > 1 after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(gcd, numbers) result = reduce(gcd, numbers)
return numbers, result return numbers, result
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single GCD task""" """Generate a single GCD task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng) numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers) numbers_str = ", ".join(str(n) for n in numbers)
return { return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}", "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {"numbers": numbers, "result": result},
"numbers": numbers,
"result": result
}
} }

View file

@ -1,21 +1,24 @@
"""Least Common Multiple (LCM) task generator""" """Least Common Multiple (LCM) task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import lcm
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
from math import lcm
from functools import reduce
@dataclass @dataclass
class LCMConfig: class LCMConfig:
"""Configuration for LCM task generation""" """Configuration for LCM task generation"""
min_numbers: int = 2 # Minimum numbers to find LCM of
max_numbers: int = 2 # Maximum numbers to find LCM of min_numbers: int = 2 # Minimum numbers to find LCM of
min_value: int = 1 # Minimum value for each number max_numbers: int = 2 # Maximum numbers to find LCM of
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast) min_value: int = 1 # Minimum value for each number
max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast)
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -36,38 +39,34 @@ class LCMDataset(ProceduralDataset):
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
"""Generate a list of random positive integers and their LCM. """Generate a list of random positive integers and their LCM.
Will try up to 3 times to find numbers with LCM < product.""" Will try up to 3 times to find numbers with LCM < product."""
def calculate_product(nums: List[int]) -> int: def calculate_product(nums: List[int]) -> int:
return reduce(lambda x, y: x * y, nums) return reduce(lambda x, y: x * y, nums)
for _ in range(3): # Try up to 3 times to get LCM < product for _ in range(3): # Try up to 3 times to get LCM < product
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(lcm, numbers) result = reduce(lcm, numbers)
if result < calculate_product(numbers): if result < calculate_product(numbers):
return numbers, result return numbers, result
# If we failed to find LCM < product after 3 tries, generate one final set # If we failed to find LCM < product after 3 tries, generate one final set
num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) num_count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = [rng.randint(self.config.min_value, self.config.max_value) numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)]
for _ in range(num_count)]
result = reduce(lcm, numbers) result = reduce(lcm, numbers)
return numbers, result return numbers, result
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single LCM task""" """Generate a single LCM task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng) numbers, result = self._generate_numbers(rng)
numbers_str = ", ".join(str(n) for n in numbers) numbers_str = ", ".join(str(n) for n in numbers)
return { return {
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}", "question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {"numbers": numbers, "result": result},
"numbers": numbers,
"result": result
}
} }

View file

@ -1,7 +1,9 @@
"""Leg counting task generator""" """Leg counting task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Dict, Optional from typing import Dict, Optional
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
ANIMALS = { ANIMALS = {
@ -52,14 +54,16 @@ ANIMALS = {
"woodlouse": 14, "woodlouse": 14,
} }
@dataclass @dataclass
class LegCountingConfig: class LegCountingConfig:
"""Configuration for leg counting task generation""" """Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals min_animals: int = 2 # Minimum number of animals in problem
max_instances: int = 3 # Maximum instances of each animal max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -80,39 +84,36 @@ class LegCountingDataset(ProceduralDataset):
"""Generate a random set of animals and their counts""" """Generate a random set of animals and their counts"""
num_types = rng.randint(self.config.min_animals, self.config.max_animals) num_types = rng.randint(self.config.min_animals, self.config.max_animals)
animals = {} animals = {}
# Select random animals # Select random animals
selected_animals = rng.sample(list(ANIMALS.keys()), num_types) selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
for animal in selected_animals: for animal in selected_animals:
count = rng.randint(1, self.config.max_instances) count = rng.randint(1, self.config.max_instances)
animals[animal] = count animals[animal] = count
return animals return animals
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single leg counting task""" """Generate a single leg counting task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Generate random animals and their counts # Generate random animals and their counts
animals = self._generate_animals(rng) animals = self._generate_animals(rng)
# Calculate total legs # Calculate total legs
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items()) total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
# Format animal counts for question # Format animal counts for question
animal_list = [] animal_list = []
for animal, count in animals.items(): for animal, count in animals.items():
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}") animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?" question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
return { return {
"question": question, "question": question,
"answer": str(total_legs), "answer": str(total_legs),
"metadata": { "metadata": {"animals": animals, "total_legs": total_legs},
"animals": animals,
"total_legs": total_legs
}
} }

View file

@ -1,16 +1,20 @@
"""Prime factorization task generator""" """Prime factorization task generator"""
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..dataset import ProceduralDataset
@dataclass @dataclass
class PrimeFactorizationConfig: class PrimeFactorizationConfig:
"""Configuration for prime factorization task generation""" """Configuration for prime factorization task generation"""
min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize min_value: int = 2 # Minimum number to factorize
max_value: int = 1000 # Maximum number to factorize
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -44,24 +48,23 @@ class PrimeFactorizationDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single prime factorization task""" """Generate a single prime factorization task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Generate random number to factorize # Generate random number to factorize
number = rng.randint(self.config.min_value, self.config.max_value) number = rng.randint(self.config.min_value, self.config.max_value)
# Calculate prime factors # Calculate prime factors
factors = self._prime_factors(number) factors = self._prime_factors(number)
# Format answer as multiplication of prime factors # Format answer as multiplication of prime factors
answer = " × ".join(map(str, factors)) answer = " × ".join(map(str, factors))
return { return {
"question": (f"Find the prime factorization of {number}. Write the factors separated by × " "question": (
f"(Example: for 12 the answer would be: 2 × 2 × 3)"), f"Find the prime factorization of {number}. Write the factors separated by × "
f"(Example: for 12 the answer would be: 2 × 2 × 3)"
),
"answer": answer, "answer": answer,
"metadata": { "metadata": {"number": number, "factors": factors},
"number": number,
"factors": factors
}
} }

View file

@ -40,7 +40,7 @@ class SequenceConfig:
class PatternRule: class PatternRule:
"""Represents a composable sequence pattern rule""" """Represents a composable sequence pattern rule"""
def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None): def __init__(self, operations: List[Operation], parameters: List[int], subrules: List["PatternRule"] = None):
self.operations = operations self.operations = operations
self.parameters = parameters self.parameters = parameters
self.subrules = subrules or [] self.subrules = subrules or []
@ -66,14 +66,14 @@ class PatternRule:
elif op == Operation.COMPOSE: elif op == Operation.COMPOSE:
# Apply each subrule in sequence, passing the result through # Apply each subrule in sequence, passing the result through
for subrule in self.subrules: for subrule in self.subrules:
temp_sequence = sequence[:position + 1] temp_sequence = sequence[: position + 1]
temp_sequence[-1] = result # Use current result as input temp_sequence[-1] = result # Use current result as input
result = subrule.apply(temp_sequence, position) result = subrule.apply(temp_sequence, position)
return result return result
@classmethod @classmethod
def compose(cls, rules: List['PatternRule']) -> 'PatternRule': def compose(cls, rules: List["PatternRule"]) -> "PatternRule":
"""Create a new rule that composes multiple rules together""" """Create a new rule that composes multiple rules together"""
return cls([Operation.COMPOSE], [0], subrules=rules) return cls([Operation.COMPOSE], [0], subrules=rules)

View file

@ -4,34 +4,37 @@ from importlib import resources
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
def get_data_file_path(filename: str) -> Path: def get_data_file_path(filename: str) -> Path:
"""Get the path to a data file in the package. """Get the path to a data file in the package.
Args: Args:
filename: Name of the file in the data directory filename: Name of the file in the data directory
Returns: Returns:
Path object pointing to the data file Path object pointing to the data file
Example: Example:
>>> path = get_data_file_path("pg19362.txt") >>> path = get_data_file_path("pg19362.txt")
>>> with open(path) as f: >>> with open(path) as f:
... content = f.read() ... content = f.read()
""" """
return resources.files('reasoning_gym.data').joinpath(filename) return resources.files("reasoning_gym.data").joinpath(filename)
def read_data_file(filename: str) -> str: def read_data_file(filename: str) -> str:
"""Read the contents of a data file in the package. """Read the contents of a data file in the package.
Args: Args:
filename: Name of the file in the data directory filename: Name of the file in the data directory
Returns: Returns:
String contents of the file String contents of the file
Example: Example:
>>> content = read_data_file("pg19362.txt") >>> content = read_data_file("pg19362.txt")
""" """
return resources.files('reasoning_gym.data').joinpath(filename).read_text() return resources.files("reasoning_gym.data").joinpath(filename).read_text()
__all__ = ['get_data_file_path', 'read_data_file']
__all__ = ["get_data_file_path", "read_data_file"]

View file

@ -1,5 +1,5 @@
The Project Gutenberg eBook of In the year 2889 The Project Gutenberg eBook of In the year 2889
This ebook is for the use of anyone anywhere in the United States and This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms whatsoever. You may copy it, give it away or re-use it under the terms
@ -702,7 +702,7 @@ End of Project Gutenberg's In the Year 2889, by Jules Verne and Michel Verne
*** END OF THE PROJECT GUTENBERG EBOOK IN THE YEAR 2889 *** *** END OF THE PROJECT GUTENBERG EBOOK IN THE YEAR 2889 ***
Updated editions will replace the previous one—the old editions will Updated editions will replace the previous one—the old editions will
be renamed. be renamed.
@ -807,7 +807,7 @@ performed, viewed, copied or distributed:
at www.gutenberg.org. If you at www.gutenberg.org. If you
are not located in the United States, you will have to check the laws are not located in the United States, you will have to check the laws
of the country where you are located before using this eBook. of the country where you are located before using this eBook.
1.E.2. If an individual Project Gutenberg™ electronic work is 1.E.2. If an individual Project Gutenberg™ electronic work is
derived from texts not protected by U.S. copyright law (does not derived from texts not protected by U.S. copyright law (does not
contain a notice indicating that it is posted with permission of the contain a notice indicating that it is posted with permission of the
@ -869,7 +869,7 @@ provided that:
Gutenberg Literary Archive Foundation at the address specified in Gutenberg Literary Archive Foundation at the address specified in
Section 4, “Information about donations to the Project Gutenberg Section 4, “Information about donations to the Project Gutenberg
Literary Archive Foundation.” Literary Archive Foundation.”
• You provide a full refund of any money paid by a user who notifies • You provide a full refund of any money paid by a user who notifies
you in writing (or by e-mail) within 30 days of receipt that s/he you in writing (or by e-mail) within 30 days of receipt that s/he
does not agree to the terms of the full Project Gutenberg™ does not agree to the terms of the full Project Gutenberg™
@ -877,15 +877,15 @@ provided that:
copies of the works possessed in a physical medium and discontinue copies of the works possessed in a physical medium and discontinue
all use of and all access to other copies of Project Gutenberg™ all use of and all access to other copies of Project Gutenberg™
works. works.
• You provide, in accordance with paragraph 1.F.3, a full refund of • You provide, in accordance with paragraph 1.F.3, a full refund of
any money paid for a work or a replacement copy, if a defect in the any money paid for a work or a replacement copy, if a defect in the
electronic work is discovered and reported to you within 90 days of electronic work is discovered and reported to you within 90 days of
receipt of the work. receipt of the work.
• You comply with all other terms of this agreement for free • You comply with all other terms of this agreement for free
distribution of Project Gutenberg™ works. distribution of Project Gutenberg™ works.
1.E.9. If you wish to charge a fee or distribute a Project 1.E.9. If you wish to charge a fee or distribute a Project
Gutenberg™ electronic work or group of works on different terms than Gutenberg™ electronic work or group of works on different terms than
@ -1048,5 +1048,3 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks. subscribe to our email newsletter to hear about new eBooks.

View file

@ -1,27 +1,28 @@
"""Base class for procedural dataset generators""" """Base class for procedural dataset generators"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sized, Iterable from collections.abc import Iterable, Sized
from random import Random from random import Random
from typing import Optional, Iterator, Dict, Any from typing import Any, Dict, Iterator, Optional
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
"""Abstract base class for procedural dataset generators""" """Abstract base class for procedural dataset generators"""
def __init__(self, seed: Optional[int] = None, size: int = 500): def __init__(self, seed: Optional[int] = None, size: int = 500):
"""Initialize the dataset with optional seed and size""" """Initialize the dataset with optional seed and size"""
self.size = size self.size = size
self.seed = seed if seed is not None else Random().randint(0, 2**32) self.seed = seed if seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int: def __len__(self) -> int:
"""Return the virtual size of the dataset""" """Return the virtual size of the dataset"""
return self.size return self.size
def __iter__(self): def __iter__(self):
"""Make the dataset iterable""" """Make the dataset iterable"""
self._current_idx = 0 self._current_idx = 0
return self return self
def __next__(self) -> Dict[str, Any]: def __next__(self) -> Dict[str, Any]:
"""Get next item in iteration""" """Get next item in iteration"""
if self._current_idx >= self.size: if self._current_idx >= self.size:
@ -29,14 +30,14 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
item = self[self._current_idx] item = self[self._current_idx]
self._current_idx += 1 self._current_idx += 1
return item return item
@abstractmethod @abstractmethod
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item """Generate a single dataset item
Args: Args:
idx: Index of the item to generate idx: Index of the item to generate
Returns: Returns:
dict containing at least: dict containing at least:
- question: str - question: str

View file

@ -14,5 +14,5 @@ __all__ = [
"mini_sudoku_dataset", "mini_sudoku_dataset",
"SudokuConfig", "SudokuConfig",
"SudokuDataset", "SudokuDataset",
"sudoku_dataset" "sudoku_dataset",
] ]

View file

@ -1,16 +1,19 @@
"""Mini Sudoku (4x4) puzzle generator""" """Mini Sudoku (4x4) puzzle generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
@dataclass @dataclass
class MiniSudokuConfig: class MiniSudokuConfig:
"""Configuration for 4x4 sudoku puzzle generation""" """Configuration for 4x4 sudoku puzzle generation"""
min_empty: int = 8 # Minimum number of empty cells
max_empty: int = 12 # Maximum number of empty cells min_empty: int = 8 # Minimum number of empty cells
max_empty: int = 12 # Maximum number of empty cells
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -45,11 +48,11 @@ class MiniSudokuDataset:
# Check row # Check row
if num in board[row]: if num in board[row]:
return False return False
# Check column # Check column
if num in [board[i][col] for i in range(4)]: if num in [board[i][col] for i in range(4)]:
return False return False
# Check 2x2 box # Check 2x2 box
box_row, box_col = 2 * (row // 2), 2 * (col // 2) box_row, box_col = 2 * (row // 2), 2 * (col // 2)
for i in range(box_row, box_row + 2): for i in range(box_row, box_row + 2):
@ -63,7 +66,7 @@ class MiniSudokuDataset:
empty = self._find_empty(board) empty = self._find_empty(board)
if not empty: if not empty:
return True return True
row, col = empty row, col = empty
for num in range(1, 5): for num in range(1, 5):
if self._is_valid(board, row, col, num): if self._is_valid(board, row, col, num):
@ -84,7 +87,7 @@ class MiniSudokuDataset:
def _generate_solved_board(self, rng: Random) -> List[List[int]]: def _generate_solved_board(self, rng: Random) -> List[List[int]]:
"""Generate a complete solved mini sudoku board""" """Generate a complete solved mini sudoku board"""
board = [[0] * 4 for _ in range(4)] board = [[0] * 4 for _ in range(4)]
# Try multiple times to generate a valid board # Try multiple times to generate a valid board
max_attempts = 100 max_attempts = 100
for _ in range(max_attempts): for _ in range(max_attempts):
@ -92,7 +95,7 @@ class MiniSudokuDataset:
for i in range(4): for i in range(4):
for j in range(4): for j in range(4):
board[i][j] = 0 board[i][j] = 0
# Fill diagonal boxes first (they are independent) # Fill diagonal boxes first (they are independent)
for i in range(0, 4, 2): for i in range(0, 4, 2):
nums = list(range(1, 5)) nums = list(range(1, 5))
@ -102,11 +105,11 @@ class MiniSudokuDataset:
for c in range(i, i + 2): for c in range(i, i + 2):
board[r][c] = nums[pos] board[r][c] = nums[pos]
pos += 1 pos += 1
# Try to solve the rest # Try to solve the rest
if self._solve(board): if self._solve(board):
return board return board
raise RuntimeError("Failed to generate valid mini sudoku board") raise RuntimeError("Failed to generate valid mini sudoku board")
def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]: def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]:
@ -114,10 +117,10 @@ class MiniSudokuDataset:
puzzle = [row[:] for row in solved_board] puzzle = [row[:] for row in solved_board]
cells = [(i, j) for i in range(4) for j in range(4)] cells = [(i, j) for i in range(4) for j in range(4)]
rng.shuffle(cells) rng.shuffle(cells)
for i, j in cells[:num_empty]: for i, j in cells[:num_empty]:
puzzle[i][j] = 0 puzzle[i][j] = 0
return puzzle return puzzle
def _board_to_string(self, board: List[List[int]]) -> str: def _board_to_string(self, board: List[List[int]]) -> str:
@ -127,26 +130,22 @@ class MiniSudokuDataset:
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single mini sudoku puzzle""" """Generate a single mini sudoku puzzle"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Generate solved board # Generate solved board
solved_board = self._generate_solved_board(rng) solved_board = self._generate_solved_board(rng)
# Create puzzle by removing numbers # Create puzzle by removing numbers
num_empty = rng.randint(self.config.min_empty, self.config.max_empty) num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
puzzle = self._create_puzzle(solved_board, num_empty, rng) puzzle = self._create_puzzle(solved_board, num_empty, rng)
# Format as strings # Format as strings
puzzle_str = self._board_to_string(puzzle) puzzle_str = self._board_to_string(puzzle)
solution_str = self._board_to_string(solved_board) solution_str = self._board_to_string(solved_board)
return { return {
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}", "question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
"answer": solution_str, "answer": solution_str,
"metadata": { "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty
}
} }

View file

@ -1,16 +1,19 @@
"""Sudoku puzzle generator""" """Sudoku puzzle generator"""
from dataclasses import dataclass
import random import random
from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
@dataclass @dataclass
class SudokuConfig: class SudokuConfig:
"""Configuration for sudoku puzzle generation""" """Configuration for sudoku puzzle generation"""
min_empty: int = 30 # Minimum number of empty cells
max_empty: int = 50 # Maximum number of empty cells min_empty: int = 30 # Minimum number of empty cells
max_empty: int = 50 # Maximum number of empty cells
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -45,11 +48,11 @@ class SudokuDataset:
# Check row # Check row
if num in board[row]: if num in board[row]:
return False return False
# Check column # Check column
if num in [board[i][col] for i in range(9)]: if num in [board[i][col] for i in range(9)]:
return False return False
# Check 3x3 box # Check 3x3 box
box_row, box_col = 3 * (row // 3), 3 * (col // 3) box_row, box_col = 3 * (row // 3), 3 * (col // 3)
for i in range(box_row, box_row + 3): for i in range(box_row, box_row + 3):
@ -63,7 +66,7 @@ class SudokuDataset:
empty = self._find_empty(board) empty = self._find_empty(board)
if not empty: if not empty:
return True return True
row, col = empty row, col = empty
for num in range(1, 10): for num in range(1, 10):
if self._is_valid(board, row, col, num): if self._is_valid(board, row, col, num):
@ -84,7 +87,7 @@ class SudokuDataset:
def _generate_solved_board(self, rng: Random) -> List[List[int]]: def _generate_solved_board(self, rng: Random) -> List[List[int]]:
"""Generate a complete solved sudoku board""" """Generate a complete solved sudoku board"""
board = [[0] * 9 for _ in range(9)] board = [[0] * 9 for _ in range(9)]
# Fill diagonal boxes first (they are independent) # Fill diagonal boxes first (they are independent)
for i in range(0, 9, 3): for i in range(0, 9, 3):
nums = list(range(1, 10)) nums = list(range(1, 10))
@ -94,7 +97,7 @@ class SudokuDataset:
for c in range(i, i + 3): for c in range(i, i + 3):
board[r][c] = nums[pos] board[r][c] = nums[pos]
pos += 1 pos += 1
# Solve the rest # Solve the rest
self._solve(board) self._solve(board)
return board return board
@ -104,10 +107,10 @@ class SudokuDataset:
puzzle = [row[:] for row in solved_board] puzzle = [row[:] for row in solved_board]
cells = [(i, j) for i in range(9) for j in range(9)] cells = [(i, j) for i in range(9) for j in range(9)]
rng.shuffle(cells) rng.shuffle(cells)
for i, j in cells[:num_empty]: for i, j in cells[:num_empty]:
puzzle[i][j] = 0 puzzle[i][j] = 0
return puzzle return puzzle
def _board_to_string(self, board: List[List[int]]) -> str: def _board_to_string(self, board: List[List[int]]) -> str:
@ -117,26 +120,22 @@ class SudokuDataset:
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single sudoku puzzle""" """Generate a single sudoku puzzle"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Generate solved board # Generate solved board
solved_board = self._generate_solved_board(rng) solved_board = self._generate_solved_board(rng)
# Create puzzle by removing numbers # Create puzzle by removing numbers
num_empty = rng.randint(self.config.min_empty, self.config.max_empty) num_empty = rng.randint(self.config.min_empty, self.config.max_empty)
puzzle = self._create_puzzle(solved_board, num_empty, rng) puzzle = self._create_puzzle(solved_board, num_empty, rng)
# Format as strings # Format as strings
puzzle_str = self._board_to_string(puzzle) puzzle_str = self._board_to_string(puzzle)
solution_str = self._board_to_string(solved_board) solution_str = self._board_to_string(solved_board)
return { return {
"question": f"Solve this Sudoku puzzle:\n{puzzle_str}", "question": f"Solve this Sudoku puzzle:\n{puzzle_str}",
"answer": solution_str, "answer": solution_str,
"metadata": { "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
"puzzle": puzzle,
"solution": solved_board,
"num_empty": num_empty
}
} }

View file

@ -1,5 +1,7 @@
import pytest
from random import Random from random import Random
import pytest
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
@ -8,11 +10,11 @@ def test_arithmetic_dataset_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = BasicArithmeticDatasetConfig(min_terms=0) config = BasicArithmeticDatasetConfig(min_terms=0)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2) config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator
config.validate() config.validate()
@ -23,30 +25,23 @@ def test_arithmetic_dataset_deterministic():
config = BasicArithmeticDatasetConfig(seed=42, size=10) config = BasicArithmeticDatasetConfig(seed=42, size=10)
dataset1 = BasicArithmeticDataset(config) dataset1 = BasicArithmeticDataset(config)
dataset2 = BasicArithmeticDataset(config) dataset2 = BasicArithmeticDataset(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
def test_arithmetic_dataset_items(): def test_arithmetic_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = BasicArithmeticDatasetConfig( config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
min_terms=2,
max_terms=4,
min_digits=1,
max_digits=2,
size=100,
seed=42
)
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
assert isinstance(item, dict) assert isinstance(item, dict)
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Verify the answer matches the expression # Verify the answer matches the expression
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
answer = eval(expression) # Safe here as we control the expression answer = eval(expression) # Safe here as we control the expression
@ -62,11 +57,11 @@ def test_arithmetic_dataset_format_styles():
min_terms=2, min_terms=2,
max_terms=3, # Keep expressions simple for testing max_terms=3, # Keep expressions simple for testing
min_digits=1, min_digits=1,
max_digits=2 max_digits=2,
) )
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
assert all(item["question"].endswith("=") for item in dataset) assert all(item["question"].endswith("=") for item in dataset)
config.format_style = "natural" config.format_style = "natural"
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
assert all("=" not in item["question"] for item in dataset) assert all("=" not in item["question"] for item in dataset)
@ -74,24 +69,19 @@ def test_arithmetic_dataset_format_styles():
def test_arithmetic_dataset_iteration(): def test_arithmetic_dataset_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = BasicArithmeticDatasetConfig( config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
min_terms=2,
max_terms=2,
size=5, # Small size for testing
seed=42
)
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
# Test manual iteration # Test manual iteration
items = [] items = []
for item in dataset: for item in dataset:
items.append(item) items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items" assert len(items) == config.size, "Iterator should yield exactly size items"
# Test list conversion # Test list conversion
items = list(dataset) items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items" assert len(items) == config.size, "Iterator should yield exactly size items"
# Test multiple iterations # Test multiple iterations
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)

View file

@ -1,10 +1,8 @@
"""Tests for base conversion task generation""" """Tests for base conversion task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.base_conversion import ( from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset
BaseConversionConfig,
BaseConversionDataset,
)
def test_base_conversion_config_validation(): def test_base_conversion_config_validation():
@ -38,14 +36,7 @@ def test_base_conversion_dataset_deterministic():
def test_base_conversion_dataset_items(): def test_base_conversion_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = BaseConversionConfig( config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42)
min_base=2,
max_base=16,
min_value=0,
max_value=1000,
size=10,
seed=42
)
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -55,28 +46,28 @@ def test_base_conversion_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "decimal_value" in item["metadata"] assert "decimal_value" in item["metadata"]
assert "source_base" in item["metadata"] assert "source_base" in item["metadata"]
assert "target_base" in item["metadata"] assert "target_base" in item["metadata"]
assert "source_repr" in item["metadata"] assert "source_repr" in item["metadata"]
assert "target_repr" in item["metadata"] assert "target_repr" in item["metadata"]
# Verify value range # Verify value range
assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value
# Verify base range # Verify base range
assert config.min_base <= item["metadata"]["source_base"] <= config.max_base assert config.min_base <= item["metadata"]["source_base"] <= config.max_base
assert config.min_base <= item["metadata"]["target_base"] <= config.max_base assert config.min_base <= item["metadata"]["target_base"] <= config.max_base
assert item["metadata"]["source_base"] != item["metadata"]["target_base"] assert item["metadata"]["source_base"] != item["metadata"]["target_base"]
# Verify conversion correctness # Verify conversion correctness
decimal_value = item["metadata"]["decimal_value"] decimal_value = item["metadata"]["decimal_value"]
target_base = item["metadata"]["target_base"] target_base = item["metadata"]["target_base"]
expected = format(decimal_value, 'x' if target_base == 16 else 'b' if target_base == 2 else '').strip() expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip()
if target_base not in (2, 16): if target_base not in (2, 16):
expected = format(decimal_value, f'{target_base}x').lower().strip() expected = format(decimal_value, f"{target_base}x").lower().strip()
assert item["answer"] == expected assert item["answer"] == expected
@ -100,24 +91,24 @@ def test_base_conversion_special_bases():
min_value=0, min_value=0,
max_value=255, # Use small range for predictable results max_value=255, # Use small range for predictable results
size=100, size=100,
seed=42 seed=42,
) )
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)
binary_found = False binary_found = False
hex_found = False hex_found = False
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
if item["metadata"]["target_base"] == 2: if item["metadata"]["target_base"] == 2:
binary_found = True binary_found = True
# Verify binary format # Verify binary format
assert all(c in '01' for c in item["answer"]) assert all(c in "01" for c in item["answer"])
elif item["metadata"]["target_base"] == 16: elif item["metadata"]["target_base"] == 16:
hex_found = True hex_found = True
# Verify hex format # Verify hex format
assert all(c in '0123456789abcdef' for c in item["answer"]) assert all(c in "0123456789abcdef" for c in item["answer"])
assert binary_found, "No binary conversion tasks generated" assert binary_found, "No binary conversion tasks generated"
assert hex_found, "No hexadecimal conversion tasks generated" assert hex_found, "No hexadecimal conversion tasks generated"
@ -130,10 +121,10 @@ def test_base_conversion_formatting():
min_value=10, # Ensure multi-digit numbers min_value=10, # Ensure multi-digit numbers
max_value=1000, max_value=1000,
size=10, size=10,
seed=42 seed=42,
) )
dataset = BaseConversionDataset(config) dataset = BaseConversionDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
# Verify lowercase letters are used # Verify lowercase letters are used

View file

@ -1,4 +1,5 @@
import pytest import pytest
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
@ -7,7 +8,7 @@ def test_chain_sum_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = ChainSumConfig(min_terms=0) config = ChainSumConfig(min_terms=0)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = ChainSumConfig(min_terms=3, max_terms=2) config = ChainSumConfig(min_terms=3, max_terms=2)
config.validate() config.validate()
@ -18,34 +19,27 @@ def test_chain_sum_deterministic():
config = ChainSumConfig(seed=42, size=10) config = ChainSumConfig(seed=42, size=10)
dataset1 = ChainSum(config) dataset1 = ChainSum(config)
dataset2 = ChainSum(config) dataset2 = ChainSum(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
def test_chain_sum_items(): def test_chain_sum_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
min_terms=2,
max_terms=4,
min_digits=1,
max_digits=2,
size=100,
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
assert isinstance(item, dict) assert isinstance(item, dict)
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Verify only + and - are used # Verify only + and - are used
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
assert all(op in ["+", "-", " "] or op.isdigit() for op in expression) assert all(op in ["+", "-", " "] or op.isdigit() for op in expression)
# Verify the answer matches the expression # Verify the answer matches the expression
answer = eval(expression) # Safe here as we control the expression answer = eval(expression) # Safe here as we control the expression
assert str(answer) == item["answer"] assert str(answer) == item["answer"]
@ -60,10 +54,10 @@ def test_chain_sum_number_ranges():
min_digits=3, # Should generate numbers >= 100 min_digits=3, # Should generate numbers >= 100
max_digits=3, # Should generate numbers <= 999 max_digits=3, # Should generate numbers <= 999
size=50, size=50,
seed=42 seed=42,
) )
dataset = ChainSum(config) dataset = ChainSum(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
@ -74,16 +68,8 @@ def test_chain_sum_number_ranges():
else: else:
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
# Test 1-digit numbers # Test 1-digit numbers
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
min_terms=2,
max_terms=2,
min_digits=1,
max_digits=1,
size=50,
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
@ -95,58 +81,48 @@ def test_chain_sum_number_ranges():
else: else:
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
def test_chain_sum_negation(): def test_chain_sum_negation():
"""Test that allow_negation controls number ranges""" """Test that allow_negation controls number ranges"""
config = ChainSumConfig( config = ChainSumConfig(
min_terms=2, min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
max_terms=2,
min_digits=2,
max_digits=2,
size=100,
seed=42,
allow_negation=True
) )
dataset = ChainSum(config) dataset = ChainSum(config)
# Track if we see both positive and negative numbers # Track if we see both positive and negative numbers
has_positive = False has_positive = False
has_negative = False has_negative = False
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith('-') and n[1:].isdigit())] numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())]
for num in numbers: for num in numbers:
if num > 0: if num > 0:
has_positive = True has_positive = True
if num < 0: if num < 0:
has_negative = True has_negative = True
# With enough samples and allow_negation=True, we should see both positive and negative numbers # With enough samples and allow_negation=True, we should see both positive and negative numbers
assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True" assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True"
def test_chain_sum_iteration(): def test_chain_sum_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = ChainSumConfig( config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
min_terms=2,
max_terms=2,
size=5, # Small size for testing
seed=42
)
dataset = ChainSum(config) dataset = ChainSum(config)
# Test manual iteration # Test manual iteration
items = [] items = []
for item in dataset: for item in dataset:
items.append(item) items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items" assert len(items) == config.size, "Iterator should yield exactly size items"
# Test list conversion # Test list conversion
items = list(dataset) items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items" assert len(items) == config.size, "Iterator should yield exactly size items"
# Test multiple iterations # Test multiple iterations
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)

View file

@ -1,6 +1,8 @@
import pytest
from math import gcd from math import gcd
from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig
import pytest
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
def test_fraction_config_validation(): def test_fraction_config_validation():
@ -8,15 +10,15 @@ def test_fraction_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = FractionSimplificationConfig(min_value=0) # Should be positive config = FractionSimplificationConfig(min_value=0) # Should be positive
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = FractionSimplificationConfig(min_factor=0) # Should be >= 1 config = FractionSimplificationConfig(min_factor=0) # Should be >= 1
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min
config.validate() config.validate()
@ -27,30 +29,23 @@ def test_fraction_deterministic():
config = FractionSimplificationConfig(seed=42, size=10) config = FractionSimplificationConfig(seed=42, size=10)
dataset1 = FractionSimplificationDataset(config) dataset1 = FractionSimplificationDataset(config)
dataset2 = FractionSimplificationDataset(config) dataset2 = FractionSimplificationDataset(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
def test_fraction_items(): def test_fraction_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=1, max_value=20, min_factor=2, max_factor=5, size=50, seed=42)
min_value=1,
max_value=20,
min_factor=2,
max_factor=5,
size=50,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
assert isinstance(item, dict) assert isinstance(item, dict)
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Verify the metadata contains all expected fields # Verify the metadata contains all expected fields
metadata = item["metadata"] metadata = item["metadata"]
assert "numerator" in metadata assert "numerator" in metadata
@ -58,45 +53,38 @@ def test_fraction_items():
assert "simplified_numerator" in metadata assert "simplified_numerator" in metadata
assert "simplified_denominator" in metadata assert "simplified_denominator" in metadata
assert "reduction_factor" in metadata assert "reduction_factor" in metadata
# Verify the numbers are within configured range # Verify the numbers are within configured range
assert config.min_value <= metadata["simplified_numerator"] <= config.max_value assert config.min_value <= metadata["simplified_numerator"] <= config.max_value
assert config.min_value <= metadata["simplified_denominator"] <= config.max_value assert config.min_value <= metadata["simplified_denominator"] <= config.max_value
# Verify the reduction is correct # Verify the reduction is correct
num = metadata["numerator"] num = metadata["numerator"]
den = metadata["denominator"] den = metadata["denominator"]
simple_num = metadata["simplified_numerator"] simple_num = metadata["simplified_numerator"]
simple_den = metadata["simplified_denominator"] simple_den = metadata["simplified_denominator"]
factor = metadata["reduction_factor"] factor = metadata["reduction_factor"]
assert num == simple_num * factor assert num == simple_num * factor
assert den == simple_den * factor assert den == simple_den * factor
# Verify the simplified fraction is actually in lowest terms # Verify the simplified fraction is actually in lowest terms
assert gcd(simple_num, simple_den) == 1 assert gcd(simple_num, simple_den) == 1
def test_fraction_ranges(): def test_fraction_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=5, max_value=15, min_factor=3, max_factor=4, size=20, seed=42)
min_value=5,
max_value=15,
min_factor=3,
max_factor=4,
size=20,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
metadata = item["metadata"] metadata = item["metadata"]
factor = metadata["reduction_factor"] factor = metadata["reduction_factor"]
# Check factor is within bounds # Check factor is within bounds
assert 3 <= factor <= 4 assert 3 <= factor <= 4
# Check simplified values are within bounds # Check simplified values are within bounds
assert 5 <= metadata["simplified_numerator"] <= 15 assert 5 <= metadata["simplified_numerator"] <= 15
assert 5 <= metadata["simplified_denominator"] <= 15 assert 5 <= metadata["simplified_denominator"] <= 15
@ -106,17 +94,17 @@ def test_fraction_iteration():
"""Test that iteration works correctly""" """Test that iteration works correctly"""
config = FractionSimplificationConfig(size=5, seed=42) config = FractionSimplificationConfig(size=5, seed=42)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
# Test manual iteration # Test manual iteration
items = [] items = []
for item in dataset: for item in dataset:
items.append(item) items.append(item)
assert len(items) == config.size assert len(items) == config.size
# Test list conversion # Test list conversion
items = list(dataset) items = list(dataset)
assert len(items) == config.size assert len(items) == config.size
# Test multiple iterations yield same results # Test multiple iterations yield same results
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
@ -125,24 +113,19 @@ def test_fraction_iteration():
def test_fraction_numerator_smaller(): def test_fraction_numerator_smaller():
"""Test that numerators are always smaller than denominators""" """Test that numerators are always smaller than denominators"""
config = FractionSimplificationConfig( config = FractionSimplificationConfig(min_value=1, max_value=100, min_factor=2, max_factor=5, size=50, seed=42)
min_value=1,
max_value=100,
min_factor=2,
max_factor=5,
size=50,
seed=42
)
dataset = FractionSimplificationDataset(config) dataset = FractionSimplificationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
metadata = item["metadata"] metadata = item["metadata"]
# Check original fraction # Check original fraction
assert metadata["numerator"] <= metadata["denominator"], \ assert (
f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}" metadata["numerator"] <= metadata["denominator"]
), f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}"
# Check simplified fraction # Check simplified fraction
assert metadata["simplified_numerator"] <= metadata["simplified_denominator"], \ assert (
f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" metadata["simplified_numerator"] <= metadata["simplified_denominator"]
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"

View file

@ -1,7 +1,9 @@
import pytest
from math import gcd
from functools import reduce from functools import reduce
from reasoning_gym.arithmetic import GCDDataset, GCDConfig from math import gcd
import pytest
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
def test_gcd_config_validation(): def test_gcd_config_validation():
@ -9,15 +11,15 @@ def test_gcd_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = GCDConfig(min_numbers=1) # Should be >= 2 config = GCDConfig(min_numbers=1) # Should be >= 2
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = GCDConfig(min_numbers=3, max_numbers=2) # max should be >= min config = GCDConfig(min_numbers=3, max_numbers=2) # max should be >= min
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = GCDConfig(min_value=0) # Should be positive config = GCDConfig(min_value=0) # Should be positive
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = GCDConfig(min_value=100, max_value=50) # max should be > min config = GCDConfig(min_value=100, max_value=50) # max should be > min
config.validate() config.validate()
@ -28,40 +30,33 @@ def test_gcd_deterministic():
config = GCDConfig(seed=42, size=10) config = GCDConfig(seed=42, size=10)
dataset1 = GCDDataset(config) dataset1 = GCDDataset(config)
dataset2 = GCDDataset(config) dataset2 = GCDDataset(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
def test_gcd_items(): def test_gcd_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=4, min_value=1, max_value=100, size=50, seed=42)
min_numbers=2,
max_numbers=4,
min_value=1,
max_value=100,
size=50,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
assert isinstance(item, dict) assert isinstance(item, dict)
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Verify the numbers and result are in metadata # Verify the numbers and result are in metadata
metadata = item["metadata"] metadata = item["metadata"]
assert "numbers" in metadata assert "numbers" in metadata
assert "result" in metadata assert "result" in metadata
# Verify the numbers are within configured range # Verify the numbers are within configured range
numbers = metadata["numbers"] numbers = metadata["numbers"]
assert all(config.min_value <= n <= config.max_value for n in numbers) assert all(config.min_value <= n <= config.max_value for n in numbers)
assert config.min_numbers <= len(numbers) <= config.max_numbers assert config.min_numbers <= len(numbers) <= config.max_numbers
# Verify the GCD calculation is correct # Verify the GCD calculation is correct
result = metadata["result"] result = metadata["result"]
assert str(result) == item["answer"] assert str(result) == item["answer"]
@ -70,16 +65,9 @@ def test_gcd_items():
def test_gcd_number_ranges(): def test_gcd_number_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=2, min_value=50, max_value=100, size=20, seed=42)
min_numbers=2,
max_numbers=2,
min_value=50,
max_value=100,
size=20,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
numbers = item["metadata"]["numbers"] numbers = item["metadata"]["numbers"]
@ -90,17 +78,17 @@ def test_gcd_iteration():
"""Test that iteration works correctly""" """Test that iteration works correctly"""
config = GCDConfig(size=5, seed=42) config = GCDConfig(size=5, seed=42)
dataset = GCDDataset(config) dataset = GCDDataset(config)
# Test manual iteration # Test manual iteration
items = [] items = []
for item in dataset: for item in dataset:
items.append(item) items.append(item)
assert len(items) == config.size assert len(items) == config.size
# Test list conversion # Test list conversion
items = list(dataset) items = list(dataset)
assert len(items) == config.size assert len(items) == config.size
# Test multiple iterations yield same results # Test multiple iterations yield same results
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
@ -109,20 +97,13 @@ def test_gcd_iteration():
def test_gcd_special_cases(): def test_gcd_special_cases():
"""Test some special GCD cases""" """Test some special GCD cases"""
config = GCDConfig( config = GCDConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=100, size=100, seed=42)
min_numbers=2,
max_numbers=2,
min_value=1,
max_value=100,
size=100,
seed=42
)
dataset = GCDDataset(config) dataset = GCDDataset(config)
# Track if we see some interesting GCD cases # Track if we see some interesting GCD cases
seen_gcd_1 = False # Coprime numbers seen_gcd_1 = False # Coprime numbers
seen_large_gcd = False # GCD > 1 seen_large_gcd = False # GCD > 1
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
result = int(item["answer"]) result = int(item["answer"])
@ -130,7 +111,7 @@ def test_gcd_special_cases():
seen_gcd_1 = True seen_gcd_1 = True
if result > 1: if result > 1:
seen_large_gcd = True seen_large_gcd = True
# With enough samples, we should see both coprime and non-coprime numbers # With enough samples, we should see both coprime and non-coprime numbers
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)" assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)" assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"

View file

@ -1,7 +1,9 @@
import pytest
from math import lcm
from functools import reduce from functools import reduce
from reasoning_gym.arithmetic import LCMDataset, LCMConfig from math import lcm
import pytest
from reasoning_gym.arithmetic import LCMConfig, LCMDataset
def test_lcm_config_validation(): def test_lcm_config_validation():
@ -9,15 +11,15 @@ def test_lcm_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = LCMConfig(min_numbers=1) # Should be >= 2 config = LCMConfig(min_numbers=1) # Should be >= 2
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = LCMConfig(min_value=0) # Should be positive config = LCMConfig(min_value=0) # Should be positive
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = LCMConfig(min_value=100, max_value=50) # max should be > min config = LCMConfig(min_value=100, max_value=50) # max should be > min
config.validate() config.validate()
@ -28,7 +30,7 @@ def test_lcm_deterministic():
config = LCMConfig(seed=42, size=10) config = LCMConfig(seed=42, size=10)
dataset1 = LCMDataset(config) dataset1 = LCMDataset(config)
dataset2 = LCMDataset(config) dataset2 = LCMDataset(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
@ -36,32 +38,27 @@ def test_lcm_deterministic():
def test_lcm_items(): def test_lcm_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LCMConfig( config = LCMConfig(
min_numbers=2, min_numbers=2, max_numbers=4, min_value=1, max_value=20, size=50, seed=42 # Keep small for testing
max_numbers=4,
min_value=1,
max_value=20, # Keep small for testing
size=50,
seed=42
) )
dataset = LCMDataset(config) dataset = LCMDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
assert isinstance(item, dict) assert isinstance(item, dict)
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Verify the numbers and result are in metadata # Verify the numbers and result are in metadata
metadata = item["metadata"] metadata = item["metadata"]
assert "numbers" in metadata assert "numbers" in metadata
assert "result" in metadata assert "result" in metadata
# Verify the numbers are within configured range # Verify the numbers are within configured range
numbers = metadata["numbers"] numbers = metadata["numbers"]
assert all(config.min_value <= n <= config.max_value for n in numbers) assert all(config.min_value <= n <= config.max_value for n in numbers)
assert config.min_numbers <= len(numbers) <= config.max_numbers assert config.min_numbers <= len(numbers) <= config.max_numbers
# Verify the LCM calculation is correct # Verify the LCM calculation is correct
result = metadata["result"] result = metadata["result"]
assert str(result) == item["answer"] assert str(result) == item["answer"]
@ -70,16 +67,9 @@ def test_lcm_items():
def test_lcm_number_ranges(): def test_lcm_number_ranges():
"""Test that generated numbers respect value constraints""" """Test that generated numbers respect value constraints"""
config = LCMConfig( config = LCMConfig(min_numbers=2, max_numbers=2, min_value=5, max_value=15, size=20, seed=42)
min_numbers=2,
max_numbers=2,
min_value=5,
max_value=15,
size=20,
seed=42
)
dataset = LCMDataset(config) dataset = LCMDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
numbers = item["metadata"]["numbers"] numbers = item["metadata"]["numbers"]
@ -90,17 +80,17 @@ def test_lcm_iteration():
"""Test that iteration works correctly""" """Test that iteration works correctly"""
config = LCMConfig(size=5, seed=42) config = LCMConfig(size=5, seed=42)
dataset = LCMDataset(config) dataset = LCMDataset(config)
# Test manual iteration # Test manual iteration
items = [] items = []
for item in dataset: for item in dataset:
items.append(item) items.append(item)
assert len(items) == config.size assert len(items) == config.size
# Test list conversion # Test list conversion
items = list(dataset) items = list(dataset)
assert len(items) == config.size assert len(items) == config.size
# Test multiple iterations yield same results # Test multiple iterations yield same results
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
@ -109,31 +99,24 @@ def test_lcm_iteration():
def test_lcm_special_cases(): def test_lcm_special_cases():
"""Test some special LCM cases""" """Test some special LCM cases"""
config = LCMConfig( config = LCMConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=20, size=100, seed=42)
min_numbers=2,
max_numbers=2,
min_value=1,
max_value=20,
size=100,
seed=42
)
dataset = LCMDataset(config) dataset = LCMDataset(config)
# Track if we see some interesting LCM cases # Track if we see some interesting LCM cases
seen_equal_to_product = False # When numbers are coprime seen_equal_to_product = False # When numbers are coprime
seen_less_than_product = False # When numbers share factors seen_less_than_product = False # When numbers share factors
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
numbers = item["metadata"]["numbers"] numbers = item["metadata"]["numbers"]
result = int(item["answer"]) result = int(item["answer"])
product = reduce(lambda x, y: x * y, numbers) product = reduce(lambda x, y: x * y, numbers)
if result == product: if result == product:
seen_equal_to_product = True seen_equal_to_product = True
if result < product: if result < product:
seen_less_than_product = True seen_less_than_product = True
# With enough samples, we should see both cases # With enough samples, we should see both cases
assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)" assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)"
assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)" assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)"

View file

@ -1,11 +1,8 @@
"""Tests for leg counting task generation""" """Tests for leg counting task generation"""
import pytest import pytest
from reasoning_gym.arithmetic.leg_counting import ( from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset
LegCountingConfig,
LegCountingDataset,
ANIMALS,
)
def test_leg_counting_config_validation(): def test_leg_counting_config_validation():
@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic():
def test_leg_counting_dataset_items(): def test_leg_counting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LegCountingConfig( config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42)
min_animals=2,
max_animals=4,
max_instances=2,
size=10,
seed=42
)
dataset = LegCountingDataset(config) dataset = LegCountingDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -51,19 +42,19 @@ def test_leg_counting_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "animals" in item["metadata"] assert "animals" in item["metadata"]
assert "total_legs" in item["metadata"] assert "total_legs" in item["metadata"]
# Verify animal count constraints # Verify animal count constraints
animals = item["metadata"]["animals"] animals = item["metadata"]["animals"]
assert len(animals) >= config.min_animals assert len(animals) >= config.min_animals
assert len(animals) <= config.max_animals assert len(animals) <= config.max_animals
# Verify instance count constraints # Verify instance count constraints
assert all(1 <= count <= config.max_instances for count in animals.values()) assert all(1 <= count <= config.max_instances for count in animals.values())
# Verify leg counting is correct # Verify leg counting is correct
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items()) total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
assert str(total_legs) == item["answer"] assert str(total_legs) == item["answer"]
@ -86,7 +77,7 @@ def test_leg_counting_animal_validation():
"""Test that all animals have valid leg counts""" """Test that all animals have valid leg counts"""
# Verify all animals have non-negative leg counts # Verify all animals have non-negative leg counts
assert all(legs >= 0 for legs in ANIMALS.values()) assert all(legs >= 0 for legs in ANIMALS.values())
# Verify common animals have expected leg counts # Verify common animals have expected leg counts
assert ANIMALS["spider"] == 8 assert ANIMALS["spider"] == 8
assert ANIMALS["insect"] == 6 assert ANIMALS["insect"] == 6

View file

@ -1,10 +1,8 @@
"""Tests for letter counting task generation""" """Tests for letter counting task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.letter_counting import ( from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset
LetterCountingConfig,
LetterCountingDataset,
)
def test_letter_counting_config_validation(): def test_letter_counting_config_validation():
@ -30,12 +28,7 @@ def test_letter_counting_dataset_deterministic():
def test_letter_counting_dataset_items(): def test_letter_counting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = LetterCountingConfig( config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42)
min_words=3,
max_words=6,
size=10,
seed=42
)
dataset = LetterCountingDataset(config) dataset = LetterCountingDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -45,17 +38,17 @@ def test_letter_counting_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "span_length" in item["metadata"] assert "span_length" in item["metadata"]
assert "target_letter" in item["metadata"] assert "target_letter" in item["metadata"]
assert "span" in item["metadata"] assert "span" in item["metadata"]
# Verify span length constraints # Verify span length constraints
span = item["metadata"]["span"] span = item["metadata"]["span"]
assert len(span) >= config.min_words assert len(span) >= config.min_words
assert len(span) <= config.max_words assert len(span) <= config.max_words
# Verify letter counting # Verify letter counting
target_letter = item["metadata"]["target_letter"] target_letter = item["metadata"]["target_letter"]
count = sum(word.lower().count(target_letter) for word in span) count = sum(word.lower().count(target_letter) for word in span)
@ -78,7 +71,7 @@ def test_letter_counting_text_preprocessing():
"""Test that text preprocessing handles edge cases""" """Test that text preprocessing handles edge cases"""
config = LetterCountingConfig(size=1, seed=42) config = LetterCountingConfig(size=1, seed=42)
dataset = LetterCountingDataset(config) dataset = LetterCountingDataset(config)
# Verify words were extracted from text # Verify words were extracted from text
assert len(dataset.words) > 0 assert len(dataset.words) > 0
# Verify words contain only word characters # Verify words contain only word characters

View file

@ -1,10 +1,8 @@
"""Tests for mini sudoku puzzle generation""" """Tests for mini sudoku puzzle generation"""
import pytest import pytest
from reasoning_gym.games.mini_sudoku import ( from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
MiniSudokuConfig,
MiniSudokuDataset,
)
def test_mini_sudoku_config_validation(): def test_mini_sudoku_config_validation():
@ -34,12 +32,7 @@ def test_mini_sudoku_dataset_deterministic():
def test_mini_sudoku_dataset_items(): def test_mini_sudoku_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = MiniSudokuConfig( config = MiniSudokuConfig(min_empty=8, max_empty=12, size=10, seed=42)
min_empty=8,
max_empty=12,
size=10,
seed=42
)
dataset = MiniSudokuDataset(config) dataset = MiniSudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -49,30 +42,30 @@ def test_mini_sudoku_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "puzzle" in item["metadata"] assert "puzzle" in item["metadata"]
assert "solution" in item["metadata"] assert "solution" in item["metadata"]
assert "num_empty" in item["metadata"] assert "num_empty" in item["metadata"]
puzzle = item["metadata"]["puzzle"] puzzle = item["metadata"]["puzzle"]
solution = item["metadata"]["solution"] solution = item["metadata"]["solution"]
num_empty = item["metadata"]["num_empty"] num_empty = item["metadata"]["num_empty"]
# Verify board dimensions # Verify board dimensions
assert len(puzzle) == 4 assert len(puzzle) == 4
assert all(len(row) == 4 for row in puzzle) assert all(len(row) == 4 for row in puzzle)
assert len(solution) == 4 assert len(solution) == 4
assert all(len(row) == 4 for row in solution) assert all(len(row) == 4 for row in solution)
# Verify empty cell count # Verify empty cell count
empty_count = sum(1 for row in puzzle for cell in row if cell == 0) empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
assert config.min_empty <= empty_count <= config.max_empty assert config.min_empty <= empty_count <= config.max_empty
assert empty_count == num_empty assert empty_count == num_empty
# Verify solution validity # Verify solution validity
assert is_valid_solution(solution) assert is_valid_solution(solution)
# Verify puzzle matches solution where filled # Verify puzzle matches solution where filled
for i in range(4): for i in range(4):
for j in range(4): for j in range(4):
@ -94,14 +87,9 @@ def test_mini_sudoku_dataset_iteration():
def test_mini_sudoku_board_generation(): def test_mini_sudoku_board_generation():
"""Test that generated boards are valid""" """Test that generated boards are valid"""
config = MiniSudokuConfig( config = MiniSudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
min_empty=0, # Force complete board
max_empty=0,
size=5,
seed=42
)
dataset = MiniSudokuDataset(config) dataset = MiniSudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
board = item["metadata"]["solution"] board = item["metadata"]["solution"]
@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool:
for row in board: for row in board:
if set(row) != set(range(1, 5)): if set(row) != set(range(1, 5)):
return False return False
# Check columns # Check columns
for j in range(4): for j in range(4):
column = [board[i][j] for i in range(4)] column = [board[i][j] for i in range(4)]
if set(column) != set(range(1, 5)): if set(column) != set(range(1, 5)):
return False return False
# Check 2x2 boxes # Check 2x2 boxes
for box_i in range(2): for box_i in range(2):
for box_j in range(2): for box_j in range(2):
box = [] box = []
for i in range(2): for i in range(2):
for j in range(2): for j in range(2):
box.append(board[box_i*2 + i][box_j*2 + j]) box.append(board[box_i * 2 + i][box_j * 2 + j])
if set(box) != set(range(1, 5)): if set(box) != set(range(1, 5)):
return False return False
return True return True

View file

@ -1,10 +1,8 @@
"""Tests for number filtering task generation""" """Tests for number filtering task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.number_filtering import ( from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset
NumberFilteringConfig,
NumberFilteringDataset,
)
def test_number_filtering_config_validation(): def test_number_filtering_config_validation():
@ -16,11 +14,11 @@ def test_number_filtering_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberFilteringConfig(min_numbers=10, max_numbers=5) config = NumberFilteringConfig(min_numbers=10, max_numbers=5)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberFilteringConfig(min_decimals=-1) config = NumberFilteringConfig(min_decimals=-1)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberFilteringConfig(min_value=100, max_value=0) config = NumberFilteringConfig(min_value=100, max_value=0)
config.validate() config.validate()
@ -39,14 +37,7 @@ def test_number_filtering_dataset_deterministic():
def test_number_filtering_dataset_items(): def test_number_filtering_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = NumberFilteringConfig( config = NumberFilteringConfig(
min_numbers=3, min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
max_numbers=6,
min_decimals=1,
max_decimals=3,
min_value=-10.0,
max_value=10.0,
size=10,
seed=42
) )
dataset = NumberFilteringDataset(config) dataset = NumberFilteringDataset(config)
@ -57,34 +48,34 @@ def test_number_filtering_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "original_numbers" in item["metadata"] assert "original_numbers" in item["metadata"]
assert "filter_value" in item["metadata"] assert "filter_value" in item["metadata"]
assert "operation" in item["metadata"] assert "operation" in item["metadata"]
assert "result" in item["metadata"] assert "result" in item["metadata"]
# Verify number count constraints # Verify number count constraints
numbers = item["metadata"]["original_numbers"] numbers = item["metadata"]["original_numbers"]
assert len(numbers) >= config.min_numbers assert len(numbers) >= config.min_numbers
assert len(numbers) <= config.max_numbers assert len(numbers) <= config.max_numbers
# Verify decimal places # Verify decimal places
for num in numbers: for num in numbers:
decimal_places = len(num.split('.')[-1]) if '.' in num else 0 decimal_places = len(num.split(".")[-1]) if "." in num else 0
assert decimal_places >= config.min_decimals assert decimal_places >= config.min_decimals
assert decimal_places <= config.max_decimals assert decimal_places <= config.max_decimals
# Verify value range # Verify value range
for num in numbers: for num in numbers:
value = float(num) value = float(num)
assert config.min_value <= value <= config.max_value assert config.min_value <= value <= config.max_value
# Verify filtering operation # Verify filtering operation
operation = item["metadata"]["operation"] operation = item["metadata"]["operation"]
filter_value = float(item["metadata"]["filter_value"]) filter_value = float(item["metadata"]["filter_value"])
result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else [] result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else []
if operation == "keep_larger": if operation == "keep_larger":
assert all(x > filter_value for x in result) assert all(x > filter_value for x in result)
elif operation == "keep_smaller": elif operation == "keep_smaller":
@ -117,11 +108,11 @@ def test_number_filtering_precision():
min_value=0.0, min_value=0.0,
max_value=1.0, max_value=1.0,
size=1, size=1,
seed=42 seed=42,
) )
dataset = NumberFilteringDataset(config) dataset = NumberFilteringDataset(config)
item = dataset[0] item = dataset[0]
# Check that string representations maintain precision # Check that string representations maintain precision
for num in item["metadata"]["original_numbers"]: for num in item["metadata"]["original_numbers"]:
assert len(num.split('.')[-1]) == 2 assert len(num.split(".")[-1]) == 2

View file

@ -1,10 +1,8 @@
"""Tests for number sorting task generation""" """Tests for number sorting task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.number_sorting import ( from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset
NumberSortingConfig,
NumberSortingDataset,
)
def test_number_sorting_config_validation(): def test_number_sorting_config_validation():
@ -16,11 +14,11 @@ def test_number_sorting_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberSortingConfig(min_numbers=10, max_numbers=5) config = NumberSortingConfig(min_numbers=10, max_numbers=5)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberSortingConfig(min_decimals=-1) config = NumberSortingConfig(min_decimals=-1)
config.validate() config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = NumberSortingConfig(min_value=100, max_value=0) config = NumberSortingConfig(min_value=100, max_value=0)
config.validate() config.validate()
@ -39,14 +37,7 @@ def test_number_sorting_dataset_deterministic():
def test_number_sorting_dataset_items(): def test_number_sorting_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = NumberSortingConfig( config = NumberSortingConfig(
min_numbers=3, min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42
max_numbers=6,
min_decimals=1,
max_decimals=3,
min_value=-10.0,
max_value=10.0,
size=10,
seed=42
) )
dataset = NumberSortingDataset(config) dataset = NumberSortingDataset(config)
@ -57,28 +48,28 @@ def test_number_sorting_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "original_numbers" in item["metadata"] assert "original_numbers" in item["metadata"]
assert "direction" in item["metadata"] assert "direction" in item["metadata"]
assert "sorted_numbers" in item["metadata"] assert "sorted_numbers" in item["metadata"]
# Verify number count constraints # Verify number count constraints
numbers = item["metadata"]["original_numbers"] numbers = item["metadata"]["original_numbers"]
assert len(numbers) >= config.min_numbers assert len(numbers) >= config.min_numbers
assert len(numbers) <= config.max_numbers assert len(numbers) <= config.max_numbers
# Verify decimal places # Verify decimal places
for num in numbers: for num in numbers:
decimal_places = len(num.split('.')[-1]) if '.' in num else 0 decimal_places = len(num.split(".")[-1]) if "." in num else 0
assert decimal_places >= config.min_decimals assert decimal_places >= config.min_decimals
assert decimal_places <= config.max_decimals assert decimal_places <= config.max_decimals
# Verify value range # Verify value range
for num in numbers: for num in numbers:
value = float(num) value = float(num)
assert config.min_value <= value <= config.max_value assert config.min_value <= value <= config.max_value
# Verify sorting # Verify sorting
direction = item["metadata"]["direction"] direction = item["metadata"]["direction"]
sorted_numbers = [float(x) for x in eval(item["answer"])] sorted_numbers = [float(x) for x in eval(item["answer"])]

View file

@ -1,10 +1,8 @@
"""Tests for prime factorization task generation""" """Tests for prime factorization task generation"""
import pytest import pytest
from reasoning_gym.arithmetic.prime_factorization import ( from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
PrimeFactorizationConfig,
PrimeFactorizationDataset,
)
def test_prime_factorization_config_validation(): def test_prime_factorization_config_validation():
@ -30,12 +28,7 @@ def test_prime_factorization_dataset_deterministic():
def test_prime_factorization_dataset_items(): def test_prime_factorization_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = PrimeFactorizationConfig( config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42)
min_value=2,
max_value=100,
size=10,
seed=42
)
dataset = PrimeFactorizationDataset(config) dataset = PrimeFactorizationDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -45,26 +38,26 @@ def test_prime_factorization_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "number" in item["metadata"] assert "number" in item["metadata"]
assert "factors" in item["metadata"] assert "factors" in item["metadata"]
# Verify value range # Verify value range
number = item["metadata"]["number"] number = item["metadata"]["number"]
assert config.min_value <= number <= config.max_value assert config.min_value <= number <= config.max_value
# Verify factorization is correct # Verify factorization is correct
factors = item["metadata"]["factors"] factors = item["metadata"]["factors"]
product = 1 product = 1
for factor in factors: for factor in factors:
product *= factor product *= factor
assert product == number assert product == number
# Verify factors are prime # Verify factors are prime
for factor in factors: for factor in factors:
assert is_prime(factor), f"{factor} is not prime" assert is_prime(factor), f"{factor} is not prime"
# Verify answer format # Verify answer format
assert item["answer"] == " × ".join(map(str, factors)) assert item["answer"] == " × ".join(map(str, factors))
@ -83,15 +76,10 @@ def test_prime_factorization_dataset_iteration():
def test_prime_factorization_known_values(): def test_prime_factorization_known_values():
"""Test factorization of known values""" """Test factorization of known values"""
config = PrimeFactorizationConfig( config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
min_value=12,
max_value=12, # Force specific number
size=1,
seed=42
)
dataset = PrimeFactorizationDataset(config) dataset = PrimeFactorizationDataset(config)
item = dataset[0] item = dataset[0]
assert item["metadata"]["number"] == 12 assert item["metadata"]["number"] == 12
assert item["metadata"]["factors"] == [2, 2, 3] assert item["metadata"]["factors"] == [2, 2, 3]
assert item["answer"] == "2 × 2 × 3" assert item["answer"] == "2 × 2 × 3"
@ -101,7 +89,7 @@ def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime""" """Helper function to check if a number is prime"""
if n < 2: if n < 2:
return False return False
for i in range(2, int(n ** 0.5) + 1): for i in range(2, int(n**0.5) + 1):
if n % i == 0: if n % i == 0:
return False return False
return True return True

View file

@ -23,14 +23,14 @@ def test_pattern_rule():
# Test simple addition # Test simple addition
rule = PatternRule([Operation.ADD], [2]) rule = PatternRule([Operation.ADD], [2])
assert rule.apply([1, 3], 1) == 5 assert rule.apply([1, 3], 1) == 5
# Test composition # Test composition
rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3]) rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3])
assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3 assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3
# Test rule composition # Test rule composition
rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number
rule2 = PatternRule([Operation.ADD], [3]) # Add 3 rule2 = PatternRule([Operation.ADD], [3]) # Add 3
composed = PatternRule.compose([rule1, rule2]) composed = PatternRule.compose([rule1, rule2])
assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3 assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3

View file

@ -1,10 +1,8 @@
"""Tests for sudoku puzzle generation""" """Tests for sudoku puzzle generation"""
import pytest import pytest
from reasoning_gym.games.sudoku import ( from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset
SudokuConfig,
SudokuDataset,
)
def test_sudoku_config_validation(): def test_sudoku_config_validation():
@ -34,12 +32,7 @@ def test_sudoku_dataset_deterministic():
def test_sudoku_dataset_items(): def test_sudoku_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = SudokuConfig( config = SudokuConfig(min_empty=30, max_empty=40, size=10, seed=42)
min_empty=30,
max_empty=40,
size=10,
seed=42
)
dataset = SudokuDataset(config) dataset = SudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -49,30 +42,30 @@ def test_sudoku_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "puzzle" in item["metadata"] assert "puzzle" in item["metadata"]
assert "solution" in item["metadata"] assert "solution" in item["metadata"]
assert "num_empty" in item["metadata"] assert "num_empty" in item["metadata"]
puzzle = item["metadata"]["puzzle"] puzzle = item["metadata"]["puzzle"]
solution = item["metadata"]["solution"] solution = item["metadata"]["solution"]
num_empty = item["metadata"]["num_empty"] num_empty = item["metadata"]["num_empty"]
# Verify board dimensions # Verify board dimensions
assert len(puzzle) == 9 assert len(puzzle) == 9
assert all(len(row) == 9 for row in puzzle) assert all(len(row) == 9 for row in puzzle)
assert len(solution) == 9 assert len(solution) == 9
assert all(len(row) == 9 for row in solution) assert all(len(row) == 9 for row in solution)
# Verify empty cell count # Verify empty cell count
empty_count = sum(1 for row in puzzle for cell in row if cell == 0) empty_count = sum(1 for row in puzzle for cell in row if cell == 0)
assert config.min_empty <= empty_count <= config.max_empty assert config.min_empty <= empty_count <= config.max_empty
assert empty_count == num_empty assert empty_count == num_empty
# Verify solution validity # Verify solution validity
assert is_valid_solution(solution) assert is_valid_solution(solution)
# Verify puzzle matches solution where filled # Verify puzzle matches solution where filled
for i in range(9): for i in range(9):
for j in range(9): for j in range(9):
@ -94,14 +87,9 @@ def test_sudoku_dataset_iteration():
def test_sudoku_board_generation(): def test_sudoku_board_generation():
"""Test that generated boards are valid""" """Test that generated boards are valid"""
config = SudokuConfig( config = SudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board
min_empty=0, # Force complete board
max_empty=0,
size=5,
seed=42
)
dataset = SudokuDataset(config) dataset = SudokuDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
board = item["metadata"]["solution"] board = item["metadata"]["solution"]
@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool:
for row in board: for row in board:
if set(row) != set(range(1, 10)): if set(row) != set(range(1, 10)):
return False return False
# Check columns # Check columns
for j in range(9): for j in range(9):
column = [board[i][j] for i in range(9)] column = [board[i][j] for i in range(9)]
if set(column) != set(range(1, 10)): if set(column) != set(range(1, 10)):
return False return False
# Check 3x3 boxes # Check 3x3 boxes
for box_i in range(3): for box_i in range(3):
for box_j in range(3): for box_j in range(3):
box = [] box = []
for i in range(3): for i in range(3):
for j in range(3): for j in range(3):
box.append(board[box_i*3 + i][box_j*3 + j]) box.append(board[box_i * 3 + i][box_j * 3 + j])
if set(box) != set(range(1, 10)): if set(box) != set(range(1, 10)):
return False return False
return True return True

View file

@ -1,10 +1,8 @@
"""Tests for word reversal task generation""" """Tests for word reversal task generation"""
import pytest import pytest
from reasoning_gym.algorithmic.word_reversal import ( from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset
WordReversalConfig,
WordReversalDataset,
)
def test_word_reversal_config_validation(): def test_word_reversal_config_validation():
@ -30,12 +28,7 @@ def test_word_reversal_dataset_deterministic():
def test_word_reversal_dataset_items(): def test_word_reversal_dataset_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = WordReversalConfig( config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42)
min_words=3,
max_words=6,
size=10,
seed=42
)
dataset = WordReversalDataset(config) dataset = WordReversalDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
@ -45,16 +38,16 @@ def test_word_reversal_dataset_items():
assert "question" in item assert "question" in item
assert "answer" in item assert "answer" in item
assert "metadata" in item assert "metadata" in item
# Check metadata # Check metadata
assert "num_words" in item["metadata"] assert "num_words" in item["metadata"]
assert "words" in item["metadata"] assert "words" in item["metadata"]
# Verify word count constraints # Verify word count constraints
words = item["metadata"]["words"] words = item["metadata"]["words"]
assert len(words) >= config.min_words assert len(words) >= config.min_words
assert len(words) <= config.max_words assert len(words) <= config.max_words
# Verify reversal is correct # Verify reversal is correct
question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")] question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")]
answer_words = item["answer"].split(", ") answer_words = item["answer"].split(", ")
@ -77,7 +70,7 @@ def test_word_reversal_text_preprocessing():
"""Test that text preprocessing handles edge cases""" """Test that text preprocessing handles edge cases"""
config = WordReversalConfig(size=1, seed=42) config = WordReversalConfig(size=1, seed=42)
dataset = WordReversalDataset(config) dataset = WordReversalDataset(config)
# Verify words were extracted from text # Verify words were extracted from text
assert len(dataset.words) > 0 assert len(dataset.words) > 0
# Verify words contain only alphanumeric characters # Verify words contain only alphanumeric characters