diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py index 422c1de6..eaf5c7ea 100644 --- a/reasoning_gym/arithmetic/decimal_chain_sum.py +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -1,6 +1,7 @@ import random from dataclasses import dataclass -from typing import Optional +from decimal import Decimal +from typing import Any, Dict, Optional from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -73,7 +74,7 @@ class DecimalChainSumDataset(ProceduralDataset): }, } - def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, float]: + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, Decimal]: """Generate a single decimal chain sum task Args: @@ -85,37 +86,43 @@ class DecimalChainSumDataset(ProceduralDataset): max_decimal_places: Maximum number of decimal places Returns: - Tuple of (expression string, result float) + Tuple of (expression string, result Decimal) """ - if self.config.allow_negation: - # Allow both positive and negative numbers - constants = [rng.randint(-max_value, max_value) for _ in range(num_terms)] - else: - # Only positive numbers - constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] + # Convert constants to Decimal + constants = [ + Decimal( + str( + rng.randint(-max_value, max_value) + if self.config.allow_negation + else rng.randint(min_value, max_value) + ) + ) + for _ in range(num_terms) + ] # Generate decimal places for each term decimal_places = [ rng.randint(self.config.min_decimal_places, self.config.max_decimal_places) for _ in range(num_terms) ] + # Add decimal parts using Decimal for precise arithmetic for i in range(num_terms): min_val = 0 if decimal_places[i] == 0 else 10 ** (decimal_places[i] - 1) max_val = (10 ** decimal_places[i]) - 1 - decimal = rng.randint(min_val, max_val) - constants[i] += decimal / 10 ** decimal_places[i] + decimal_part = Decimal(str(rng.randint(min_val, max_val))) / Decimal(str(10 ** decimal_places[i])) + constants[i] += decimal_part operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)] expression_parts = [] result = constants[0] - expression_parts.append(f"{constants[0]:.{max(decimal_places)}f}") + expression_parts.append(f"{constants[0]:.{decimal_places[0]}f}") for i, op in enumerate(operators): c = constants[i + 1] expression_parts.append(op) - expression_parts.append(f"{c:.{max(decimal_places)}f}") + expression_parts.append(f"{c:.{decimal_places[i+1]}f}") if op == "+": result += c @@ -123,5 +130,25 @@ class DecimalChainSumDataset(ProceduralDataset): result -= c expression = " ".join(expression_parts) - result = round(result, max(decimal_places)) + result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}")) return expression, result + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """Score the answer by comparing decimal values instead of strings. + Args: + answer: The answer to score + entry: The entry containing the oracle answer + + Returns: + 1.0 for exact numerical match, 0.01 otherwise + """ + if answer is None or len(answer.strip()) == 0: + return 0.0 + + try: + student_answer = Decimal(answer.strip()) + oracle_answer = Decimal(entry["answer"]) + + return 1.0 if student_answer == oracle_answer else 0.01 + except (ValueError, TypeError, ArithmeticError): + return 0.01 diff --git a/tests/test_decimal_chain_sum.py b/tests/test_decimal_chain_sum.py index be4859d7..5114a7c7 100644 --- a/tests/test_decimal_chain_sum.py +++ b/tests/test_decimal_chain_sum.py @@ -136,8 +136,7 @@ def test_decimal_chain_sum_negation(): for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] - numbers = [int(n) for n in expression.split() if n.isdigit() or (n.startswith("-") and n[1:].isdigit())] - # numbers = [float(n) for n in expression.split() if n.replace(".", "").replace("-", "").isdigit()] + numbers = [float(n) for n in expression.split() if n.replace(".", "").replace("-", "").isdigit()] for num in numbers: if num > 0: has_positive = True @@ -216,3 +215,38 @@ def test_decimal_places_generation(): for num in numbers: decimal_part = num.split(".")[-1] assert 1 <= len(decimal_part) <= 3, f"Number {num} should have between 1 and 3 decimal places" + + +def test_decimal_precision_scoring(): + """Test that scoring handles decimal precision correctly""" + config = DecimalChainSumConfig( + min_terms=2, + max_terms=2, + min_digits=1, + max_digits=2, + min_decimal_places=2, + max_decimal_places=3, + size=1, + seed=42, + ) + dataset = DecimalChainSumDataset(config) + item = dataset[0] + + # Test exact matches with different representations + assert dataset.score_answer("1.200", {"answer": "1.2"}) == 1.0 + assert dataset.score_answer("1.20", {"answer": "1.200"}) == 1.0 + assert dataset.score_answer("-0.5", {"answer": "-0.500"}) == 1.0 + + # Test floating point precision edge cases + assert dataset.score_answer("0.1", {"answer": "0.100"}) == 1.0 + assert dataset.score_answer("0.3", {"answer": "0.300"}) == 1.0 + + # Test incorrect answers + assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.01 + + # Test invalid inputs + assert dataset.score_answer(None, {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("", {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.01