This commit is contained in:
Rich Jones 2025-02-11 12:54:23 +01:00
parent 93a7a58023
commit 9cd4e825d4
2 changed files with 14 additions and 19 deletions

View file

@ -1,8 +1,8 @@
from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import Dict, Optional
from math import gcd
from functools import reduce
from ..factory import ProceduralDataset, register_dataset
@ -37,6 +37,7 @@ def compute_probability(dice, target):
frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng):
"""
Generates a puzzle:
@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng):
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
- The dice are then shuffled.
- The target total is chosen roughly in the middle (but you can adjust the method).
It then computes the probability of rolling a total at least the target.
Finally, it prints out the puzzle statement and the answer.
"""
# Guarantee one die is the maximum.
dice = [max_dice_size]
for _ in range(num_dice - 1):
@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng):
else:
die = 2
dice.append(die)
# Optionally, sort dice in descending order (as is common in puzzles)
dice.sort(reverse=True)
# Compute minimum and maximum possible totals.
min_total = num_dice # each die gives at least 1
max_total = sum(dice)
# Choose a target total. For an interesting puzzle,
# we choose a target somewhere in the middle third of the range.
low_target = min_total + (max_total - min_total) // 3
high_target = min_total + 2 * (max_total - min_total) // 3
target = rng.randint(low_target, high_target)
# Compute probability.
(num, den), prob = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle.
return {
'dice_str': dice_str,
'target': target,
'num': num,
'den': den
}
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
@dataclass
@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset):
return {
"question": puzzle_str,
"answer": answer_str,
"metadata": {
},
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: