diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 05d321da..cc94bf56 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSum, ChainSumConfig from .count_bits import CountBitsConfig, CountBitsDataset +from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig @@ -38,4 +39,6 @@ __all__ = [ "TimeIntervalsDataset", "CountBitsConfig", "CountBitsDataset", + "DiceConfig", + "DiceDataset", ] diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py new file mode 100644 index 00000000..f4ad97e9 --- /dev/null +++ b/reasoning_gym/arithmetic/dice.py @@ -0,0 +1,149 @@ +from dataclasses import dataclass +from functools import reduce +from math import gcd +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + + +def compute_probability(dice, target): + """ + Computes the probability of rolling a total of at least `target` + when rolling dice specified in the list `dice`. Each element in dice + is the number of sides on that die. The computation is done via dynamic programming. + Returns the probability as a fraction (numerator, denominator) and as a float. + """ + # dp[i][s] = number of ways to get sum s using the first i dice. + # We use only one dictionary for the current dp state. + dp = {0: 1} + for sides in dice: + new_dp = {} + for current_sum, count in dp.items(): + # Each die gives a number from 1 to sides. + for face in range(1, sides + 1): + new_sum = current_sum + face + new_dp[new_sum] = new_dp.get(new_sum, 0) + count + dp = new_dp + + total_outcomes = reduce(lambda a, b: a * b, dice, 1) + ways = sum(count for s, count in dp.items() if s >= target) + + # Simplify the fraction (ways / total_outcomes) + def simplify(n, d): + common = gcd(n, d) + return n // common, d // common + + frac = simplify(ways, total_outcomes) + return frac, ways / total_outcomes + + +def generate_puzzle(num_dice, max_dice_size, rng): + """ + Generates a puzzle: + - It forces one die to have max_dice_size. + - The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1. + - The dice are then shuffled. + - The target total is chosen roughly in the middle (but you can adjust the method). + + It then computes the probability of rolling a total at least the target. + Finally, it prints out the puzzle statement and the answer. + """ + + # Guarantee one die is the maximum. + dice = [max_dice_size] + for _ in range(num_dice - 1): + # Choose a die size randomly from 2 up to max_dice_size-1. + # (If max_dice_size == 2 then all dice are 2-sided.) + if max_dice_size > 2: + die = rng.randint(2, max_dice_size - 1) + else: + die = 2 + dice.append(die) + + # Optionally, sort dice in descending order (as is common in puzzles) + dice.sort(reverse=True) + + # Compute minimum and maximum possible totals. + min_total = num_dice # each die gives at least 1 + max_total = sum(dice) + + # Choose a target total. For an interesting puzzle, + # we choose a target somewhere in the middle third of the range. + low_target = min_total + (max_total - min_total) // 3 + high_target = min_total + 2 * (max_total - min_total) // 3 + target = rng.randint(low_target, high_target) + + # Compute probability. + (num, den), prob = compute_probability(dice, target) + + # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. + dice_str = ", ".join(f"1d{s}" for s in dice) + + # Return the puzzle. + return {"dice_str": dice_str, "target": target, "num": num, "den": den} + + +@dataclass +class DiceConfig: + """Configuration for dice puzzle generation""" + + num_dice: int = 4 + max_dice_size: int = 20 + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters""" + assert self.num_dice >= 1, "num_dice must be gte 1" + assert self.max_dice_size >= 2, "max_dice_size must be gte 2" + + +class DiceDataset(ProceduralDataset): + """Generates Dice-based puzzles with configurable parameters""" + + def __init__(self, config: DiceConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Dice task + + Returns: + dict with keys: + - question: str, the task description + - answer: str, a solution string + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng) + puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]." + answer_str = f"{puzzle['num']}/{puzzle['den']}" + + return { + "question": puzzle_str, + "answer": answer_str, + "metadata": {}, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Determine if the solution provided solves the Dice task. + + The function awards 1.0 for a correct answer. + + Args: + answer (Optional[str]): The user's answer. + entry (Dict[str, any]): The original dataset entry containing the correct answer. + + Returns: + float: The computed score between 0.0 and 1.0. + """ + + if answer == None: + return 0.0 + if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): + return 0.01 + else: + return 1.0 # Yay + + +register_dataset("dice", DiceDataset, DiceConfig) diff --git a/tests/test_dice.py b/tests/test_dice.py new file mode 100644 index 00000000..8a3bd991 --- /dev/null +++ b/tests/test_dice.py @@ -0,0 +1,35 @@ +import pytest + +from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset + + +def test_dice(): + """Test basic properties and solution of generated items""" + config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24) + dataset = DiceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # Easy + config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2) + dataset = DiceDataset(config) + + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # Hard + config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40) + dataset = DiceDataset(config) + + for item in dataset: + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=None, entry=item) == 0.0