diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 94de4880..cc94bf56 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSum, ChainSumConfig from .count_bits import CountBitsConfig, CountBitsDataset +from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig @@ -14,7 +15,6 @@ from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset -from .dice import DiceConfig, DiceDataset __all__ = [ "BasicArithmeticDataset", diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index a3203dd9..6fd464b9 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -1,8 +1,8 @@ from dataclasses import dataclass +from functools import reduce +from math import gcd from random import Random from typing import Dict, Optional -from math import gcd -from functools import reduce from ..factory import ProceduralDataset, register_dataset @@ -37,6 +37,7 @@ def compute_probability(dice, target): frac = simplify(ways, total_outcomes) return frac, ways / total_outcomes + def generate_puzzle(num_dice, max_dice_size, rng): """ Generates a puzzle: @@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng): - The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1. - The dice are then shuffled. - The target total is chosen roughly in the middle (but you can adjust the method). - + It then computes the probability of rolling a total at least the target. Finally, it prints out the puzzle statement and the answer. """ - + # Guarantee one die is the maximum. dice = [max_dice_size] for _ in range(num_dice - 1): @@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng): else: die = 2 dice.append(die) - + # Optionally, sort dice in descending order (as is common in puzzles) dice.sort(reverse=True) - + # Compute minimum and maximum possible totals. min_total = num_dice # each die gives at least 1 max_total = sum(dice) - + # Choose a target total. For an interesting puzzle, # we choose a target somewhere in the middle third of the range. low_target = min_total + (max_total - min_total) // 3 high_target = min_total + 2 * (max_total - min_total) // 3 target = rng.randint(low_target, high_target) - + # Compute probability. (num, den), prob = compute_probability(dice, target) - + # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. dice_str = ", ".join(f"1d{s}" for s in dice) - + # Return the puzzle. - return { - 'dice_str': dice_str, - 'target': target, - 'num': num, - 'den': den - } + return {"dice_str": dice_str, "target": target, "num": num, "den": den} @dataclass @@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset): return { "question": puzzle_str, "answer": answer_str, - "metadata": { - }, + "metadata": {}, } def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: