reasoning-gym/reasoning_gym/arithmetic/decimal_arithmetic.py
Rich Jones 46cdfc71cf lint
2025-02-18 19:22:03 +01:00

114 lines
3.6 KiB
Python

from dataclasses import dataclass
from random import Random
from typing import Any, Dict, Literal, Optional
from ..factory import ProceduralDataset, register_dataset
@dataclass
class DecimalArithmeticDatasetConfig:
"""Configuration for decimal arithmetic dataset generation"""
num_decimal_places: int = 6
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.num_decimal_places > 0, "num_decimal_places must be positive"
def generate_arithmetic_problem(rng, num_decimal_places, operations=None):
"""
Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places.
Parameters:
rng
num_problems (int): Number of problems to generate
num_decimal_places (int): Number of decimal places for the numbers
operations (list): List of operations to use (default: ['+', '-', '*', '/'])
Returns:
list: List of formatted arithmetic problem strings
"""
if operations is None:
operations = ["+", "-", "*", "/"]
max_integer_part = 10 # Maximum whole number portion before decimal
max_value = max_integer_part * (10**num_decimal_places)
problem = None
# Generate random numbers with exact decimal places
num1 = rng.randint(1, max_value) / (10**num_decimal_places)
num2 = rng.randint(1, max_value) / (10**num_decimal_places)
# Select random operation
op = rng.choice(operations)
# Format numbers to ensure exact decimal places
formatted_num1 = f"{num1:.{num_decimal_places}f}"
formatted_num2 = f"{num2:.{num_decimal_places}f}"
problem = f"{formatted_num1} {op} {formatted_num2} = ?"
return problem
def eval_floordiv(exp: str) -> int:
return eval(exp.replace("/", "//").replace(" = ?", ""))
class DecimalArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
def __init__(self, config: DecimalArithmeticDatasetConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted arithmetic expression
- answer: str, the ground truth result
- metadata: dict with generation parameters
"""
# Create deterministic RNG from base seed and idx
rng = Random(self.seed + idx)
decimal_problem = generate_arithmetic_problem(rng, self.config.num_decimal_places)
answer = eval_floordiv(decimal_problem)
return {"question": decimal_problem, "answer": answer, "metadata": {}}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the Sokoban task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
try:
if float(answer) == entry["answer"]:
return 1.0
except Exception as e:
return 0.01
return 0.01
# Register the dataset
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig)