added fraction simplifications score answer impl

This commit is contained in:
joesharratt1229 2025-02-16 15:08:24 +00:00
parent 1a33fba608
commit 3f731029dd

View file

@ -1,12 +1,16 @@
"""Fraction simplification task generator"""
import re
from dataclasses import dataclass
from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction
as your final answer."""
@dataclass
class FractionSimplificationConfig:
@ -107,7 +111,7 @@ class FractionSimplificationDataset(ProceduralDataset):
answer_fraction = self._format_fraction(simple_num, simple_den, style)
return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
"answer": answer_fraction,
"metadata": {
"numerator": num,
@ -119,5 +123,34 @@ class FractionSimplificationDataset(ProceduralDataset):
},
}
def _extract_fraction(self, answer: Optional[str]):
try:
cleaned = answer.strip().strip("$").strip()
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
if latex_match:
return int(latex_match.group(1)), int(latex_match.group(2))
if "/" in cleaned:
numerator, denominator = map(str.strip, cleaned.split("/", 1))
return int(numerator), int(denominator)
except:
return None
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
reward = 0.0
metadata = entry["metadata"]
try:
numerator, denominator = self._extract_fraction(answer)
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
reward = 1.0
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
reward = 0.1
elif len(answer.strip()) > 0:
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)