diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index 32335a50..da84f28e 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from random import Random -from typing import Any, Literal, Optional, Dict +from typing import Any, Dict, Literal, Optional from ..factory import ProceduralDataset, register_dataset @@ -21,42 +21,42 @@ class DecimalArithmeticDatasetConfig: def generate_arithmetic_problem(rng, num_decimal_places, operations=None): """ Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places. - + Parameters: rng num_problems (int): Number of problems to generate num_decimal_places (int): Number of decimal places for the numbers operations (list): List of operations to use (default: ['+', '-', '*', '/']) - + Returns: list: List of formatted arithmetic problem strings """ if operations is None: - operations = ['+', '-', '*', '/'] - + operations = ["+", "-", "*", "/"] + max_integer_part = 10 # Maximum whole number portion before decimal - max_value = max_integer_part * (10 ** num_decimal_places) - + max_value = max_integer_part * (10**num_decimal_places) + problem = None # Generate random numbers with exact decimal places - num1 = rng.randint(1, max_value) / (10 ** num_decimal_places) - num2 = rng.randint(1, max_value) / (10 ** num_decimal_places) - + num1 = rng.randint(1, max_value) / (10**num_decimal_places) + num2 = rng.randint(1, max_value) / (10**num_decimal_places) + # Select random operation op = rng.choice(operations) - + # Format numbers to ensure exact decimal places formatted_num1 = f"{num1:.{num_decimal_places}f}" formatted_num2 = f"{num2:.{num_decimal_places}f}" - + problem = f"{formatted_num1} {op} {formatted_num2} = ?" - + return problem def eval_floordiv(exp: str) -> int: - return eval(exp.replace("/", "//").replace(" = ?", '')) + return eval(exp.replace("/", "//").replace(" = ?", "")) class DecimalArithmeticDataset(ProceduralDataset): @@ -83,13 +83,7 @@ class DecimalArithmeticDataset(ProceduralDataset): decimal_problem = generate_arithmetic_problem(rng, self.config.num_decimal_places) answer = eval_floordiv(decimal_problem) - return { - "question": decimal_problem, - "answer": answer, - "metadata": { - - } - } + return {"question": decimal_problem, "answer": answer, "metadata": {}} def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: """Determine if the solution provided solves the Sokoban task. @@ -108,12 +102,13 @@ class DecimalArithmeticDataset(ProceduralDataset): return 0.0 try: - if float(answer) == entry['answer']: + if float(answer) == entry["answer"]: return 1.0 except Exception as e: return 0.01 return 0.01 + # Register the dataset register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig) diff --git a/reasoning_gym/code/bf.py b/reasoning_gym/code/bf.py index c2697203..58b55f23 100644 --- a/reasoning_gym/code/bf.py +++ b/reasoning_gym/code/bf.py @@ -28,7 +28,7 @@ class BFDataset(ProceduralDataset): def __init__(self, config: BFConfig): self._prompt_templates = [ - "This is a BF (Brainf*ck) computer program. What is the output? \n\n{bf_program}", + "This is a BF (Brainf*ck) computer program. What is the output? Reply only with the program output, ex: 42. \n\n{bf_program}", ] super().__init__(config=config, seed=config.seed, size=config.size) diff --git a/tests/test_decimal_arithmetic.py b/tests/test_decimal_arithmetic.py index c499105c..5f303527 100644 --- a/tests/test_decimal_arithmetic.py +++ b/tests/test_decimal_arithmetic.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticDatasetConfig, DecimalArithmeticDataset +from reasoning_gym.arithmetic.decimal_arithmetic import DecimalArithmeticDataset, DecimalArithmeticDatasetConfig def test_decimal_arithmetic(): @@ -41,4 +41,4 @@ def test_decimal_arithmetic(): assert "answer" in item assert "metadata" in item - assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 \ No newline at end of file + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0