diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index a4766ebc..453601fa 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -1,12 +1,16 @@ """Fraction simplification task generator""" +import re from dataclasses import dataclass from math import gcd from random import Random -from typing import Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPLATE = """Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction + as your final answer.""" + @dataclass class FractionSimplificationConfig: @@ -107,7 +111,7 @@ class FractionSimplificationDataset(ProceduralDataset): answer_fraction = self._format_fraction(simple_num, simple_den, style) return { - "question": f"Simplify the fraction {question_fraction} to its lowest terms", + "question": QUESTION_TEMPLATE.format(question_fraction=question_fraction), "answer": answer_fraction, "metadata": { "numerator": num, @@ -119,5 +123,34 @@ class FractionSimplificationDataset(ProceduralDataset): }, } + def _extract_fraction(self, answer: Optional[str]): + try: + cleaned = answer.strip().strip("$").strip() + latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE) + if latex_match: + return int(latex_match.group(1)), int(latex_match.group(2)) + if "/" in cleaned: + numerator, denominator = map(str.strip, cleaned.split("/", 1)) + return int(numerator), int(denominator) + except: + return None + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]): + reward = 0.0 + metadata = entry["metadata"] + try: + numerator, denominator = self._extract_fraction(answer) + if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]: + reward = 1.0 + elif numerator == metadata["numerator"] or denominator == metadata["denominator"]: + reward = 0.1 + elif len(answer.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)