reasoning-gym/reasoning_gym/arithmetic/dice.py
2025-03-11 00:44:48 +01:00

183 lines
6.4 KiB
Python

from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
def compute_probability(dice, target):
"""
Computes the probability of rolling a total of at least `target`
when rolling dice specified in the list `dice`. Each element in dice
is the number of sides on that die. The computation is done via dynamic programming.
Returns the probability as a fraction (numerator, denominator) and as a float.
"""
# dp[i][s] = number of ways to get sum s using the first i dice.
# We use only one dictionary for the current dp state.
dp = {0: 1}
for sides in dice:
new_dp = {}
for current_sum, count in dp.items():
# Each die gives a number from 1 to sides.
for face in range(1, sides + 1):
new_sum = current_sum + face
new_dp[new_sum] = new_dp.get(new_sum, 0) + count
dp = new_dp
total_outcomes = reduce(lambda a, b: a * b, dice, 1)
ways = sum(count for s, count in dp.items() if s >= target)
# Simplify the fraction (ways / total_outcomes)
def simplify(n, d):
common = gcd(n, d)
return n // common, d // common
frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng):
"""
Generates a puzzle:
- It forces one die to have max_dice_size.
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
- The dice are then shuffled.
- The target total is chosen roughly in the middle (but you can adjust the method).
It then computes the probability of rolling a total at least the target.
Finally, it prints out the puzzle statement and the answer.
"""
# Guarantee one die is the maximum.
dice = [max_dice_size]
for _ in range(num_dice - 1):
# Choose a die size randomly from 2 up to max_dice_size-1.
# (If max_dice_size == 2 then all dice are 2-sided.)
if max_dice_size > 2:
die = rng.randint(2, max_dice_size - 1)
else:
die = 2
dice.append(die)
# Optionally, sort dice in descending order (as is common in puzzles)
dice.sort(reverse=True)
# Compute minimum and maximum possible totals.
min_total = num_dice # each die gives at least 1
max_total = sum(dice)
# Choose a target total. For an interesting puzzle,
# we choose a target somewhere in the middle third of the range.
low_target = min_total + (max_total - min_total) // 3
high_target = min_total + 2 * (max_total - min_total) // 3
target = rng.randint(low_target, high_target)
# Compute probability. Don't change this.
(num, den), probability = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle.
return {"dice_str": dice_str, "target": target, "num": num, "den": den, "probability": probability}
@dataclass
class DiceConfig:
"""Configuration for dice puzzle generation"""
num_dice: int = 4
max_dice_size: int = 20
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.num_dice >= 1, "num_dice must be gte 1"
assert self.max_dice_size >= 2, "max_dice_size must be gte 2"
class DiceDataset(ProceduralDataset):
"""Generates Dice-based puzzles with configurable parameters"""
def __init__(self, config: DiceConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Dice task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng)
puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]."
answer_str = f"{puzzle['num']}/{puzzle['den']}"
return {
"question": puzzle_str,
"answer": answer_str,
"metadata": {
"difficulty": {
"num_dice": self.config.num_dice,
"max_dice_size": self.config.max_dice_size,
},
"puzzle": puzzle,
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the Dice task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if isinstance(answer, str):
if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""):
return 1.0 # Yay
return 0.0
class DiceCurriculum(BaseCurriculum):
"""Curriculum for dice puzzle generation"""
def __init__(self):
super().__init__(DiceCurriculum.__name__, DiceConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="num_dice",
levels=[4, 5, 6, 7],
default_level=0,
description="Number of dice to roll",
attr_type=AttributeType.STATIC,
min_value=4,
field_name="num_dice",
),
ScalarAttributeDefinition(
name="max_dice_size",
levels=[20, 25, 30, 35],
default_level=0,
description="Maximum number of sides on any die",
attr_type=AttributeType.STATIC,
min_value=20,
field_name="max_dice_size",
),
)
register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum)