diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 495a79c5..1d2a7ba8 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset +from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py new file mode 100644 index 00000000..422c1de6 --- /dev/null +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -0,0 +1,127 @@ +import random +from dataclasses import dataclass +from typing import Optional + +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class DecimalChainSumConfig: + """Configuration for decimal chain sum task generation""" + + min_terms: int = 2 + max_terms: int = 6 + min_digits: int = 1 + max_digits: int = 4 + min_decimal_places: int = 1 + max_decimal_places: int = 4 + allow_negation: bool = False + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + assert self.min_decimal_places >= 0, "min_decimal_places must be non-negative" + assert self.max_decimal_places >= self.min_decimal_places, "max_decimal_places must be >= min_decimal_places" + + +class DecimalChainSumDataset(ProceduralDataset): + """Generates simple decimal arithmetic tasks using only + and - operators""" + + def __init__(self, config: DecimalChainSumConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single decimal chain sum task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted arithmetic expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + + rng = random.Random(self.seed + idx) + + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) + + # Calculate value ranges based on number of digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(rng, num_terms, min_value, max_value) + + return { + "question": f"State the final answer to the following arithmetic problem: {expression} =", + "answer": str(result), + "metadata": { + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, + "expression": expression, + }, + } + + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, float]: + """Generate a single decimal chain sum task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers + min_decimal_places: Minimum number of decimal places + max_decimal_places: Maximum number of decimal places + + Returns: + Tuple of (expression string, result float) + """ + + if self.config.allow_negation: + # Allow both positive and negative numbers + constants = [rng.randint(-max_value, max_value) for _ in range(num_terms)] + else: + # Only positive numbers + constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] + + # Generate decimal places for each term + decimal_places = [ + rng.randint(self.config.min_decimal_places, self.config.max_decimal_places) for _ in range(num_terms) + ] + + for i in range(num_terms): + min_val = 0 if decimal_places[i] == 0 else 10 ** (decimal_places[i] - 1) + max_val = (10 ** decimal_places[i]) - 1 + decimal = rng.randint(min_val, max_val) + constants[i] += decimal / 10 ** decimal_places[i] + + operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)] + + expression_parts = [] + result = constants[0] + + expression_parts.append(f"{constants[0]:.{max(decimal_places)}f}") + for i, op in enumerate(operators): + c = constants[i + 1] + expression_parts.append(op) + expression_parts.append(f"{c:.{max(decimal_places)}f}") + + if op == "+": + result += c + else: # op == "-" + result -= c + + expression = " ".join(expression_parts) + result = round(result, max(decimal_places)) + return expression, result diff --git a/tests/test_decimal_chain_sum.py b/tests/test_decimal_chain_sum.py new file mode 100644 index 00000000..d9eeb79d --- /dev/null +++ b/tests/test_decimal_chain_sum.py @@ -0,0 +1,217 @@ +import pytest + +from reasoning_gym.arithmetic import DecimalChainSumConfig, DecimalChainSumDataset + + +def test_decimal_chain_sum_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = DecimalChainSumConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = DecimalChainSumConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_decimal_chain_sum_deterministic(): + """Test that dataset generates same items with same seed""" + config = DecimalChainSumConfig(seed=42, size=10) + dataset1 = DecimalChainSumDataset(config) + dataset2 = DecimalChainSumDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_decimal_chain_sum_items(): + """Test basic properties of generated items""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=4, + min_digits=1, + max_digits=2, + min_decimal_places=1, + max_decimal_places=2, + size=100, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only + and - are used + expression = item["metadata"]["expression"] + assert all(op in ["+", "-", " "] or op.isdigit() for op in expression) + + # Check for floating point errors + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + # Verify no numbers have more decimal places than max_decimal_places + if "." in num: + decimal_places = len(num.split(".")[-1]) + assert decimal_places <= config.max_decimal_places, f"Number {num} has more decimal places than allowed" + + # Verify answer has correct precision + answer_str = item["answer"] + if "." in answer_str: + decimal_places = len(answer_str.split(".")[-1]) + assert ( + decimal_places <= config.max_decimal_places + ), f"Answer {answer_str} has more decimal places than allowed" + + # Verify mathematical correctness within epsilon + expected = eval(expression) + assert ( + abs(float(item["answer"]) - expected) < 1e-10 + ), f"Answer {item['answer']} doesn't match expected {expected}" + + +def test_chain_sum_number_ranges(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, + max_digits=3, + min_decimal_places=1, + max_decimal_places=4, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + +def test_decimal_chain_sum_negation(): + """Test that negation is properly handled""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + allow_negation=True, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + has_positive = False + has_negative = False + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())] + for num in numbers: + if num > 0: + has_positive = True + if num < 0: + has_negative = True + + assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True" + + +def test_decimal_chain_sum_iteration(): + """Test that iteration respects dataset size""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=1, + min_decimal_places=1, + max_decimal_places=4, + size=5, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_decimal_places_generation(): + """Test that generated decimal numbers have correct number of decimal places""" + # Test fixed decimal places + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=2, + max_decimal_places=2, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for item in dataset: + expression = item["metadata"]["expression"] + # Extract numbers including decimals + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + decimal_part = num.split(".")[-1] + assert len(decimal_part) == 2, f"Number {num} should have exactly 2 decimal places" + + # Test varying decimal places + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=1, + max_decimal_places=3, + size=50, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + + for item in dataset: + expression = item["metadata"]["expression"] + numbers = [n for n in expression.split() if any(c.isdigit() for c in n)] + for num in numbers: + decimal_part = num.split(".")[-1] + assert 1 <= len(decimal_part) <= 3, f"Number {num} should have between 1 and 3 decimal places"