diff --git a/README.md b/README.md index f7855146..ed15e319 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ The goal is to generate virtually infinite data with adjustable complexity. #### Arithmetic Tasks - `ArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, *) - `ChainSum`: Generate addition/subtraction chains with configurable length and digit counts +- `FractionSimplificationDataset`: Generate fraction simplification tasks with configurable complexity - `GCDDataset`: Generate Greatest Common Divisor problems with configurable number of integers - `LCMDataset`: Generate Least Common Multiple problems with configurable number of integers - `LegCountingDataset`: Generate animal leg counting word problems with various animals diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 3d8b3329..c1614f99 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -8,6 +8,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import ArithmeticDataset, ArithmeticDatasetConfig, arithmetic_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset +from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset, fraction_simplification_dataset from .gcd import GCDConfig, GCDDataset, gcd_dataset from .lcm import LCMConfig, LCMDataset, lcm_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset @@ -20,6 +21,9 @@ __all__ = [ "ChainSum", "ChainSumConfig", "chain_sum_dataset", + "FractionSimplificationConfig", + "FractionSimplificationDataset", + "fraction_simplification_dataset", "GCDConfig", "GCDDataset", "gcd_dataset", diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py new file mode 100644 index 00000000..756961f6 --- /dev/null +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -0,0 +1,103 @@ +"""Fraction simplification task generator""" +from dataclasses import dataclass +from random import Random +from typing import List, Optional, Tuple +from math import gcd + + +@dataclass +class FractionSimplificationConfig: + """Configuration for fraction simplification task generation""" + min_value: int = 1 # Minimum value for numerator/denominator + max_value: int = 100 # Maximum value for numerator/denominator + min_factor: int = 2 # Minimum multiplication factor + max_factor: int = 10 # Maximum multiplication factor + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self): + """Validate configuration parameters""" + assert self.min_value >= 1, "min_value must be positive" + assert self.max_value > self.min_value, "max_value must be > min_value" + assert self.min_factor >= 2, "min_factor must be at least 2" + assert self.max_factor >= self.min_factor, "max_factor must be >= min_factor" + + +class FractionSimplificationDataset: + """Generates fraction simplification tasks""" + + def __init__(self, config: FractionSimplificationConfig): + self.config = config + self.config.validate() + self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]: + """Generate a random fraction and its simplified form. + Returns (numerator, denominator, simplified_num, simplified_den)""" + # Generate the simplified fraction first + simplified_num = rng.randint(self.config.min_value, self.config.max_value) + simplified_den = rng.randint(self.config.min_value, self.config.max_value) + + # Make sure they're coprime by dividing by their GCD + common = gcd(simplified_num, simplified_den) + simplified_num //= common + simplified_den //= common + + # Multiply both by a random factor to create the unsimplified version + factor = rng.randint(self.config.min_factor, self.config.max_factor) + numerator = simplified_num * factor + denominator = simplified_den * factor + + return numerator, denominator, simplified_num, simplified_den + + def __getitem__(self, idx: int) -> dict: + """Generate a single fraction simplification task""" + rng = Random(self.seed + idx) + + num, den, simple_num, simple_den = self._generate_fraction(rng) + + return { + "question": f"Simplify the fraction {num}/{den} to its lowest terms", + "answer": f"{simple_num}/{simple_den}", + "metadata": { + "numerator": num, + "denominator": den, + "simplified_numerator": simple_num, + "simplified_denominator": simple_den, + "reduction_factor": num // simple_num # Will be same as den // simple_den + } + } + + +def fraction_simplification_dataset( + min_value: int = 1, + max_value: int = 100, + min_factor: int = 2, + max_factor: int = 10, + seed: Optional[int] = None, + size: int = 500, +) -> FractionSimplificationDataset: + """Create a FractionSimplificationDataset with the given configuration.""" + config = FractionSimplificationConfig( + min_value=min_value, + max_value=max_value, + min_factor=min_factor, + max_factor=max_factor, + seed=seed, + size=size, + ) + return FractionSimplificationDataset(config) diff --git a/tests/test_fraction_simplification.py b/tests/test_fraction_simplification.py new file mode 100644 index 00000000..97dcb01d --- /dev/null +++ b/tests/test_fraction_simplification.py @@ -0,0 +1,123 @@ +import pytest +from math import gcd +from reasoning_gym.arithmetic import FractionSimplificationDataset, FractionSimplificationConfig + + +def test_fraction_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = FractionSimplificationConfig(min_value=0) # Should be positive + config.validate() + + with pytest.raises(AssertionError): + config = FractionSimplificationConfig(min_value=100, max_value=50) # max should be > min + config.validate() + + with pytest.raises(AssertionError): + config = FractionSimplificationConfig(min_factor=1) # Should be >= 2 + config.validate() + + with pytest.raises(AssertionError): + config = FractionSimplificationConfig(min_factor=5, max_factor=3) # max should be >= min + config.validate() + + +def test_fraction_deterministic(): + """Test that dataset generates same items with same seed""" + config = FractionSimplificationConfig(seed=42, size=10) + dataset1 = FractionSimplificationDataset(config) + dataset2 = FractionSimplificationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_fraction_items(): + """Test basic properties of generated items""" + config = FractionSimplificationConfig( + min_value=1, + max_value=20, + min_factor=2, + max_factor=5, + size=50, + seed=42 + ) + dataset = FractionSimplificationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify the metadata contains all expected fields + metadata = item["metadata"] + assert "numerator" in metadata + assert "denominator" in metadata + assert "simplified_numerator" in metadata + assert "simplified_denominator" in metadata + assert "reduction_factor" in metadata + + # Verify the numbers are within configured range + assert config.min_value <= metadata["simplified_numerator"] <= config.max_value + assert config.min_value <= metadata["simplified_denominator"] <= config.max_value + + # Verify the reduction is correct + num = metadata["numerator"] + den = metadata["denominator"] + simple_num = metadata["simplified_numerator"] + simple_den = metadata["simplified_denominator"] + factor = metadata["reduction_factor"] + + assert num == simple_num * factor + assert den == simple_den * factor + + # Verify the simplified fraction is actually in lowest terms + assert gcd(simple_num, simple_den) == 1 + + +def test_fraction_ranges(): + """Test that generated numbers respect value constraints""" + config = FractionSimplificationConfig( + min_value=5, + max_value=15, + min_factor=3, + max_factor=4, + size=20, + seed=42 + ) + dataset = FractionSimplificationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + metadata = item["metadata"] + factor = metadata["reduction_factor"] + + # Check factor is within bounds + assert 3 <= factor <= 4 + + # Check simplified values are within bounds + assert 5 <= metadata["simplified_numerator"] <= 15 + assert 5 <= metadata["simplified_denominator"] <= 15 + + +def test_fraction_iteration(): + """Test that iteration works correctly""" + config = FractionSimplificationConfig(size=5, seed=42) + dataset = FractionSimplificationDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size + + # Test list conversion + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items