diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 4c10e3b2..cbfee4b1 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -14,6 +14,7 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset +from .products import Products, ProductsConfig from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ @@ -35,6 +36,8 @@ __all__ = [ "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", + "Products", + "ProductsConfig", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", "TimeIntervalsConfig", diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py new file mode 100644 index 00000000..62ae291c --- /dev/null +++ b/reasoning_gym/arithmetic/products.py @@ -0,0 +1,130 @@ +import random +from dataclasses import dataclass +from typing import Optional + +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ProductsConfig: + """Configuration for products task generation""" + + min_terms: int = 2 + max_terms: int = 2 + min_digits: int = 1 + max_digits: int = 5 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + + +class Products(ProceduralDataset): + """Generates multiplication tasks with configurable number of terms""" + + def __init__(self, config: ProductsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single multiplication task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted multiplication expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + # Create deterministic RNG from base seed and idx + item_rng = random.Random(self.seed + idx) + + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + + # Calculate value ranges based on number of digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + + return { + "question": f"{expression} =", + "answer": str(result), + "metadata": { + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, + "expression": expression, + }, + } + + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: + """Generate a multiplication task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers + + Returns: + Tuple of (expression string, result integer) + """ + # Generate random numbers within the specified range + constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] + + # Build expression and compute result + expression_parts = [] + result = constants[0] + + expression_parts.append(str(constants[0])) + for i in range(1, len(constants)): + expression_parts.append("*") + expression_parts.append(str(constants[i])) + result *= constants[i] + + expression = " ".join(expression_parts) + return expression, result + + +class ProductsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ProductsCurriculum.__name__, ProductsConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 3, 4], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + + +# Register the dataset +register_dataset("products", Products, ProductsConfig) diff --git a/tests/test_products.py b/tests/test_products.py new file mode 100644 index 00000000..a569209c --- /dev/null +++ b/tests/test_products.py @@ -0,0 +1,125 @@ +import pytest + +from reasoning_gym.arithmetic import Products, ProductsConfig +from reasoning_gym.arithmetic.products import ProductsCurriculum + + +def test_products_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_products_deterministic(): + """Test that dataset generates same items with same seed""" + config = ProductsConfig(seed=42, size=10) + dataset1 = Products(config) + dataset2 = Products(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_products_items(): + """Test basic properties of generated items""" + config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) + dataset = Products(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only * is used + expression = item["metadata"]["expression"] + assert all(op in ["*", " "] or op.isdigit() for op in expression) + + # Verify the answer matches the expression + answer = eval(expression) # Safe here as we control the expression + assert str(answer) == item["answer"] + + +def test_products_number_ranges(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers + config = ProductsConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, # Should generate numbers >= 100 + max_digits=3, # Should generate numbers <= 999 + size=50, + seed=42, + ) + dataset = Products(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) + dataset = Products(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + +def test_products_iteration(): + """Test that iteration respects dataset size""" + config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing + dataset = Products(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_products_curriculum(): + curriculum = ProductsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ProductsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + + # test incrementing attribute levels for num_terms & num_digits attributes + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 + + # test decrementing attribute level for num_digits again + curriculum.decrement_attr_level("num_digits") + + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1 + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3