mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
decimal math
This commit is contained in:
parent
f8f25f45a1
commit
80d8dbcc36
2 changed files with 163 additions and 0 deletions
119
reasoning_gym/arithmetic/decimal_arithmetic.py
Normal file
119
reasoning_gym/arithmetic/decimal_arithmetic.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Literal, Optional, Dict
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecimalArithmeticDatasetConfig:
|
||||
"""Configuration for decimal arithmetic dataset generation"""
|
||||
|
||||
num_decimal_places: int = 6
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.num_decimal_places > 0, "num_decimal_places must be positive"
|
||||
|
||||
|
||||
def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
|
||||
"""
|
||||
Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places.
|
||||
|
||||
Parameters:
|
||||
rng
|
||||
num_problems (int): Number of problems to generate
|
||||
num_decimal_places (int): Number of decimal places for the numbers
|
||||
operations (list): List of operations to use (default: ['+', '-', '*', '/'])
|
||||
|
||||
Returns:
|
||||
list: List of formatted arithmetic problem strings
|
||||
"""
|
||||
if operations is None:
|
||||
operations = ['+', '-', '*', '/']
|
||||
|
||||
max_integer_part = 10 # Maximum whole number portion before decimal
|
||||
max_value = max_integer_part * (10 ** num_decimal_places)
|
||||
|
||||
problem = None
|
||||
|
||||
# Generate random numbers with exact decimal places
|
||||
num1 = rng.randint(1, max_value) / (10 ** num_decimal_places)
|
||||
num2 = rng.randint(1, max_value) / (10 ** num_decimal_places)
|
||||
|
||||
# Select random operation
|
||||
op = rng.choice(operations)
|
||||
|
||||
# Format numbers to ensure exact decimal places
|
||||
formatted_num1 = f"{num1:.{num_decimal_places}f}"
|
||||
formatted_num2 = f"{num2:.{num_decimal_places}f}"
|
||||
|
||||
problem = f"{formatted_num1} {op} {formatted_num2} = ?"
|
||||
|
||||
return problem
|
||||
|
||||
|
||||
def eval_floordiv(exp: str) -> int:
|
||||
return eval(exp.replace("/", "//").replace(" = ?", ''))
|
||||
|
||||
|
||||
class DecimalArithmeticDataset(ProceduralDataset):
|
||||
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
|
||||
|
||||
def __init__(self, config: DecimalArithmeticDatasetConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
||||
Args:
|
||||
idx: Index of the item to generate
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the formatted arithmetic expression
|
||||
- answer: str, the ground truth result
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
# Create deterministic RNG from base seed and idx
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
decimal_problem = generate_arithmetic_problem(rng, self.config.num_decimal_places)
|
||||
answer = eval_floordiv(decimal_problem)
|
||||
|
||||
return {
|
||||
"question": decimal_problem,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Determine if the solution provided solves the Sokoban task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
if float(answer) == entry['answer']:
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
|
||||
return 0.01
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig)
|
||||
Loading…
Add table
Add a link
Reference in a new issue