diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 6d2e43e6..05c19779 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset): expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { - "question": f"{expression} =", + "question": f"State the final answer to the following arithmetic problem: {expression} =", "answer": str(result), "metadata": { "difficulty": { diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index a4766ebc..453601fa 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -1,12 +1,16 @@ """Fraction simplification task generator""" +import re from dataclasses import dataclass from math import gcd from random import Random -from typing import Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPLATE = """Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction + as your final answer.""" + @dataclass class FractionSimplificationConfig: @@ -107,7 +111,7 @@ class FractionSimplificationDataset(ProceduralDataset): answer_fraction = self._format_fraction(simple_num, simple_den, style) return { - "question": f"Simplify the fraction {question_fraction} to its lowest terms", + "question": QUESTION_TEMPLATE.format(question_fraction=question_fraction), "answer": answer_fraction, "metadata": { "numerator": num, @@ -119,5 +123,34 @@ class FractionSimplificationDataset(ProceduralDataset): }, } + def _extract_fraction(self, answer: Optional[str]): + try: + cleaned = answer.strip().strip("$").strip() + latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE) + if latex_match: + return int(latex_match.group(1)), int(latex_match.group(2)) + if "/" in cleaned: + numerator, denominator = map(str.strip, cleaned.split("/", 1)) + return int(numerator), int(denominator) + except: + return None + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]): + reward = 0.0 + metadata = entry["metadata"] + try: + numerator, denominator = self._extract_fraction(answer) + if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]: + reward = 1.0 + elif numerator == metadata["numerator"] or denominator == metadata["denominator"]: + reward = 0.1 + elif len(answer.strip()) > 0: + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig) diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index ce30a127..57c36296 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -57,7 +57,9 @@ class GCDDataset(ProceduralDataset): numbers_str = ", ".join(str(n) for n in numbers) return { - "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}", + "question": f"""Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the + GCD as your final answer. + """, "answer": str(result), "metadata": {"numbers": numbers, "result": result}, } diff --git a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py index 2ac45d9b..13b0b372 100644 --- a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py +++ b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py @@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset): rng = Random(self.seed + idx) generator_idx = self.task_indices[idx] generator = self.generators[generator_idx] - return generator(rng, self.config.difficulty) + example = generator(rng, self.config.difficulty) + example["question"] += " Give only the result as your final answer." + return example register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 742696ec..3a25077f 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset): expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { - "question": f"{expression} =", + "question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.", "answer": str(result), "metadata": { "difficulty": { diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 7130a11d..00448280 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -187,7 +187,7 @@ class AliceInWonderlandDataset(ProceduralDataset): num_female_colleagues_bob_circle=num_female_colleagues_bob_circle, ) - return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}} + return {"question": question, "answer": str(answer), "metadata": {"task_type": task_type.value}} def __getitem__(self, idx: int) -> dict: rng = Random(self.seed + idx)