Merge pull request #144 from joesharratt1229/fix/arithmetic

Added fixes for arithmetic environments
This commit is contained in:
joesharratt1229 2025-02-16 16:34:09 +00:00 committed by GitHub
commit 95f179f34e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 44 additions and 7 deletions

View file

@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"difficulty": {

View file

@ -1,12 +1,16 @@
"""Fraction simplification task generator"""
import re
from dataclasses import dataclass
from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction
as your final answer."""
@dataclass
class FractionSimplificationConfig:
@ -107,7 +111,7 @@ class FractionSimplificationDataset(ProceduralDataset):
answer_fraction = self._format_fraction(simple_num, simple_den, style)
return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
"answer": answer_fraction,
"metadata": {
"numerator": num,
@ -119,5 +123,34 @@ class FractionSimplificationDataset(ProceduralDataset):
},
}
def _extract_fraction(self, answer: Optional[str]):
try:
cleaned = answer.strip().strip("$").strip()
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
if latex_match:
return int(latex_match.group(1)), int(latex_match.group(2))
if "/" in cleaned:
numerator, denominator = map(str.strip, cleaned.split("/", 1))
return int(numerator), int(denominator)
except:
return None
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
reward = 0.0
metadata = entry["metadata"]
try:
numerator, denominator = self._extract_fraction(answer)
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
reward = 1.0
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
reward = 0.1
elif len(answer.strip()) > 0:
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)

View file

@ -57,7 +57,9 @@ class GCDDataset(ProceduralDataset):
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"question": f"""Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the
GCD as your final answer.
""",
"answer": str(result),
"metadata": {"numbers": numbers, "result": result},
}

View file

@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx]
return generator(rng, self.config.difficulty)
example = generator(rng, self.config.difficulty)
example["question"] += " Give only the result as your final answer."
return example
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result),
"metadata": {
"difficulty": {