diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index ddf99615..e894ba12 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,12 +2,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import arithmetic -from . import algorithmic -from . import cognition -from . import data -from . import games -from . import logic +from . import algorithmic, arithmetic, cognition, data, games, logic __version__ = "0.1.0" __all__ = ["arithmetic", "algorithmic", "cognition", "data", "games", "logic"] diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 5264a9b0..bf33e5f3 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities: from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset + from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset @@ -20,8 +21,8 @@ __all__ = [ "BaseConversionDataset", "base_conversion_dataset", "chain_sum_dataset", - "LetterCountingConfig", - "LetterCountingDataset", + "LetterCountingConfig", + "LetterCountingDataset", "letter_counting_dataset", "NumberFilteringConfig", "NumberFilteringDataset", @@ -31,5 +32,5 @@ __all__ = [ "number_sorting_dataset", "WordReversalConfig", "WordReversalDataset", - "word_reversal_dataset" + "word_reversal_dataset", ] diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index 73f267ae..55c352fe 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -1,17 +1,20 @@ """Base conversion task generator""" + from dataclasses import dataclass from random import Random from typing import Optional, Tuple + @dataclass class BaseConversionConfig: """Configuration for base conversion task generation""" - min_base: int = 2 # Minimum base (2=binary) - max_base: int = 16 # Maximum base (16=hex) - min_value: int = 0 # Minimum decimal value to convert - max_value: int = 1000 # Maximum decimal value to convert + + min_base: int = 2 # Minimum base (2=binary) + max_base: int = 16 # Maximum base (16=hex) + min_value: int = 0 # Minimum decimal value to convert + max_value: int = 1000 # Maximum decimal value to convert seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -55,37 +58,37 @@ class BaseConversionDataset: def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]: """Generate random value and source/target bases""" value = rng.randint(self.config.min_value, self.config.max_value) - + # Choose source and target bases source_base = rng.randint(self.config.min_base, self.config.max_base) target_base = rng.randint(self.config.min_base, self.config.max_base) while target_base == source_base: # Ensure different bases target_base = rng.randint(self.config.min_base, self.config.max_base) - + return value, source_base, target_base def __getitem__(self, idx: int) -> dict: """Generate a single base conversion task""" rng = Random(self.seed + idx) - + value, source_base, target_base = self._generate_conversion(rng) - + # Convert decimal to source base representation - source_repr = format(value, f'x' if source_base == 16 else f'b' if source_base == 2 else '').strip() + source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip() if source_base not in (2, 16): - source_repr = format(value, f'{source_base}x').lower().strip() - + source_repr = format(value, f"{source_base}x").lower().strip() + # Convert decimal to target base for answer - target_repr = format(value, f'x' if target_base == 16 else f'b' if target_base == 2 else '').strip() + target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip() if target_base not in (2, 16): - target_repr = format(value, f'{target_base}x').lower().strip() - + target_repr = format(value, f"{target_base}x").lower().strip() + source_name = self._format_base_name(source_base) target_name = self._format_base_name(target_base) - + # Add hint for bases > 10 about using lowercase letters hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else "" - + return { "question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}", "answer": target_repr, @@ -94,8 +97,8 @@ class BaseConversionDataset: "source_base": source_base, "target_base": target_base, "source_repr": source_repr, - "target_repr": target_repr - } + "target_repr": target_repr, + }, } diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 5f2c372c..5c620225 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -1,18 +1,21 @@ """Letter counting task generator""" -from dataclasses import dataclass + import re +from dataclasses import dataclass from random import Random from typing import List, Optional from reasoning_gym.data import read_data_file + @dataclass class LetterCountingConfig: """Configuration for letter counting task generation""" - min_words: int = 5 # Minimum words in span - max_words: int = 15 # Maximum words in span + + min_words: int = 5 # Minimum words in span + max_words: int = 15 # Maximum words in span seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -27,11 +30,11 @@ class LetterCountingDataset: self.config = config self.config.validate() self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - + # Load and preprocess text text = read_data_file("in_the_year_2889.txt") # Extract words and clean them to contain only alphanumeric characters - self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] + self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()] def __len__(self) -> int: return self.config.size @@ -50,31 +53,27 @@ class LetterCountingDataset: def __getitem__(self, idx: int) -> dict: """Generate a single letter counting task""" rng = Random(self.seed + idx) - + # Select random span of words span_length = rng.randint(self.config.min_words, self.config.max_words) start_idx = rng.randint(0, len(self.words) - span_length) - span = self.words[start_idx:start_idx + span_length] - + span = self.words[start_idx : start_idx + span_length] + # Get all unique letters from span - letters = set(''.join(span).lower()) + letters = set("".join(span).lower()) if not letters: - letters = {'a'} # Fallback if span has no letters - + letters = {"a"} # Fallback if span has no letters + # Select random letter that appears in the span target_letter = rng.choice(list(letters)) - + # Count occurrences count = sum(word.lower().count(target_letter) for word in span) - + return { "question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?', "answer": str(count), - "metadata": { - "span_length": span_length, - "target_letter": target_letter, - "span": span - } + "metadata": {"span_length": span_length, "target_letter": target_letter, "span": span}, } diff --git a/reasoning_gym/algorithmic/number_filtering.py b/reasoning_gym/algorithmic/number_filtering.py index 22d57f6c..e4329f82 100644 --- a/reasoning_gym/algorithmic/number_filtering.py +++ b/reasoning_gym/algorithmic/number_filtering.py @@ -1,20 +1,23 @@ """Number filtering task generator""" -from dataclasses import dataclass + import random +from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple + @dataclass class NumberFilteringConfig: """Configuration for number filtering task generation""" - min_numbers: int = 3 # Minimum numbers in list - max_numbers: int = 10 # Maximum numbers in list - min_decimals: int = 0 # Minimum decimal places - max_decimals: int = 4 # Maximum decimal places - min_value: float = -100.0 # Minimum number value - max_value: float = 100.0 # Maximum number value + + min_numbers: int = 3 # Minimum numbers in list + max_numbers: int = 10 # Maximum numbers in list + min_decimals: int = 0 # Minimum decimal places + max_decimals: int = 4 # Maximum decimal places + min_value: float = -100.0 # Minimum number value + max_value: float = 100.0 # Maximum number value seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -56,23 +59,23 @@ class NumberFilteringDataset: count = rng.randint(self.config.min_numbers, self.config.max_numbers) numbers = [] str_numbers = [] - + for _ in range(count): num = rng.uniform(self.config.min_value, self.config.max_value) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) str_num = self._format_number(num, decimals) numbers.append(float(str_num)) # Convert back to simulate precision loss str_numbers.append(str_num) - + return numbers, str_numbers def __getitem__(self, idx: int) -> dict: """Generate a single number filtering task""" rng = Random(self.seed + idx) - + # Generate numbers and their string representations numbers, str_numbers = self._generate_numbers(rng) - + # Determine filter value between min and max of generated numbers min_val = min(numbers) max_val = max(numbers) @@ -80,31 +83,33 @@ class NumberFilteringDataset: decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) filter_str = self._format_number(filter_value, decimals) filter_value = float(filter_str) # Convert back to simulate precision loss - + # Randomly choose filter operation keep_larger = rng.choice([True, False]) larger_smaller = "larger" if keep_larger else "smaller" keep_remove = "keep" if rng.choice([True, False]) else "remove" - + # Apply filter based on chosen operation if keep_remove == "keep": result = [n for n in numbers if (n > filter_value if keep_larger else n < filter_value)] else: # remove result = [n for n in numbers if (n <= filter_value if keep_larger else n >= filter_value)] - + # Format results as strings with original precision result_strs = [str_numbers[numbers.index(n)] for n in result] - + return { - "question": (f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} " - f"in this list: {str_numbers}"), + "question": ( + f"{keep_remove.capitalize()} all numbers {larger_smaller} than {filter_str} " + f"in this list: {str_numbers}" + ), "answer": str(result_strs) if result_strs else "[]", "metadata": { "original_numbers": str_numbers, "filter_value": filter_str, "operation": f"{keep_remove}_{larger_smaller}", - "result": result_strs - } + "result": result_strs, + }, } diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index 8fa8fb94..c2d9718d 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -1,20 +1,23 @@ """Number sorting task generator""" -from dataclasses import dataclass + import random +from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple + @dataclass class NumberSortingConfig: """Configuration for number sorting task generation""" - min_numbers: int = 3 # Minimum numbers to sort - max_numbers: int = 10 # Maximum numbers to sort - min_decimals: int = 0 # Minimum decimal places - max_decimals: int = 2 # Maximum decimal places + + min_numbers: int = 3 # Minimum numbers to sort + max_numbers: int = 10 # Maximum numbers to sort + min_decimals: int = 0 # Minimum decimal places + max_decimals: int = 2 # Maximum decimal places min_value: float = -100.0 # Minimum value - max_value: float = 100.0 # Maximum value + max_value: float = 100.0 # Maximum value seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -57,10 +60,10 @@ class NumberSortingDataset: """Generate list of numbers and their string representations""" count = rng.randint(self.config.min_numbers, self.config.max_numbers) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) - + numbers = [] number_strs = [] - + for _ in range(count): num = rng.uniform(self.config.min_value, self.config.max_value) num_str = self._format_number(num, decimals) @@ -68,37 +71,33 @@ class NumberSortingDataset: num = float(num_str) numbers.append(num) number_strs.append(num_str) - + return numbers, number_strs def __getitem__(self, idx: int) -> dict: """Generate a single sorting task""" rng = Random(self.seed + idx) - + numbers, number_strs = self._generate_numbers(rng) - + # Generate both ascending and descending answers asc_numbers = sorted(numbers) desc_numbers = sorted(numbers, reverse=True) - + # Format answers as string lists - decimals = len(number_strs[0].split('.')[-1]) if '.' in number_strs[0] else 0 + decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0 asc_answer = [self._format_number(n, decimals) for n in asc_numbers] desc_answer = [self._format_number(n, decimals) for n in desc_numbers] - + # Randomly choose ascending or descending is_ascending = rng.choice([True, False]) direction = "ascending" if is_ascending else "descending" answer = asc_answer if is_ascending else desc_answer - + return { "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", "answer": str(answer), - "metadata": { - "original_numbers": number_strs, - "direction": direction, - "sorted_numbers": answer - } + "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer}, } diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_reversal.py index 16144e17..7fa10332 100644 --- a/reasoning_gym/algorithmic/word_reversal.py +++ b/reasoning_gym/algorithmic/word_reversal.py @@ -1,18 +1,21 @@ """Word reversal task generator""" -from dataclasses import dataclass + import re +from dataclasses import dataclass from random import Random from typing import List, Optional from reasoning_gym.data import read_data_file + @dataclass class WordReversalConfig: """Configuration for word reversal task generation""" - min_words: int = 3 # Minimum words in list - max_words: int = 8 # Maximum words in list + + min_words: int = 3 # Minimum words in list + max_words: int = 8 # Maximum words in list seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -27,11 +30,11 @@ class WordReversalDataset: self.config = config self.config.validate() self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - + # Load and preprocess text text = read_data_file("in_the_year_2889.txt") # Extract words and clean them to contain only alphanumeric characters - self.words = [word for word in re.findall(r'\b\w+\b', text) if word.isalnum()] + self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()] def __len__(self) -> int: return self.config.size @@ -50,23 +53,20 @@ class WordReversalDataset: def __getitem__(self, idx: int) -> dict: """Generate a single word reversal task""" rng = Random(self.seed + idx) - + # Select random words num_words = rng.randint(self.config.min_words, self.config.max_words) word_indices = rng.sample(range(len(self.words)), num_words) words = [self.words[i] for i in word_indices] - + # Create question and answer question = ", ".join(words) answer = ", ".join(reversed(words)) - + return { "question": f"Reverse this list of words: {question}", "answer": answer, - "metadata": { - "num_words": num_words, - "words": words - } + "metadata": {"num_words": num_words, "words": words}, } diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cf93144a..2ac85f97 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -8,7 +8,11 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset -from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset +from .fraction_simplification import ( + FractionSimplificationConfig, + FractionSimplificationDataset, + fraction_simplification_dataset, +) from .gcd import GCDConfig, GCDDataset, gcd_dataset from .lcm import LCMConfig, LCMDataset, lcm_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset @@ -25,7 +29,7 @@ __all__ = [ "FractionSimplificationDataset", "fraction_simplification_dataset", "GCDConfig", - "GCDDataset", + "GCDDataset", "gcd_dataset", "LCMConfig", "LCMDataset", @@ -35,5 +39,5 @@ __all__ = [ "leg_counting_dataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", - "prime_factorization_dataset" + "prime_factorization_dataset", ] diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index ae9750f6..a1c21e70 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Literal, Optional + from ..dataset import ProceduralDataset @@ -145,7 +146,6 @@ class BasicArithmeticDataset(ProceduralDataset): expression = " ".join(expression_parts) return expression, result - def _format_question(self, rng: Random, expression: str) -> str: """Format the expression according to config style""" if self.config.format_style == "simple": diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index dc122091..6bfad5f9 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -1,6 +1,7 @@ import random from dataclasses import dataclass from typing import Optional + from ..dataset import ProceduralDataset @@ -70,7 +71,6 @@ class ChainSum(ProceduralDataset): }, } - def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: """Generate a chain sum task diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 2aa7bd19..21007fea 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -1,21 +1,24 @@ """Fraction simplification task generator""" + from dataclasses import dataclass -from random import Random -from typing import Optional, Tuple, Sequence -from ..dataset import ProceduralDataset from math import gcd +from random import Random +from typing import Optional, Sequence, Tuple + +from ..dataset import ProceduralDataset @dataclass class FractionSimplificationConfig: """Configuration for fraction simplification task generation""" - min_value: int = 1 # Minimum value for numerator/denominator - max_value: int = 1000 # Maximum value for numerator/denominator - min_factor: int = 1 # Minimum multiplication factor - max_factor: int = 100 # Maximum multiplication factor + + min_value: int = 1 # Minimum value for numerator/denominator + max_value: int = 1000 # Maximum value for numerator/denominator + min_factor: int = 1 # Minimum multiplication factor + max_factor: int = 100 # Maximum multiplication factor styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac") # Allowed fraction formatting styles seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -23,7 +26,7 @@ class FractionSimplificationConfig: assert self.max_value > self.min_value, "max_value must be > min_value" assert self.min_factor >= 1, "min_factor must be at least 1" assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" - + # Validate styles valid_styles = {"plain", "latex_inline", "latex_frac", "latex_dfrac"} for style in self.styles: @@ -46,37 +49,38 @@ class FractionSimplificationDataset(ProceduralDataset): # Generate the simplified fraction first simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value) - + # Make sure they're coprime by dividing by their GCD common = gcd(simplified_num, simplified_den) simplified_num //= common simplified_den //= common - + # Check if simplified fraction is within bounds - if (self.config.min_value <= simplified_num <= self.config.max_value and - self.config.min_value <= simplified_den <= self.config.max_value): + if ( + self.config.min_value <= simplified_num <= self.config.max_value + and self.config.min_value <= simplified_den <= self.config.max_value + ): # Ensure numerator is smaller than denominator if simplified_num > simplified_den: simplified_num, simplified_den = simplified_den, simplified_num - + # Multiply both by a random factor to create the unsimplified version factor = rng.randint(self.config.min_factor, self.config.max_factor) numerator = simplified_num * factor denominator = simplified_den * factor return numerator, denominator, simplified_num, simplified_den - + # If we failed to find a good fraction after max attempts, # generate one that's guaranteed to be within bounds simplified_num = rng.randint(self.config.min_value, self.config.max_value) simplified_den = rng.randint(self.config.min_value, self.config.max_value) - + # Ensure numerator is smaller than denominator if simplified_num > simplified_den: simplified_num, simplified_den = simplified_den, simplified_num - + factor = rng.randint(self.config.min_factor, self.config.max_factor) - return (simplified_num * factor, simplified_den * factor, - simplified_num, simplified_den) + return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den) def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: """Format a fraction in various styles""" @@ -95,16 +99,16 @@ class FractionSimplificationDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single fraction simplification task""" rng = Random(self.seed + idx) - + num, den, simple_num, simple_den = self._generate_fraction(rng) - + # Choose a random style from configured styles - style = self.config.styles[rng.randint(0, len(self.config.styles)-1)] - + style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)] + # Format both question and answer in the same style question_fraction = self._format_fraction(num, den, style) answer_fraction = self._format_fraction(simple_num, simple_den, style) - + return { "question": f"Simplify the fraction {question_fraction} to its lowest terms", "answer": answer_fraction, @@ -114,8 +118,8 @@ class FractionSimplificationDataset(ProceduralDataset): "simplified_numerator": simple_num, "simplified_denominator": simple_den, "reduction_factor": num // simple_num, # Will be same as den // simple_den - "style": style - } + "style": style, + }, } diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index ba0a69cc..26c12041 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -1,21 +1,24 @@ """Greatest Common Divisor (GCD) task generator""" + from dataclasses import dataclass +from functools import reduce +from math import gcd from random import Random from typing import List, Optional, Tuple + from ..dataset import ProceduralDataset -from math import gcd -from functools import reduce @dataclass class GCDConfig: """Configuration for GCD task generation""" - min_numbers: int = 2 # Minimum numbers to find GCD of - max_numbers: int = 2 # Maximum numbers to find GCD of - min_value: int = 1 # Minimum value for each number - max_value: int = 1000 # Maximum value for each number + + min_numbers: int = 2 # Minimum numbers to find GCD of + max_numbers: int = 2 # Maximum numbers to find GCD of + min_value: int = 1 # Minimum value for each number + max_value: int = 1000 # Maximum value for each number seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -38,33 +41,28 @@ class GCDDataset(ProceduralDataset): Will try up to 3 times to find numbers with GCD > 1.""" for _ in range(3): # Try up to 3 times to get GCD > 1 num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - numbers = [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] result = reduce(gcd, numbers) if result > 1: return numbers, result - + # If we failed to find GCD > 1 after 3 tries, generate one final set num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - numbers = [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] result = reduce(gcd, numbers) return numbers, result def __getitem__(self, idx: int) -> dict: """Generate a single GCD task""" rng = Random(self.seed + idx) - + numbers, result = self._generate_numbers(rng) numbers_str = ", ".join(str(n) for n in numbers) - + return { "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}", "answer": str(result), - "metadata": { - "numbers": numbers, - "result": result - } + "metadata": {"numbers": numbers, "result": result}, } diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index 05804eda..85b52006 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -1,21 +1,24 @@ """Least Common Multiple (LCM) task generator""" + from dataclasses import dataclass +from functools import reduce +from math import lcm from random import Random from typing import List, Optional, Tuple + from ..dataset import ProceduralDataset -from math import lcm -from functools import reduce @dataclass class LCMConfig: """Configuration for LCM task generation""" - min_numbers: int = 2 # Minimum numbers to find LCM of - max_numbers: int = 2 # Maximum numbers to find LCM of - min_value: int = 1 # Minimum value for each number - max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast) + + min_numbers: int = 2 # Minimum numbers to find LCM of + max_numbers: int = 2 # Maximum numbers to find LCM of + min_value: int = 1 # Minimum value for each number + max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast) seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -36,38 +39,34 @@ class LCMDataset(ProceduralDataset): def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: """Generate a list of random positive integers and their LCM. Will try up to 3 times to find numbers with LCM < product.""" + def calculate_product(nums: List[int]) -> int: return reduce(lambda x, y: x * y, nums) - + for _ in range(3): # Try up to 3 times to get LCM < product num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - numbers = [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] result = reduce(lcm, numbers) if result < calculate_product(numbers): return numbers, result - + # If we failed to find LCM < product after 3 tries, generate one final set num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) - numbers = [rng.randint(self.config.min_value, self.config.max_value) - for _ in range(num_count)] + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_count)] result = reduce(lcm, numbers) return numbers, result def __getitem__(self, idx: int) -> dict: """Generate a single LCM task""" rng = Random(self.seed + idx) - + numbers, result = self._generate_numbers(rng) numbers_str = ", ".join(str(n) for n in numbers) - + return { "question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}", "answer": str(result), - "metadata": { - "numbers": numbers, - "result": result - } + "metadata": {"numbers": numbers, "result": result}, } diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 7bec5c5e..a1308b0c 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -1,7 +1,9 @@ """Leg counting task generator""" + from dataclasses import dataclass from random import Random from typing import Dict, Optional + from ..dataset import ProceduralDataset ANIMALS = { @@ -52,14 +54,16 @@ ANIMALS = { "woodlouse": 14, } + @dataclass class LegCountingConfig: """Configuration for leg counting task generation""" - min_animals: int = 2 # Minimum number of animals in problem - max_animals: int = 5 # Maximum number of animals - max_instances: int = 3 # Maximum instances of each animal + + min_animals: int = 2 # Minimum number of animals in problem + max_animals: int = 5 # Maximum number of animals + max_instances: int = 3 # Maximum instances of each animal seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -80,39 +84,36 @@ class LegCountingDataset(ProceduralDataset): """Generate a random set of animals and their counts""" num_types = rng.randint(self.config.min_animals, self.config.max_animals) animals = {} - + # Select random animals selected_animals = rng.sample(list(ANIMALS.keys()), num_types) for animal in selected_animals: count = rng.randint(1, self.config.max_instances) animals[animal] = count - + return animals def __getitem__(self, idx: int) -> dict: """Generate a single leg counting task""" rng = Random(self.seed + idx) - + # Generate random animals and their counts animals = self._generate_animals(rng) - + # Calculate total legs total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items()) - + # Format animal counts for question animal_list = [] for animal, count in animals.items(): animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}") - + question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?" - + return { "question": question, "answer": str(total_legs), - "metadata": { - "animals": animals, - "total_legs": total_legs - } + "metadata": {"animals": animals, "total_legs": total_legs}, } diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index 2d4418cd..ab228d7f 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -1,16 +1,20 @@ """Prime factorization task generator""" + from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple + from ..dataset import ProceduralDataset + @dataclass class PrimeFactorizationConfig: """Configuration for prime factorization task generation""" - min_value: int = 2 # Minimum number to factorize - max_value: int = 1000 # Maximum number to factorize + + min_value: int = 2 # Minimum number to factorize + max_value: int = 1000 # Maximum number to factorize seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -44,24 +48,23 @@ class PrimeFactorizationDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single prime factorization task""" rng = Random(self.seed + idx) - + # Generate random number to factorize number = rng.randint(self.config.min_value, self.config.max_value) - + # Calculate prime factors factors = self._prime_factors(number) - + # Format answer as multiplication of prime factors answer = " × ".join(map(str, factors)) - + return { - "question": (f"Find the prime factorization of {number}. Write the factors separated by × " - f"(Example: for 12 the answer would be: 2 × 2 × 3)"), + "question": ( + f"Find the prime factorization of {number}. Write the factors separated by × " + f"(Example: for 12 the answer would be: 2 × 2 × 3)" + ), "answer": answer, - "metadata": { - "number": number, - "factors": factors - } + "metadata": {"number": number, "factors": factors}, } diff --git a/reasoning_gym/cognition/sequences.py b/reasoning_gym/cognition/sequences.py index 55288e36..56fbdbee 100644 --- a/reasoning_gym/cognition/sequences.py +++ b/reasoning_gym/cognition/sequences.py @@ -40,7 +40,7 @@ class SequenceConfig: class PatternRule: """Represents a composable sequence pattern rule""" - def __init__(self, operations: List[Operation], parameters: List[int], subrules: List['PatternRule'] = None): + def __init__(self, operations: List[Operation], parameters: List[int], subrules: List["PatternRule"] = None): self.operations = operations self.parameters = parameters self.subrules = subrules or [] @@ -66,14 +66,14 @@ class PatternRule: elif op == Operation.COMPOSE: # Apply each subrule in sequence, passing the result through for subrule in self.subrules: - temp_sequence = sequence[:position + 1] + temp_sequence = sequence[: position + 1] temp_sequence[-1] = result # Use current result as input result = subrule.apply(temp_sequence, position) return result @classmethod - def compose(cls, rules: List['PatternRule']) -> 'PatternRule': + def compose(cls, rules: List["PatternRule"]) -> "PatternRule": """Create a new rule that composes multiple rules together""" return cls([Operation.COMPOSE], [0], subrules=rules) diff --git a/reasoning_gym/data/__init__.py b/reasoning_gym/data/__init__.py index 7ee1121f..d0c4f943 100644 --- a/reasoning_gym/data/__init__.py +++ b/reasoning_gym/data/__init__.py @@ -4,34 +4,37 @@ from importlib import resources from pathlib import Path from typing import Union + def get_data_file_path(filename: str) -> Path: """Get the path to a data file in the package. - + Args: filename: Name of the file in the data directory - + Returns: Path object pointing to the data file - + Example: >>> path = get_data_file_path("pg19362.txt") >>> with open(path) as f: ... content = f.read() """ - return resources.files('reasoning_gym.data').joinpath(filename) + return resources.files("reasoning_gym.data").joinpath(filename) + def read_data_file(filename: str) -> str: """Read the contents of a data file in the package. - + Args: filename: Name of the file in the data directory - + Returns: String contents of the file - + Example: >>> content = read_data_file("pg19362.txt") """ - return resources.files('reasoning_gym.data').joinpath(filename).read_text() + return resources.files("reasoning_gym.data").joinpath(filename).read_text() -__all__ = ['get_data_file_path', 'read_data_file'] + +__all__ = ["get_data_file_path", "read_data_file"] diff --git a/reasoning_gym/data/in_the_year_2889.txt b/reasoning_gym/data/in_the_year_2889.txt index 1e295f80..7d004201 100644 --- a/reasoning_gym/data/in_the_year_2889.txt +++ b/reasoning_gym/data/in_the_year_2889.txt @@ -1,5 +1,5 @@ The Project Gutenberg eBook of In the year 2889 - + This ebook is for the use of anyone anywhere in the United States and most other parts of the world at no cost and with almost no restrictions whatsoever. You may copy it, give it away or re-use it under the terms @@ -702,7 +702,7 @@ End of Project Gutenberg's In the Year 2889, by Jules Verne and Michel Verne *** END OF THE PROJECT GUTENBERG EBOOK IN THE YEAR 2889 *** - + Updated editions will replace the previous one—the old editions will be renamed. @@ -807,7 +807,7 @@ performed, viewed, copied or distributed: at www.gutenberg.org. If you are not located in the United States, you will have to check the laws of the country where you are located before using this eBook. - + 1.E.2. If an individual Project Gutenberg™ electronic work is derived from texts not protected by U.S. copyright law (does not contain a notice indicating that it is posted with permission of the @@ -869,7 +869,7 @@ provided that: Gutenberg Literary Archive Foundation at the address specified in Section 4, “Information about donations to the Project Gutenberg Literary Archive Foundation.” - + • You provide a full refund of any money paid by a user who notifies you in writing (or by e-mail) within 30 days of receipt that s/he does not agree to the terms of the full Project Gutenberg™ @@ -877,15 +877,15 @@ provided that: copies of the works possessed in a physical medium and discontinue all use of and all access to other copies of Project Gutenberg™ works. - + • You provide, in accordance with paragraph 1.F.3, a full refund of any money paid for a work or a replacement copy, if a defect in the electronic work is discovered and reported to you within 90 days of receipt of the work. - + • You comply with all other terms of this agreement for free distribution of Project Gutenberg™ works. - + 1.E.9. If you wish to charge a fee or distribute a Project Gutenberg™ electronic work or group of works on different terms than @@ -1048,5 +1048,3 @@ This website includes information about Project Gutenberg™, including how to make donations to the Project Gutenberg Literary Archive Foundation, how to help produce our new eBooks, and how to subscribe to our email newsletter to hear about new eBooks. - - diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index cce6752d..6837e40f 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -1,27 +1,28 @@ """Base class for procedural dataset generators""" + from abc import ABC, abstractmethod -from collections.abc import Sized, Iterable +from collections.abc import Iterable, Sized from random import Random -from typing import Optional, Iterator, Dict, Any +from typing import Any, Dict, Iterator, Optional class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): """Abstract base class for procedural dataset generators""" - + def __init__(self, seed: Optional[int] = None, size: int = 500): """Initialize the dataset with optional seed and size""" self.size = size self.seed = seed if seed is not None else Random().randint(0, 2**32) - + def __len__(self) -> int: """Return the virtual size of the dataset""" return self.size - + def __iter__(self): """Make the dataset iterable""" self._current_idx = 0 return self - + def __next__(self) -> Dict[str, Any]: """Get next item in iteration""" if self._current_idx >= self.size: @@ -29,14 +30,14 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): item = self[self._current_idx] self._current_idx += 1 return item - + @abstractmethod def __getitem__(self, idx: int) -> dict: """Generate a single dataset item - + Args: idx: Index of the item to generate - + Returns: dict containing at least: - question: str diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 0d46cdae..f8e166fa 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -14,5 +14,5 @@ __all__ = [ "mini_sudoku_dataset", "SudokuConfig", "SudokuDataset", - "sudoku_dataset" + "sudoku_dataset", ] diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index dafebf25..a08c8123 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -1,16 +1,19 @@ """Mini Sudoku (4x4) puzzle generator""" -from dataclasses import dataclass + import random +from dataclasses import dataclass from random import Random from typing import List, Optional, Set, Tuple + @dataclass class MiniSudokuConfig: """Configuration for 4x4 sudoku puzzle generation""" - min_empty: int = 8 # Minimum number of empty cells - max_empty: int = 12 # Maximum number of empty cells + + min_empty: int = 8 # Minimum number of empty cells + max_empty: int = 12 # Maximum number of empty cells seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -45,11 +48,11 @@ class MiniSudokuDataset: # Check row if num in board[row]: return False - + # Check column if num in [board[i][col] for i in range(4)]: return False - + # Check 2x2 box box_row, box_col = 2 * (row // 2), 2 * (col // 2) for i in range(box_row, box_row + 2): @@ -63,7 +66,7 @@ class MiniSudokuDataset: empty = self._find_empty(board) if not empty: return True - + row, col = empty for num in range(1, 5): if self._is_valid(board, row, col, num): @@ -84,7 +87,7 @@ class MiniSudokuDataset: def _generate_solved_board(self, rng: Random) -> List[List[int]]: """Generate a complete solved mini sudoku board""" board = [[0] * 4 for _ in range(4)] - + # Try multiple times to generate a valid board max_attempts = 100 for _ in range(max_attempts): @@ -92,7 +95,7 @@ class MiniSudokuDataset: for i in range(4): for j in range(4): board[i][j] = 0 - + # Fill diagonal boxes first (they are independent) for i in range(0, 4, 2): nums = list(range(1, 5)) @@ -102,11 +105,11 @@ class MiniSudokuDataset: for c in range(i, i + 2): board[r][c] = nums[pos] pos += 1 - + # Try to solve the rest if self._solve(board): return board - + raise RuntimeError("Failed to generate valid mini sudoku board") def _create_puzzle(self, solved_board: List[List[int]], num_empty: int, rng: Random) -> List[List[int]]: @@ -114,10 +117,10 @@ class MiniSudokuDataset: puzzle = [row[:] for row in solved_board] cells = [(i, j) for i in range(4) for j in range(4)] rng.shuffle(cells) - + for i, j in cells[:num_empty]: puzzle[i][j] = 0 - + return puzzle def _board_to_string(self, board: List[List[int]]) -> str: @@ -127,26 +130,22 @@ class MiniSudokuDataset: def __getitem__(self, idx: int) -> dict: """Generate a single mini sudoku puzzle""" rng = Random(self.seed + idx) - + # Generate solved board solved_board = self._generate_solved_board(rng) - + # Create puzzle by removing numbers num_empty = rng.randint(self.config.min_empty, self.config.max_empty) puzzle = self._create_puzzle(solved_board, num_empty, rng) - + # Format as strings puzzle_str = self._board_to_string(puzzle) solution_str = self._board_to_string(solved_board) - + return { "question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}", "answer": solution_str, - "metadata": { - "puzzle": puzzle, - "solution": solved_board, - "num_empty": num_empty - } + "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, } diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index 47ffd54f..a47b7fbf 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -1,16 +1,19 @@ """Sudoku puzzle generator""" -from dataclasses import dataclass + import random +from dataclasses import dataclass from random import Random from typing import List, Optional, Set, Tuple + @dataclass class SudokuConfig: """Configuration for sudoku puzzle generation""" - min_empty: int = 30 # Minimum number of empty cells - max_empty: int = 50 # Maximum number of empty cells + + min_empty: int = 30 # Minimum number of empty cells + max_empty: int = 50 # Maximum number of empty cells seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 # Virtual dataset size def validate(self): """Validate configuration parameters""" @@ -45,11 +48,11 @@ class SudokuDataset: # Check row if num in board[row]: return False - + # Check column if num in [board[i][col] for i in range(9)]: return False - + # Check 3x3 box box_row, box_col = 3 * (row // 3), 3 * (col // 3) for i in range(box_row, box_row + 3): @@ -63,7 +66,7 @@ class SudokuDataset: empty = self._find_empty(board) if not empty: return True - + row, col = empty for num in range(1, 10): if self._is_valid(board, row, col, num): @@ -84,7 +87,7 @@ class SudokuDataset: def _generate_solved_board(self, rng: Random) -> List[List[int]]: """Generate a complete solved sudoku board""" board = [[0] * 9 for _ in range(9)] - + # Fill diagonal boxes first (they are independent) for i in range(0, 9, 3): nums = list(range(1, 10)) @@ -94,7 +97,7 @@ class SudokuDataset: for c in range(i, i + 3): board[r][c] = nums[pos] pos += 1 - + # Solve the rest self._solve(board) return board @@ -104,10 +107,10 @@ class SudokuDataset: puzzle = [row[:] for row in solved_board] cells = [(i, j) for i in range(9) for j in range(9)] rng.shuffle(cells) - + for i, j in cells[:num_empty]: puzzle[i][j] = 0 - + return puzzle def _board_to_string(self, board: List[List[int]]) -> str: @@ -117,26 +120,22 @@ class SudokuDataset: def __getitem__(self, idx: int) -> dict: """Generate a single sudoku puzzle""" rng = Random(self.seed + idx) - + # Generate solved board solved_board = self._generate_solved_board(rng) - + # Create puzzle by removing numbers num_empty = rng.randint(self.config.min_empty, self.config.max_empty) puzzle = self._create_puzzle(solved_board, num_empty, rng) - + # Format as strings puzzle_str = self._board_to_string(puzzle) solution_str = self._board_to_string(solved_board) - + return { "question": f"Solve this Sudoku puzzle:\n{puzzle_str}", "answer": solution_str, - "metadata": { - "puzzle": puzzle, - "solution": solved_board, - "num_empty": num_empty - } + "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, } diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index 87d5f1f5..8472b5fa 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -1,5 +1,7 @@ -import pytest from random import Random + +import pytest + from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig @@ -8,11 +10,11 @@ def test_arithmetic_dataset_config_validation(): with pytest.raises(AssertionError): config = BasicArithmeticDatasetConfig(min_terms=0) config.validate() - + with pytest.raises(AssertionError): config = BasicArithmeticDatasetConfig(min_terms=3, max_terms=2) config.validate() - + with pytest.raises(AssertionError): config = BasicArithmeticDatasetConfig(operators=["^"]) # Invalid operator config.validate() @@ -23,30 +25,23 @@ def test_arithmetic_dataset_deterministic(): config = BasicArithmeticDatasetConfig(seed=42, size=10) dataset1 = BasicArithmeticDataset(config) dataset2 = BasicArithmeticDataset(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_arithmetic_dataset_items(): """Test basic properties of generated items""" - config = BasicArithmeticDatasetConfig( - min_terms=2, - max_terms=4, - min_digits=1, - max_digits=2, - size=100, - seed=42 - ) + config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) dataset = BasicArithmeticDataset(config) - + for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Verify the answer matches the expression expression = item["metadata"]["expression"] answer = eval(expression) # Safe here as we control the expression @@ -62,11 +57,11 @@ def test_arithmetic_dataset_format_styles(): min_terms=2, max_terms=3, # Keep expressions simple for testing min_digits=1, - max_digits=2 + max_digits=2, ) dataset = BasicArithmeticDataset(config) assert all(item["question"].endswith("=") for item in dataset) - + config.format_style = "natural" dataset = BasicArithmeticDataset(config) assert all("=" not in item["question"] for item in dataset) @@ -74,24 +69,19 @@ def test_arithmetic_dataset_format_styles(): def test_arithmetic_dataset_iteration(): """Test that iteration respects dataset size""" - config = BasicArithmeticDatasetConfig( - min_terms=2, - max_terms=2, - size=5, # Small size for testing - seed=42 - ) + config = BasicArithmeticDatasetConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing dataset = BasicArithmeticDataset(config) - + # Test manual iteration items = [] for item in dataset: items.append(item) assert len(items) == config.size, "Iterator should yield exactly size items" - + # Test list conversion items = list(dataset) assert len(items) == config.size, "Iterator should yield exactly size items" - + # Test multiple iterations first_items = list(dataset) second_items = list(dataset) diff --git a/tests/test_base_conversion.py b/tests/test_base_conversion.py index 94e79aa6..7c8edf1e 100644 --- a/tests/test_base_conversion.py +++ b/tests/test_base_conversion.py @@ -1,10 +1,8 @@ """Tests for base conversion task generation""" + import pytest -from reasoning_gym.algorithmic.base_conversion import ( - BaseConversionConfig, - BaseConversionDataset, -) +from reasoning_gym.algorithmic.base_conversion import BaseConversionConfig, BaseConversionDataset def test_base_conversion_config_validation(): @@ -38,14 +36,7 @@ def test_base_conversion_dataset_deterministic(): def test_base_conversion_dataset_items(): """Test basic properties of generated items""" - config = BaseConversionConfig( - min_base=2, - max_base=16, - min_value=0, - max_value=1000, - size=10, - seed=42 - ) + config = BaseConversionConfig(min_base=2, max_base=16, min_value=0, max_value=1000, size=10, seed=42) dataset = BaseConversionDataset(config) for i in range(len(dataset)): @@ -55,28 +46,28 @@ def test_base_conversion_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "decimal_value" in item["metadata"] assert "source_base" in item["metadata"] assert "target_base" in item["metadata"] assert "source_repr" in item["metadata"] assert "target_repr" in item["metadata"] - + # Verify value range assert config.min_value <= item["metadata"]["decimal_value"] <= config.max_value - + # Verify base range assert config.min_base <= item["metadata"]["source_base"] <= config.max_base assert config.min_base <= item["metadata"]["target_base"] <= config.max_base assert item["metadata"]["source_base"] != item["metadata"]["target_base"] - + # Verify conversion correctness decimal_value = item["metadata"]["decimal_value"] target_base = item["metadata"]["target_base"] - expected = format(decimal_value, 'x' if target_base == 16 else 'b' if target_base == 2 else '').strip() + expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip() if target_base not in (2, 16): - expected = format(decimal_value, f'{target_base}x').lower().strip() + expected = format(decimal_value, f"{target_base}x").lower().strip() assert item["answer"] == expected @@ -100,24 +91,24 @@ def test_base_conversion_special_bases(): min_value=0, max_value=255, # Use small range for predictable results size=100, - seed=42 + seed=42, ) dataset = BaseConversionDataset(config) - + binary_found = False hex_found = False - + for i in range(len(dataset)): item = dataset[i] if item["metadata"]["target_base"] == 2: binary_found = True # Verify binary format - assert all(c in '01' for c in item["answer"]) + assert all(c in "01" for c in item["answer"]) elif item["metadata"]["target_base"] == 16: hex_found = True # Verify hex format - assert all(c in '0123456789abcdef' for c in item["answer"]) - + assert all(c in "0123456789abcdef" for c in item["answer"]) + assert binary_found, "No binary conversion tasks generated" assert hex_found, "No hexadecimal conversion tasks generated" @@ -130,10 +121,10 @@ def test_base_conversion_formatting(): min_value=10, # Ensure multi-digit numbers max_value=1000, size=10, - seed=42 + seed=42, ) dataset = BaseConversionDataset(config) - + for i in range(len(dataset)): item = dataset[i] # Verify lowercase letters are used diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index aff20a6c..c1ddf641 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,4 +1,5 @@ import pytest + from reasoning_gym.arithmetic import ChainSum, ChainSumConfig @@ -7,7 +8,7 @@ def test_chain_sum_config_validation(): with pytest.raises(AssertionError): config = ChainSumConfig(min_terms=0) config.validate() - + with pytest.raises(AssertionError): config = ChainSumConfig(min_terms=3, max_terms=2) config.validate() @@ -18,34 +19,27 @@ def test_chain_sum_deterministic(): config = ChainSumConfig(seed=42, size=10) dataset1 = ChainSum(config) dataset2 = ChainSum(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_chain_sum_items(): """Test basic properties of generated items""" - config = ChainSumConfig( - min_terms=2, - max_terms=4, - min_digits=1, - max_digits=2, - size=100, - seed=42 - ) + config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) dataset = ChainSum(config) - + for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Verify only + and - are used expression = item["metadata"]["expression"] assert all(op in ["+", "-", " "] or op.isdigit() for op in expression) - + # Verify the answer matches the expression answer = eval(expression) # Safe here as we control the expression assert str(answer) == item["answer"] @@ -60,10 +54,10 @@ def test_chain_sum_number_ranges(): min_digits=3, # Should generate numbers >= 100 max_digits=3, # Should generate numbers <= 999 size=50, - seed=42 + seed=42, ) dataset = ChainSum(config) - + for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -74,16 +68,8 @@ def test_chain_sum_number_ranges(): else: assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" - # Test 1-digit numbers - config = ChainSumConfig( - min_terms=2, - max_terms=2, - min_digits=1, - max_digits=1, - size=50, - seed=42 - ) + config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) dataset = ChainSum(config) for i in range(len(dataset)): item = dataset[i] @@ -95,58 +81,48 @@ def test_chain_sum_number_ranges(): else: assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + def test_chain_sum_negation(): """Test that allow_negation controls number ranges""" config = ChainSumConfig( - min_terms=2, - max_terms=2, - min_digits=2, - max_digits=2, - size=100, - seed=42, - allow_negation=True + min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True ) dataset = ChainSum(config) - + # Track if we see both positive and negative numbers has_positive = False has_negative = False - + for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] - numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith('-') and n[1:].isdigit())] - + numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())] + for num in numbers: if num > 0: has_positive = True if num < 0: has_negative = True - + # With enough samples and allow_negation=True, we should see both positive and negative numbers assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True" def test_chain_sum_iteration(): """Test that iteration respects dataset size""" - config = ChainSumConfig( - min_terms=2, - max_terms=2, - size=5, # Small size for testing - seed=42 - ) + config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing dataset = ChainSum(config) - + # Test manual iteration items = [] for item in dataset: items.append(item) assert len(items) == config.size, "Iterator should yield exactly size items" - + # Test list conversion items = list(dataset) assert len(items) == config.size, "Iterator should yield exactly size items" - + # Test multiple iterations first_items = list(dataset) second_items = list(dataset) diff --git a/tests/test_fraction_simplification.py b/tests/test_fraction_simplification.py index ae674329..4b399e8f 100644 --- a/tests/test_fraction_simplification.py +++ b/tests/test_fraction_simplification.py @@ -1,6 +1,8 @@ -import pytest from math import gcd -from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig + +import pytest + +from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset def test_fraction_config_validation(): @@ -8,15 +10,15 @@ def test_fraction_config_validation(): with pytest.raises(AssertionError): config = FractionSimplificationConfig(min_value=0) # Should be positive config.validate() - + with pytest.raises(AssertionError): config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min config.validate() - + with pytest.raises(AssertionError): config = FractionSimplificationConfig(min_factor=0) # Should be >= 1 config.validate() - + with pytest.raises(AssertionError): config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min config.validate() @@ -27,30 +29,23 @@ def test_fraction_deterministic(): config = FractionSimplificationConfig(seed=42, size=10) dataset1 = FractionSimplificationDataset(config) dataset2 = FractionSimplificationDataset(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_fraction_items(): """Test basic properties of generated items""" - config = FractionSimplificationConfig( - min_value=1, - max_value=20, - min_factor=2, - max_factor=5, - size=50, - seed=42 - ) + config = FractionSimplificationConfig(min_value=1, max_value=20, min_factor=2, max_factor=5, size=50, seed=42) dataset = FractionSimplificationDataset(config) - + for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Verify the metadata contains all expected fields metadata = item["metadata"] assert "numerator" in metadata @@ -58,45 +53,38 @@ def test_fraction_items(): assert "simplified_numerator" in metadata assert "simplified_denominator" in metadata assert "reduction_factor" in metadata - + # Verify the numbers are within configured range assert config.min_value <= metadata["simplified_numerator"] <= config.max_value assert config.min_value <= metadata["simplified_denominator"] <= config.max_value - + # Verify the reduction is correct num = metadata["numerator"] den = metadata["denominator"] simple_num = metadata["simplified_numerator"] simple_den = metadata["simplified_denominator"] factor = metadata["reduction_factor"] - + assert num == simple_num * factor assert den == simple_den * factor - + # Verify the simplified fraction is actually in lowest terms assert gcd(simple_num, simple_den) == 1 def test_fraction_ranges(): """Test that generated numbers respect value constraints""" - config = FractionSimplificationConfig( - min_value=5, - max_value=15, - min_factor=3, - max_factor=4, - size=20, - seed=42 - ) + config = FractionSimplificationConfig(min_value=5, max_value=15, min_factor=3, max_factor=4, size=20, seed=42) dataset = FractionSimplificationDataset(config) - + for i in range(len(dataset)): item = dataset[i] metadata = item["metadata"] factor = metadata["reduction_factor"] - + # Check factor is within bounds assert 3 <= factor <= 4 - + # Check simplified values are within bounds assert 5 <= metadata["simplified_numerator"] <= 15 assert 5 <= metadata["simplified_denominator"] <= 15 @@ -106,17 +94,17 @@ def test_fraction_iteration(): """Test that iteration works correctly""" config = FractionSimplificationConfig(size=5, seed=42) dataset = FractionSimplificationDataset(config) - + # Test manual iteration items = [] for item in dataset: items.append(item) assert len(items) == config.size - + # Test list conversion items = list(dataset) assert len(items) == config.size - + # Test multiple iterations yield same results first_items = list(dataset) second_items = list(dataset) @@ -125,24 +113,19 @@ def test_fraction_iteration(): def test_fraction_numerator_smaller(): """Test that numerators are always smaller than denominators""" - config = FractionSimplificationConfig( - min_value=1, - max_value=100, - min_factor=2, - max_factor=5, - size=50, - seed=42 - ) + config = FractionSimplificationConfig(min_value=1, max_value=100, min_factor=2, max_factor=5, size=50, seed=42) dataset = FractionSimplificationDataset(config) - + for i in range(len(dataset)): item = dataset[i] metadata = item["metadata"] - + # Check original fraction - assert metadata["numerator"] <= metadata["denominator"], \ - f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}" - + assert ( + metadata["numerator"] <= metadata["denominator"] + ), f"Original numerator {metadata['numerator']} should be <= denominator {metadata['denominator']}" + # Check simplified fraction - assert metadata["simplified_numerator"] <= metadata["simplified_denominator"], \ - f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" + assert ( + metadata["simplified_numerator"] <= metadata["simplified_denominator"] + ), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" diff --git a/tests/test_gcd.py b/tests/test_gcd.py index 4c661e10..1ed90df6 100644 --- a/tests/test_gcd.py +++ b/tests/test_gcd.py @@ -1,7 +1,9 @@ -import pytest -from math import gcd from functools import reduce -from reasoning_gym.arithmetic import GCDDataset, GCDConfig +from math import gcd + +import pytest + +from reasoning_gym.arithmetic import GCDConfig, GCDDataset def test_gcd_config_validation(): @@ -9,15 +11,15 @@ def test_gcd_config_validation(): with pytest.raises(AssertionError): config = GCDConfig(min_numbers=1) # Should be >= 2 config.validate() - + with pytest.raises(AssertionError): config = GCDConfig(min_numbers=3, max_numbers=2) # max should be >= min config.validate() - + with pytest.raises(AssertionError): config = GCDConfig(min_value=0) # Should be positive config.validate() - + with pytest.raises(AssertionError): config = GCDConfig(min_value=100, max_value=50) # max should be > min config.validate() @@ -28,40 +30,33 @@ def test_gcd_deterministic(): config = GCDConfig(seed=42, size=10) dataset1 = GCDDataset(config) dataset2 = GCDDataset(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_gcd_items(): """Test basic properties of generated items""" - config = GCDConfig( - min_numbers=2, - max_numbers=4, - min_value=1, - max_value=100, - size=50, - seed=42 - ) + config = GCDConfig(min_numbers=2, max_numbers=4, min_value=1, max_value=100, size=50, seed=42) dataset = GCDDataset(config) - + for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Verify the numbers and result are in metadata metadata = item["metadata"] assert "numbers" in metadata assert "result" in metadata - + # Verify the numbers are within configured range numbers = metadata["numbers"] assert all(config.min_value <= n <= config.max_value for n in numbers) assert config.min_numbers <= len(numbers) <= config.max_numbers - + # Verify the GCD calculation is correct result = metadata["result"] assert str(result) == item["answer"] @@ -70,16 +65,9 @@ def test_gcd_items(): def test_gcd_number_ranges(): """Test that generated numbers respect value constraints""" - config = GCDConfig( - min_numbers=2, - max_numbers=2, - min_value=50, - max_value=100, - size=20, - seed=42 - ) + config = GCDConfig(min_numbers=2, max_numbers=2, min_value=50, max_value=100, size=20, seed=42) dataset = GCDDataset(config) - + for i in range(len(dataset)): item = dataset[i] numbers = item["metadata"]["numbers"] @@ -90,17 +78,17 @@ def test_gcd_iteration(): """Test that iteration works correctly""" config = GCDConfig(size=5, seed=42) dataset = GCDDataset(config) - + # Test manual iteration items = [] for item in dataset: items.append(item) assert len(items) == config.size - + # Test list conversion items = list(dataset) assert len(items) == config.size - + # Test multiple iterations yield same results first_items = list(dataset) second_items = list(dataset) @@ -109,20 +97,13 @@ def test_gcd_iteration(): def test_gcd_special_cases(): """Test some special GCD cases""" - config = GCDConfig( - min_numbers=2, - max_numbers=2, - min_value=1, - max_value=100, - size=100, - seed=42 - ) + config = GCDConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=100, size=100, seed=42) dataset = GCDDataset(config) - + # Track if we see some interesting GCD cases seen_gcd_1 = False # Coprime numbers seen_large_gcd = False # GCD > 1 - + for i in range(len(dataset)): item = dataset[i] result = int(item["answer"]) @@ -130,7 +111,7 @@ def test_gcd_special_cases(): seen_gcd_1 = True if result > 1: seen_large_gcd = True - + # With enough samples, we should see both coprime and non-coprime numbers assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)" assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)" diff --git a/tests/test_lcm.py b/tests/test_lcm.py index 029eea47..91cb9b5e 100644 --- a/tests/test_lcm.py +++ b/tests/test_lcm.py @@ -1,7 +1,9 @@ -import pytest -from math import lcm from functools import reduce -from reasoning_gym.arithmetic import LCMDataset, LCMConfig +from math import lcm + +import pytest + +from reasoning_gym.arithmetic import LCMConfig, LCMDataset def test_lcm_config_validation(): @@ -9,15 +11,15 @@ def test_lcm_config_validation(): with pytest.raises(AssertionError): config = LCMConfig(min_numbers=1) # Should be >= 2 config.validate() - + with pytest.raises(AssertionError): config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min config.validate() - + with pytest.raises(AssertionError): config = LCMConfig(min_value=0) # Should be positive config.validate() - + with pytest.raises(AssertionError): config = LCMConfig(min_value=100, max_value=50) # max should be > min config.validate() @@ -28,7 +30,7 @@ def test_lcm_deterministic(): config = LCMConfig(seed=42, size=10) dataset1 = LCMDataset(config) dataset2 = LCMDataset(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -36,32 +38,27 @@ def test_lcm_deterministic(): def test_lcm_items(): """Test basic properties of generated items""" config = LCMConfig( - min_numbers=2, - max_numbers=4, - min_value=1, - max_value=20, # Keep small for testing - size=50, - seed=42 + min_numbers=2, max_numbers=4, min_value=1, max_value=20, size=50, seed=42 # Keep small for testing ) dataset = LCMDataset(config) - + for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Verify the numbers and result are in metadata metadata = item["metadata"] assert "numbers" in metadata assert "result" in metadata - + # Verify the numbers are within configured range numbers = metadata["numbers"] assert all(config.min_value <= n <= config.max_value for n in numbers) assert config.min_numbers <= len(numbers) <= config.max_numbers - + # Verify the LCM calculation is correct result = metadata["result"] assert str(result) == item["answer"] @@ -70,16 +67,9 @@ def test_lcm_items(): def test_lcm_number_ranges(): """Test that generated numbers respect value constraints""" - config = LCMConfig( - min_numbers=2, - max_numbers=2, - min_value=5, - max_value=15, - size=20, - seed=42 - ) + config = LCMConfig(min_numbers=2, max_numbers=2, min_value=5, max_value=15, size=20, seed=42) dataset = LCMDataset(config) - + for i in range(len(dataset)): item = dataset[i] numbers = item["metadata"]["numbers"] @@ -90,17 +80,17 @@ def test_lcm_iteration(): """Test that iteration works correctly""" config = LCMConfig(size=5, seed=42) dataset = LCMDataset(config) - + # Test manual iteration items = [] for item in dataset: items.append(item) assert len(items) == config.size - + # Test list conversion items = list(dataset) assert len(items) == config.size - + # Test multiple iterations yield same results first_items = list(dataset) second_items = list(dataset) @@ -109,31 +99,24 @@ def test_lcm_iteration(): def test_lcm_special_cases(): """Test some special LCM cases""" - config = LCMConfig( - min_numbers=2, - max_numbers=2, - min_value=1, - max_value=20, - size=100, - seed=42 - ) + config = LCMConfig(min_numbers=2, max_numbers=2, min_value=1, max_value=20, size=100, seed=42) dataset = LCMDataset(config) - + # Track if we see some interesting LCM cases seen_equal_to_product = False # When numbers are coprime seen_less_than_product = False # When numbers share factors - + for i in range(len(dataset)): item = dataset[i] numbers = item["metadata"]["numbers"] result = int(item["answer"]) product = reduce(lambda x, y: x * y, numbers) - + if result == product: seen_equal_to_product = True if result < product: seen_less_than_product = True - + # With enough samples, we should see both cases assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)" assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)" diff --git a/tests/test_leg_counting.py b/tests/test_leg_counting.py index 4589e24f..31191bda 100644 --- a/tests/test_leg_counting.py +++ b/tests/test_leg_counting.py @@ -1,11 +1,8 @@ """Tests for leg counting task generation""" + import pytest -from reasoning_gym.arithmetic.leg_counting import ( - LegCountingConfig, - LegCountingDataset, - ANIMALS, -) +from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset def test_leg_counting_config_validation(): @@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic(): def test_leg_counting_dataset_items(): """Test basic properties of generated items""" - config = LegCountingConfig( - min_animals=2, - max_animals=4, - max_instances=2, - size=10, - seed=42 - ) + config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42) dataset = LegCountingDataset(config) for i in range(len(dataset)): @@ -51,19 +42,19 @@ def test_leg_counting_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "animals" in item["metadata"] assert "total_legs" in item["metadata"] - + # Verify animal count constraints animals = item["metadata"]["animals"] assert len(animals) >= config.min_animals assert len(animals) <= config.max_animals - + # Verify instance count constraints assert all(1 <= count <= config.max_instances for count in animals.values()) - + # Verify leg counting is correct total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items()) assert str(total_legs) == item["answer"] @@ -86,7 +77,7 @@ def test_leg_counting_animal_validation(): """Test that all animals have valid leg counts""" # Verify all animals have non-negative leg counts assert all(legs >= 0 for legs in ANIMALS.values()) - + # Verify common animals have expected leg counts assert ANIMALS["spider"] == 8 assert ANIMALS["insect"] == 6 diff --git a/tests/test_letter_counting.py b/tests/test_letter_counting.py index 7604cda9..7c6e9bd1 100644 --- a/tests/test_letter_counting.py +++ b/tests/test_letter_counting.py @@ -1,10 +1,8 @@ """Tests for letter counting task generation""" + import pytest -from reasoning_gym.algorithmic.letter_counting import ( - LetterCountingConfig, - LetterCountingDataset, -) +from reasoning_gym.algorithmic.letter_counting import LetterCountingConfig, LetterCountingDataset def test_letter_counting_config_validation(): @@ -30,12 +28,7 @@ def test_letter_counting_dataset_deterministic(): def test_letter_counting_dataset_items(): """Test basic properties of generated items""" - config = LetterCountingConfig( - min_words=3, - max_words=6, - size=10, - seed=42 - ) + config = LetterCountingConfig(min_words=3, max_words=6, size=10, seed=42) dataset = LetterCountingDataset(config) for i in range(len(dataset)): @@ -45,17 +38,17 @@ def test_letter_counting_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "span_length" in item["metadata"] assert "target_letter" in item["metadata"] assert "span" in item["metadata"] - + # Verify span length constraints span = item["metadata"]["span"] assert len(span) >= config.min_words assert len(span) <= config.max_words - + # Verify letter counting target_letter = item["metadata"]["target_letter"] count = sum(word.lower().count(target_letter) for word in span) @@ -78,7 +71,7 @@ def test_letter_counting_text_preprocessing(): """Test that text preprocessing handles edge cases""" config = LetterCountingConfig(size=1, seed=42) dataset = LetterCountingDataset(config) - + # Verify words were extracted from text assert len(dataset.words) > 0 # Verify words contain only word characters diff --git a/tests/test_mini_sudoku.py b/tests/test_mini_sudoku.py index 5bddfe2b..606a544a 100644 --- a/tests/test_mini_sudoku.py +++ b/tests/test_mini_sudoku.py @@ -1,10 +1,8 @@ """Tests for mini sudoku puzzle generation""" + import pytest -from reasoning_gym.games.mini_sudoku import ( - MiniSudokuConfig, - MiniSudokuDataset, -) +from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset def test_mini_sudoku_config_validation(): @@ -34,12 +32,7 @@ def test_mini_sudoku_dataset_deterministic(): def test_mini_sudoku_dataset_items(): """Test basic properties of generated items""" - config = MiniSudokuConfig( - min_empty=8, - max_empty=12, - size=10, - seed=42 - ) + config = MiniSudokuConfig(min_empty=8, max_empty=12, size=10, seed=42) dataset = MiniSudokuDataset(config) for i in range(len(dataset)): @@ -49,30 +42,30 @@ def test_mini_sudoku_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "puzzle" in item["metadata"] assert "solution" in item["metadata"] assert "num_empty" in item["metadata"] - + puzzle = item["metadata"]["puzzle"] solution = item["metadata"]["solution"] num_empty = item["metadata"]["num_empty"] - + # Verify board dimensions assert len(puzzle) == 4 assert all(len(row) == 4 for row in puzzle) assert len(solution) == 4 assert all(len(row) == 4 for row in solution) - + # Verify empty cell count empty_count = sum(1 for row in puzzle for cell in row if cell == 0) assert config.min_empty <= empty_count <= config.max_empty assert empty_count == num_empty - + # Verify solution validity assert is_valid_solution(solution) - + # Verify puzzle matches solution where filled for i in range(4): for j in range(4): @@ -94,14 +87,9 @@ def test_mini_sudoku_dataset_iteration(): def test_mini_sudoku_board_generation(): """Test that generated boards are valid""" - config = MiniSudokuConfig( - min_empty=0, # Force complete board - max_empty=0, - size=5, - seed=42 - ) + config = MiniSudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board dataset = MiniSudokuDataset(config) - + for i in range(len(dataset)): item = dataset[i] board = item["metadata"]["solution"] @@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool: for row in board: if set(row) != set(range(1, 5)): return False - + # Check columns for j in range(4): column = [board[i][j] for i in range(4)] if set(column) != set(range(1, 5)): return False - + # Check 2x2 boxes for box_i in range(2): for box_j in range(2): box = [] for i in range(2): for j in range(2): - box.append(board[box_i*2 + i][box_j*2 + j]) + box.append(board[box_i * 2 + i][box_j * 2 + j]) if set(box) != set(range(1, 5)): return False - + return True diff --git a/tests/test_number_filtering.py b/tests/test_number_filtering.py index 1e285e02..e70ec6b7 100644 --- a/tests/test_number_filtering.py +++ b/tests/test_number_filtering.py @@ -1,10 +1,8 @@ """Tests for number filtering task generation""" + import pytest -from reasoning_gym.algorithmic.number_filtering import ( - NumberFilteringConfig, - NumberFilteringDataset, -) +from reasoning_gym.algorithmic.number_filtering import NumberFilteringConfig, NumberFilteringDataset def test_number_filtering_config_validation(): @@ -16,11 +14,11 @@ def test_number_filtering_config_validation(): with pytest.raises(AssertionError): config = NumberFilteringConfig(min_numbers=10, max_numbers=5) config.validate() - + with pytest.raises(AssertionError): config = NumberFilteringConfig(min_decimals=-1) config.validate() - + with pytest.raises(AssertionError): config = NumberFilteringConfig(min_value=100, max_value=0) config.validate() @@ -39,14 +37,7 @@ def test_number_filtering_dataset_deterministic(): def test_number_filtering_dataset_items(): """Test basic properties of generated items""" config = NumberFilteringConfig( - min_numbers=3, - max_numbers=6, - min_decimals=1, - max_decimals=3, - min_value=-10.0, - max_value=10.0, - size=10, - seed=42 + min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42 ) dataset = NumberFilteringDataset(config) @@ -57,34 +48,34 @@ def test_number_filtering_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "original_numbers" in item["metadata"] assert "filter_value" in item["metadata"] assert "operation" in item["metadata"] assert "result" in item["metadata"] - + # Verify number count constraints numbers = item["metadata"]["original_numbers"] assert len(numbers) >= config.min_numbers assert len(numbers) <= config.max_numbers - + # Verify decimal places for num in numbers: - decimal_places = len(num.split('.')[-1]) if '.' in num else 0 + decimal_places = len(num.split(".")[-1]) if "." in num else 0 assert decimal_places >= config.min_decimals assert decimal_places <= config.max_decimals - + # Verify value range for num in numbers: value = float(num) assert config.min_value <= value <= config.max_value - + # Verify filtering operation operation = item["metadata"]["operation"] filter_value = float(item["metadata"]["filter_value"]) result = [float(x) for x in eval(item["answer"])] if item["answer"] != "[]" else [] - + if operation == "keep_larger": assert all(x > filter_value for x in result) elif operation == "keep_smaller": @@ -117,11 +108,11 @@ def test_number_filtering_precision(): min_value=0.0, max_value=1.0, size=1, - seed=42 + seed=42, ) dataset = NumberFilteringDataset(config) item = dataset[0] - + # Check that string representations maintain precision for num in item["metadata"]["original_numbers"]: - assert len(num.split('.')[-1]) == 2 + assert len(num.split(".")[-1]) == 2 diff --git a/tests/test_number_sorting.py b/tests/test_number_sorting.py index 3374a79c..88916076 100644 --- a/tests/test_number_sorting.py +++ b/tests/test_number_sorting.py @@ -1,10 +1,8 @@ """Tests for number sorting task generation""" + import pytest -from reasoning_gym.algorithmic.number_sorting import ( - NumberSortingConfig, - NumberSortingDataset, -) +from reasoning_gym.algorithmic.number_sorting import NumberSortingConfig, NumberSortingDataset def test_number_sorting_config_validation(): @@ -16,11 +14,11 @@ def test_number_sorting_config_validation(): with pytest.raises(AssertionError): config = NumberSortingConfig(min_numbers=10, max_numbers=5) config.validate() - + with pytest.raises(AssertionError): config = NumberSortingConfig(min_decimals=-1) config.validate() - + with pytest.raises(AssertionError): config = NumberSortingConfig(min_value=100, max_value=0) config.validate() @@ -39,14 +37,7 @@ def test_number_sorting_dataset_deterministic(): def test_number_sorting_dataset_items(): """Test basic properties of generated items""" config = NumberSortingConfig( - min_numbers=3, - max_numbers=6, - min_decimals=1, - max_decimals=3, - min_value=-10.0, - max_value=10.0, - size=10, - seed=42 + min_numbers=3, max_numbers=6, min_decimals=1, max_decimals=3, min_value=-10.0, max_value=10.0, size=10, seed=42 ) dataset = NumberSortingDataset(config) @@ -57,28 +48,28 @@ def test_number_sorting_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "original_numbers" in item["metadata"] assert "direction" in item["metadata"] assert "sorted_numbers" in item["metadata"] - + # Verify number count constraints numbers = item["metadata"]["original_numbers"] assert len(numbers) >= config.min_numbers assert len(numbers) <= config.max_numbers - + # Verify decimal places for num in numbers: - decimal_places = len(num.split('.')[-1]) if '.' in num else 0 + decimal_places = len(num.split(".")[-1]) if "." in num else 0 assert decimal_places >= config.min_decimals assert decimal_places <= config.max_decimals - + # Verify value range for num in numbers: value = float(num) assert config.min_value <= value <= config.max_value - + # Verify sorting direction = item["metadata"]["direction"] sorted_numbers = [float(x) for x in eval(item["answer"])] diff --git a/tests/test_prime_factorization.py b/tests/test_prime_factorization.py index bf3b8b2e..b3b8695a 100644 --- a/tests/test_prime_factorization.py +++ b/tests/test_prime_factorization.py @@ -1,10 +1,8 @@ """Tests for prime factorization task generation""" + import pytest -from reasoning_gym.arithmetic.prime_factorization import ( - PrimeFactorizationConfig, - PrimeFactorizationDataset, -) +from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset def test_prime_factorization_config_validation(): @@ -30,12 +28,7 @@ def test_prime_factorization_dataset_deterministic(): def test_prime_factorization_dataset_items(): """Test basic properties of generated items""" - config = PrimeFactorizationConfig( - min_value=2, - max_value=100, - size=10, - seed=42 - ) + config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42) dataset = PrimeFactorizationDataset(config) for i in range(len(dataset)): @@ -45,26 +38,26 @@ def test_prime_factorization_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "number" in item["metadata"] assert "factors" in item["metadata"] - + # Verify value range number = item["metadata"]["number"] assert config.min_value <= number <= config.max_value - + # Verify factorization is correct factors = item["metadata"]["factors"] product = 1 for factor in factors: product *= factor assert product == number - + # Verify factors are prime for factor in factors: assert is_prime(factor), f"{factor} is not prime" - + # Verify answer format assert item["answer"] == " × ".join(map(str, factors)) @@ -83,15 +76,10 @@ def test_prime_factorization_dataset_iteration(): def test_prime_factorization_known_values(): """Test factorization of known values""" - config = PrimeFactorizationConfig( - min_value=12, - max_value=12, # Force specific number - size=1, - seed=42 - ) + config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number dataset = PrimeFactorizationDataset(config) item = dataset[0] - + assert item["metadata"]["number"] == 12 assert item["metadata"]["factors"] == [2, 2, 3] assert item["answer"] == "2 × 2 × 3" @@ -101,7 +89,7 @@ def is_prime(n: int) -> bool: """Helper function to check if a number is prime""" if n < 2: return False - for i in range(2, int(n ** 0.5) + 1): + for i in range(2, int(n**0.5) + 1): if n % i == 0: return False return True diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 6d383284..883be954 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -23,14 +23,14 @@ def test_pattern_rule(): # Test simple addition rule = PatternRule([Operation.ADD], [2]) assert rule.apply([1, 3], 1) == 5 - + # Test composition rule = PatternRule([Operation.DOUBLE, Operation.ADD], [0, 3]) assert rule.apply([1, 4], 1) == 11 # (4 * 2) + 3 # Test rule composition rule1 = PatternRule([Operation.DOUBLE], [0]) # Double the number - rule2 = PatternRule([Operation.ADD], [3]) # Add 3 + rule2 = PatternRule([Operation.ADD], [3]) # Add 3 composed = PatternRule.compose([rule1, rule2]) assert composed.apply([1, 4], 1) == 11 # (4 * 2) + 3 diff --git a/tests/test_sudoku.py b/tests/test_sudoku.py index da7986bd..ce41ce3a 100644 --- a/tests/test_sudoku.py +++ b/tests/test_sudoku.py @@ -1,10 +1,8 @@ """Tests for sudoku puzzle generation""" + import pytest -from reasoning_gym.games.sudoku import ( - SudokuConfig, - SudokuDataset, -) +from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset def test_sudoku_config_validation(): @@ -34,12 +32,7 @@ def test_sudoku_dataset_deterministic(): def test_sudoku_dataset_items(): """Test basic properties of generated items""" - config = SudokuConfig( - min_empty=30, - max_empty=40, - size=10, - seed=42 - ) + config = SudokuConfig(min_empty=30, max_empty=40, size=10, seed=42) dataset = SudokuDataset(config) for i in range(len(dataset)): @@ -49,30 +42,30 @@ def test_sudoku_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "puzzle" in item["metadata"] assert "solution" in item["metadata"] assert "num_empty" in item["metadata"] - + puzzle = item["metadata"]["puzzle"] solution = item["metadata"]["solution"] num_empty = item["metadata"]["num_empty"] - + # Verify board dimensions assert len(puzzle) == 9 assert all(len(row) == 9 for row in puzzle) assert len(solution) == 9 assert all(len(row) == 9 for row in solution) - + # Verify empty cell count empty_count = sum(1 for row in puzzle for cell in row if cell == 0) assert config.min_empty <= empty_count <= config.max_empty assert empty_count == num_empty - + # Verify solution validity assert is_valid_solution(solution) - + # Verify puzzle matches solution where filled for i in range(9): for j in range(9): @@ -94,14 +87,9 @@ def test_sudoku_dataset_iteration(): def test_sudoku_board_generation(): """Test that generated boards are valid""" - config = SudokuConfig( - min_empty=0, # Force complete board - max_empty=0, - size=5, - seed=42 - ) + config = SudokuConfig(min_empty=0, max_empty=0, size=5, seed=42) # Force complete board dataset = SudokuDataset(config) - + for i in range(len(dataset)): item = dataset[i] board = item["metadata"]["solution"] @@ -114,21 +102,21 @@ def is_valid_solution(board: list[list[int]]) -> bool: for row in board: if set(row) != set(range(1, 10)): return False - + # Check columns for j in range(9): column = [board[i][j] for i in range(9)] if set(column) != set(range(1, 10)): return False - + # Check 3x3 boxes for box_i in range(3): for box_j in range(3): box = [] for i in range(3): for j in range(3): - box.append(board[box_i*3 + i][box_j*3 + j]) + box.append(board[box_i * 3 + i][box_j * 3 + j]) if set(box) != set(range(1, 10)): return False - + return True diff --git a/tests/test_word_reversal.py b/tests/test_word_reversal.py index 69b27f08..eec4bef1 100644 --- a/tests/test_word_reversal.py +++ b/tests/test_word_reversal.py @@ -1,10 +1,8 @@ """Tests for word reversal task generation""" + import pytest -from reasoning_gym.algorithmic.word_reversal import ( - WordReversalConfig, - WordReversalDataset, -) +from reasoning_gym.algorithmic.word_reversal import WordReversalConfig, WordReversalDataset def test_word_reversal_config_validation(): @@ -30,12 +28,7 @@ def test_word_reversal_dataset_deterministic(): def test_word_reversal_dataset_items(): """Test basic properties of generated items""" - config = WordReversalConfig( - min_words=3, - max_words=6, - size=10, - seed=42 - ) + config = WordReversalConfig(min_words=3, max_words=6, size=10, seed=42) dataset = WordReversalDataset(config) for i in range(len(dataset)): @@ -45,16 +38,16 @@ def test_word_reversal_dataset_items(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "num_words" in item["metadata"] assert "words" in item["metadata"] - + # Verify word count constraints words = item["metadata"]["words"] assert len(words) >= config.min_words assert len(words) <= config.max_words - + # Verify reversal is correct question_words = [w.strip() for w in item["question"].split(":")[1].strip().split(",")] answer_words = item["answer"].split(", ") @@ -77,7 +70,7 @@ def test_word_reversal_text_preprocessing(): """Test that text preprocessing handles edge cases""" config = WordReversalConfig(size=1, seed=42) dataset = WordReversalDataset(config) - + # Verify words were extracted from text assert len(dataset.words) > 0 # Verify words contain only alphanumeric characters