Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath

This commit is contained in:
Rich Jones 2025-02-19 03:34:57 +01:00
commit 59229bd2d2
62 changed files with 4012 additions and 478 deletions

View file

@ -64,6 +64,9 @@ class BasicArithmeticDataset(ProceduralDataset):
def __init__(self, config: BasicArithmeticDatasetConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.added_instruction = (
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task
@ -88,7 +91,7 @@ class BasicArithmeticDataset(ProceduralDataset):
else:
expression, result = self._generate_simple_task(rng, num_terms, num_digits)
question = self._format_question(rng, expression)
question = self._format_question(rng, expression) + self.added_instruction
return {
"question": question,
@ -223,12 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
return expression, result
def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style"""
"""Format the the question with the arithmetic expression"""
if self.config.format_style == "simple":
return f"{expression} ="
return f"Calculate {expression}."
else:
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
return rng.choice(templates).format(expression)
templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
template = rng.choice(templates)
return template.format(expression)
# Register the dataset

View file

@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"difficulty": {

View file

@ -1,12 +1,15 @@
"""Fraction simplification task generator"""
import re
from dataclasses import dataclass
from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
@dataclass
class FractionSimplificationConfig:
@ -107,7 +110,7 @@ class FractionSimplificationDataset(ProceduralDataset):
answer_fraction = self._format_fraction(simple_num, simple_den, style)
return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
"answer": answer_fraction,
"metadata": {
"numerator": num,
@ -119,5 +122,34 @@ class FractionSimplificationDataset(ProceduralDataset):
},
}
def _extract_fraction(self, answer: Optional[str]):
try:
cleaned = answer.strip().strip("$").strip()
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
if latex_match:
return int(latex_match.group(1)), int(latex_match.group(2))
if "/" in cleaned:
numerator, denominator = map(str.strip, cleaned.split("/", 1))
return int(numerator), int(denominator)
except:
return None
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
reward = 0.0
metadata = entry["metadata"]
try:
numerator, denominator = self._extract_fraction(answer)
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
reward = 1.0
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
reward = 0.1
elif len(answer.strip()) > 0:
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)

View file

@ -57,7 +57,7 @@ class GCDDataset(ProceduralDataset):
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
"answer": str(result),
"metadata": {"numbers": numbers, "result": result},
}

View file

@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx]
return generator(rng, self.config.difficulty)
example = generator(rng, self.config.difficulty)
example["question"] += " Give only the result as your final answer."
return example
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -54,14 +54,29 @@ ANIMALS = {
"woodlouse": 14,
}
QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals.
Example:
- Input: How many legs are there in total if you have 1 duck, 2 deers, 1 spider, 3 cows?
- Output: 30
- Explanation:
- Ducks have 2 legs each, so 1 duck has 2 legs.
- Deers have 4 legs each, so 2 deers have 8 legs.
- Spiders have 8 legs each, so 1 spider has 8 legs.
- Cows have 4 legs each, so 3 cows have 12 legs.
- Therefore, the total number of legs is 2 + 8 + 8 + 12 = 30
Now, how many legs are there in total if you have {animals}?
"""
@dataclass
class LegCountingConfig:
"""Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal
min_animals: int = 3 # Minimum number of animals in problem
max_animals: int = 10 # Maximum number of animals
max_instances: int = 15 # Maximum instances of each animal
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
@ -106,10 +121,8 @@ class LegCountingDataset(ProceduralDataset):
for animal, count in animals.items():
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
return {
"question": question,
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
"answer": str(total_legs),
"metadata": {
"difficulty": {

View file

@ -7,7 +7,24 @@ from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Compute {base}^{exponent}"""
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
Example:
- Input: Compute 2^3
- Output: 8
- Explanation:
- 2^3 = 2 * 2 * 2 = 8
- Therefore, the final answer is 8
Example:
- Input: Compute 412.5^3
- Output: 70189453.125
- Explanation:
- 412.5^3 = 412.5 * 412.5 * 412.5 = 70189453.125
- Therefore, the final answer is 70189453.125
Compute {base}^{exponent}
"""
@dataclass
@ -32,28 +49,31 @@ class PowerFunctionDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None:
difference = abs(float(answer) - float(oracle_answer))
if difference < 1e-6:
reward = 1.0
elif difference < 1e-1:
reward = 0.5
else:
reward = 0.01
return reward
try:
answer = round(float(answer), 4)
oracle_answer = round(float(oracle_answer), 4)
difference = abs(float(answer) - float(oracle_answer))
if difference < 1e-4:
return 1.0
elif difference < 1e-1:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict:
"""Generate a single Power Function question"""
rng = Random(self.seed + idx)
base = rng.uniform(self.config.min_base, self.config.max_base)
base = round(rng.uniform(self.config.min_base, self.config.max_base), 4)
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
answer = pow(base, exponent)
return {
"question": f"Compute {base}^{exponent}",
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer),
"metadata": {"base": base, "exponent": exponent, "solution": answer},
}

View file

@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result),
"metadata": {
"difficulty": {