diff --git a/README.md b/README.md index 6af65067..f7855146 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ The goal is to generate virtually infinite data with adjustable complexity. - `ArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, *) - `ChainSum`: Generate addition/subtraction chains with configurable length and digit counts - `GCDDataset`: Generate Greatest Common Divisor problems with configurable number of integers +- `LCMDataset`: Generate Least Common Multiple problems with configurable number of integers - `LegCountingDataset`: Generate animal leg counting word problems with various animals - `PrimeFactorizationDataset`: Generate prime factorization tasks with configurable number ranges diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 1aed0ae1..3d8b3329 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -9,6 +9,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import ArithmeticDataset, ArithmeticDatasetConfig, arithmetic_dataset from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .gcd import GCDConfig, GCDDataset, gcd_dataset +from .lcm import LCMConfig, LCMDataset, lcm_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset @@ -22,6 +23,9 @@ __all__ = [ "GCDConfig", "GCDDataset", "gcd_dataset", + "LCMConfig", + "LCMDataset", + "lcm_dataset", "LegCountingConfig", "LegCountingDataset", "leg_counting_dataset", diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py new file mode 100644 index 00000000..3a840f49 --- /dev/null +++ b/reasoning_gym/arithmetic/lcm.py @@ -0,0 +1,95 @@ +"""Least Common Multiple (LCM) task generator""" +from dataclasses import dataclass +from random import Random +from typing import List, Optional, Tuple +from math import lcm +from functools import reduce + + +@dataclass +class LCMConfig: + """Configuration for LCM task generation""" + min_numbers: int = 2 # Minimum numbers to find LCM of + max_numbers: int = 2 # Maximum numbers to find LCM of + min_value: int = 1 # Minimum value for each number + max_value: int = 100 # Maximum value for each number (kept smaller than GCD default since LCM grows fast) + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self): + """Validate configuration parameters""" + assert self.min_numbers >= 2, "min_numbers must be at least 2" + assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers" + assert self.min_value >= 1, "min_value must be positive" + assert self.max_value > self.min_value, "max_value must be > min_value" + + +class LCMDataset: + """Generates Least Common Multiple (LCM) tasks""" + + def __init__(self, config: LCMConfig): + self.config = config + self.config.validate() + self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _generate_numbers(self, rng: Random) -> List[int]: + """Generate a list of random positive integers""" + num_count = rng.randint(self.config.min_numbers, self.config.max_numbers) + return [rng.randint(self.config.min_value, self.config.max_value) + for _ in range(num_count)] + + def _calculate_lcm(self, numbers: List[int]) -> int: + """Calculate the LCM of a list of numbers""" + return reduce(lcm, numbers) + + def __getitem__(self, idx: int) -> dict: + """Generate a single LCM task""" + rng = Random(self.seed + idx) + + numbers = self._generate_numbers(rng) + result = self._calculate_lcm(numbers) + + numbers_str = ", ".join(str(n) for n in numbers) + + return { + "question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}", + "answer": str(result), + "metadata": { + "numbers": numbers, + "result": result + } + } + + +def lcm_dataset( + min_numbers: int = 2, + max_numbers: int = 2, + min_value: int = 1, + max_value: int = 100, + seed: Optional[int] = None, + size: int = 500, +) -> LCMDataset: + """Create a LCMDataset with the given configuration.""" + config = LCMConfig( + min_numbers=min_numbers, + max_numbers=max_numbers, + min_value=min_value, + max_value=max_value, + seed=seed, + size=size, + ) + return LCMDataset(config) diff --git a/tests/test_lcm.py b/tests/test_lcm.py new file mode 100644 index 00000000..029eea47 --- /dev/null +++ b/tests/test_lcm.py @@ -0,0 +1,139 @@ +import pytest +from math import lcm +from functools import reduce +from reasoning_gym.arithmetic import LCMDataset, LCMConfig + + +def test_lcm_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = LCMConfig(min_numbers=1) # Should be >= 2 + config.validate() + + with pytest.raises(AssertionError): + config = LCMConfig(min_numbers=3, max_numbers=2) # max should be >= min + config.validate() + + with pytest.raises(AssertionError): + config = LCMConfig(min_value=0) # Should be positive + config.validate() + + with pytest.raises(AssertionError): + config = LCMConfig(min_value=100, max_value=50) # max should be > min + config.validate() + + +def test_lcm_deterministic(): + """Test that dataset generates same items with same seed""" + config = LCMConfig(seed=42, size=10) + dataset1 = LCMDataset(config) + dataset2 = LCMDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_lcm_items(): + """Test basic properties of generated items""" + config = LCMConfig( + min_numbers=2, + max_numbers=4, + min_value=1, + max_value=20, # Keep small for testing + size=50, + seed=42 + ) + dataset = LCMDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify the numbers and result are in metadata + metadata = item["metadata"] + assert "numbers" in metadata + assert "result" in metadata + + # Verify the numbers are within configured range + numbers = metadata["numbers"] + assert all(config.min_value <= n <= config.max_value for n in numbers) + assert config.min_numbers <= len(numbers) <= config.max_numbers + + # Verify the LCM calculation is correct + result = metadata["result"] + assert str(result) == item["answer"] + assert result == reduce(lcm, numbers) + + +def test_lcm_number_ranges(): + """Test that generated numbers respect value constraints""" + config = LCMConfig( + min_numbers=2, + max_numbers=2, + min_value=5, + max_value=15, + size=20, + seed=42 + ) + dataset = LCMDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + numbers = item["metadata"]["numbers"] + assert all(5 <= n <= 15 for n in numbers) + + +def test_lcm_iteration(): + """Test that iteration works correctly""" + config = LCMConfig(size=5, seed=42) + dataset = LCMDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size + + # Test list conversion + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items + + +def test_lcm_special_cases(): + """Test some special LCM cases""" + config = LCMConfig( + min_numbers=2, + max_numbers=2, + min_value=1, + max_value=20, + size=100, + seed=42 + ) + dataset = LCMDataset(config) + + # Track if we see some interesting LCM cases + seen_equal_to_product = False # When numbers are coprime + seen_less_than_product = False # When numbers share factors + + for i in range(len(dataset)): + item = dataset[i] + numbers = item["metadata"]["numbers"] + result = int(item["answer"]) + product = reduce(lambda x, y: x * y, numbers) + + if result == product: + seen_equal_to_product = True + if result < product: + seen_less_than_product = True + + # With enough samples, we should see both cases + assert seen_equal_to_product, "Expected to see some coprime numbers (LCM = product)" + assert seen_less_than_product, "Expected to see some numbers with common factors (LCM < product)"