mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath
This commit is contained in:
commit
59229bd2d2
62 changed files with 4012 additions and 478 deletions
|
|
@ -64,6 +64,9 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
|
||||
def __init__(self, config: BasicArithmeticDatasetConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.added_instruction = (
|
||||
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
|
||||
)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
|
@ -88,7 +91,7 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
else:
|
||||
expression, result = self._generate_simple_task(rng, num_terms, num_digits)
|
||||
|
||||
question = self._format_question(rng, expression)
|
||||
question = self._format_question(rng, expression) + self.added_instruction
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
|
|
@ -223,12 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
return expression, result
|
||||
|
||||
def _format_question(self, rng: Random, expression: str) -> str:
|
||||
"""Format the expression according to config style"""
|
||||
"""Format the the question with the arithmetic expression"""
|
||||
|
||||
if self.config.format_style == "simple":
|
||||
return f"{expression} ="
|
||||
return f"Calculate {expression}."
|
||||
else:
|
||||
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
|
||||
return rng.choice(templates).format(expression)
|
||||
templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
|
||||
template = rng.choice(templates)
|
||||
return template.format(expression)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
|
|||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||
|
||||
return {
|
||||
"question": f"{expression} =",
|
||||
"question": f"State the final answer to the following arithmetic problem: {expression} =",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
"""Fraction simplification task generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||
|
||||
|
||||
@dataclass
|
||||
class FractionSimplificationConfig:
|
||||
|
|
@ -107,7 +110,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||
|
||||
return {
|
||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
||||
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
|
||||
"answer": answer_fraction,
|
||||
"metadata": {
|
||||
"numerator": num,
|
||||
|
|
@ -119,5 +122,34 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def _extract_fraction(self, answer: Optional[str]):
|
||||
try:
|
||||
cleaned = answer.strip().strip("$").strip()
|
||||
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
|
||||
if latex_match:
|
||||
return int(latex_match.group(1)), int(latex_match.group(2))
|
||||
if "/" in cleaned:
|
||||
numerator, denominator = map(str.strip, cleaned.split("/", 1))
|
||||
return int(numerator), int(denominator)
|
||||
except:
|
||||
return None
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
try:
|
||||
numerator, denominator = self._extract_fraction(answer)
|
||||
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
|
||||
reward = 1.0
|
||||
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
|
||||
reward = 0.1
|
||||
elif len(answer.strip()) > 0:
|
||||
reward = 0.05
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
reward = 0.01
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class GCDDataset(ProceduralDataset):
|
|||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
generator_idx = self.task_indices[idx]
|
||||
generator = self.generators[generator_idx]
|
||||
return generator(rng, self.config.difficulty)
|
||||
example = generator(rng, self.config.difficulty)
|
||||
example["question"] += " Give only the result as your final answer."
|
||||
return example
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
|
|
@ -54,14 +54,29 @@ ANIMALS = {
|
|||
"woodlouse": 14,
|
||||
}
|
||||
|
||||
QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals.
|
||||
|
||||
Example:
|
||||
- Input: How many legs are there in total if you have 1 duck, 2 deers, 1 spider, 3 cows?
|
||||
- Output: 30
|
||||
- Explanation:
|
||||
- Ducks have 2 legs each, so 1 duck has 2 legs.
|
||||
- Deers have 4 legs each, so 2 deers have 8 legs.
|
||||
- Spiders have 8 legs each, so 1 spider has 8 legs.
|
||||
- Cows have 4 legs each, so 3 cows have 12 legs.
|
||||
- Therefore, the total number of legs is 2 + 8 + 8 + 12 = 30
|
||||
|
||||
Now, how many legs are there in total if you have {animals}?
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegCountingConfig:
|
||||
"""Configuration for leg counting task generation"""
|
||||
|
||||
min_animals: int = 2 # Minimum number of animals in problem
|
||||
max_animals: int = 5 # Maximum number of animals
|
||||
max_instances: int = 3 # Maximum instances of each animal
|
||||
min_animals: int = 3 # Minimum number of animals in problem
|
||||
max_animals: int = 10 # Maximum number of animals
|
||||
max_instances: int = 15 # Maximum instances of each animal
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
|
|
@ -106,10 +121,8 @@ class LegCountingDataset(ProceduralDataset):
|
|||
for animal, count in animals.items():
|
||||
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
||||
|
||||
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
|
||||
"answer": str(total_legs),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,24 @@ from typing import Dict, Optional
|
|||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Compute {base}^{exponent}"""
|
||||
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
|
||||
|
||||
Example:
|
||||
- Input: Compute 2^3
|
||||
- Output: 8
|
||||
- Explanation:
|
||||
- 2^3 = 2 * 2 * 2 = 8
|
||||
- Therefore, the final answer is 8
|
||||
|
||||
Example:
|
||||
- Input: Compute 412.5^3
|
||||
- Output: 70189453.125
|
||||
- Explanation:
|
||||
- 412.5^3 = 412.5 * 412.5 * 412.5 = 70189453.125
|
||||
- Therefore, the final answer is 70189453.125
|
||||
|
||||
Compute {base}^{exponent}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -32,28 +49,31 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
difference = abs(float(answer) - float(oracle_answer))
|
||||
if difference < 1e-6:
|
||||
reward = 1.0
|
||||
elif difference < 1e-1:
|
||||
reward = 0.5
|
||||
else:
|
||||
reward = 0.01
|
||||
|
||||
return reward
|
||||
try:
|
||||
answer = round(float(answer), 4)
|
||||
oracle_answer = round(float(oracle_answer), 4)
|
||||
difference = abs(float(answer) - float(oracle_answer))
|
||||
if difference < 1e-4:
|
||||
return 1.0
|
||||
elif difference < 1e-1:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Power Function question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
base = rng.uniform(self.config.min_base, self.config.max_base)
|
||||
base = round(rng.uniform(self.config.min_base, self.config.max_base), 4)
|
||||
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
|
||||
answer = pow(base, exponent)
|
||||
|
||||
return {
|
||||
"question": f"Compute {base}^{exponent}",
|
||||
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
|
||||
"answer": str(answer),
|
||||
"metadata": {"base": base, "exponent": exponent, "solution": answer},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
|
|||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||
|
||||
return {
|
||||
"question": f"{expression} =",
|
||||
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue