diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index d8c9d2af..55ec14ad 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -1,8 +1,9 @@ """Prime factorization task generator""" +import math from dataclasses import dataclass from random import Random -from typing import List, Optional +from typing import Dict, List, Optional from ..factory import ProceduralDataset, register_dataset @@ -43,6 +44,25 @@ class PrimeFactorizationDataset(ProceduralDataset): break return factors + def _normalize_answer(self, answer: str) -> List[int]: + """Parse and sort factors from a string""" + return sorted([int(factor.strip()) for factor in answer.split("×")]) + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + oracle_answer = entry["answer"] + reward = 0.0 + if answer is not None: + oracle_answer_parsed = self._normalize_answer(oracle_answer) + answer_parsed = self._normalize_answer(answer) + if oracle_answer_parsed == answer_parsed: + reward = 1.0 + elif math.prod(oracle_answer_parsed) == math.prod(answer_parsed): + reward = 0.5 + else: + reward = 0.01 + + return reward + def __getitem__(self, idx: int) -> dict: """Generate a single prime factorization task""" rng = Random(self.seed + idx) diff --git a/tests/test_prime_factorization.py b/tests/test_prime_factorization.py index b3b8695a..70c463b7 100644 --- a/tests/test_prime_factorization.py +++ b/tests/test_prime_factorization.py @@ -85,6 +85,37 @@ def test_prime_factorization_known_values(): assert item["answer"] == "2 × 2 × 3" +def test_prime_factorization_score_answer(): + """Test scoring of answers""" + config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number + dataset = PrimeFactorizationDataset(config) + item = dataset[0] + + # Perfectly ordered answer + answer = "2 × 2 × 3" + assert dataset.score_answer(answer, item) == 1.0 + + # No white spaces answer (still correct) + answer = "2×2×3" + assert dataset.score_answer(answer, item) == 1.0 + + # Shuffled factors (still correct) + answer = "2 × 3 × 2" + assert dataset.score_answer(answer, item) == 1.0 + + # Partially correct answer (not all numbers are fully factorized) + answer = "2 × 6" + assert dataset.score_answer(answer, item) == 0.5 + + # Incorrect answer + answer = "2 × 5" + assert dataset.score_answer(answer, item) == 0.01 + + # Answer is none + answer = None + assert dataset.score_answer(answer, item) == 0.0 + + def is_prime(n: int) -> bool: """Helper function to check if a number is prime""" if n < 2: