mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge pull request #144 from joesharratt1229/fix/arithmetic
Added fixes for arithmetic environments
This commit is contained in:
commit
95f179f34e
6 changed files with 44 additions and 7 deletions
|
|
@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
|
|||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||
|
||||
return {
|
||||
"question": f"{expression} =",
|
||||
"question": f"State the final answer to the following arithmetic problem: {expression} =",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
"""Fraction simplification task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction
|
||||
as your final answer."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FractionSimplificationConfig:
|
||||
|
|
@ -107,7 +111,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||
|
||||
return {
|
||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
||||
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
|
||||
"answer": answer_fraction,
|
||||
"metadata": {
|
||||
"numerator": num,
|
||||
|
|
@ -119,5 +123,34 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def _extract_fraction(self, answer: Optional[str]):
|
||||
try:
|
||||
cleaned = answer.strip().strip("$").strip()
|
||||
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
|
||||
if latex_match:
|
||||
return int(latex_match.group(1)), int(latex_match.group(2))
|
||||
if "/" in cleaned:
|
||||
numerator, denominator = map(str.strip, cleaned.split("/", 1))
|
||||
return int(numerator), int(denominator)
|
||||
except:
|
||||
return None
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
try:
|
||||
numerator, denominator = self._extract_fraction(answer)
|
||||
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
|
||||
reward = 1.0
|
||||
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
|
||||
reward = 0.1
|
||||
elif len(answer.strip()) > 0:
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,9 @@ class GCDDataset(ProceduralDataset):
|
|||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||
"question": f"""Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the
|
||||
GCD as your final answer.
|
||||
""",
|
||||
"answer": str(result),
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
generator_idx = self.task_indices[idx]
|
||||
generator = self.generators[generator_idx]
|
||||
return generator(rng, self.config.difficulty)
|
||||
example = generator(rng, self.config.difficulty)
|
||||
example["question"] += " Give only the result as your final answer."
|
||||
return example
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
|
|||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||
|
||||
return {
|
||||
"question": f"{expression} =",
|
||||
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue