This commit is contained in:
Rich Jones 2025-02-11 12:54:23 +01:00
parent 93a7a58023
commit 9cd4e825d4
2 changed files with 14 additions and 19 deletions

View file

@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig from .chain_sum import ChainSum, ChainSumConfig
from .count_bits import CountBitsConfig, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsDataset
from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset from .gcd import GCDConfig, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
@ -14,7 +15,6 @@ from .leg_counting import LegCountingConfig, LegCountingDataset
from .power_function import PowerFunctionConfig, PowerFunctionDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
from .dice import DiceConfig, DiceDataset
__all__ = [ __all__ = [
"BasicArithmeticDataset", "BasicArithmeticDataset",

View file

@ -1,8 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random from random import Random
from typing import Dict, Optional from typing import Dict, Optional
from math import gcd
from functools import reduce
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -37,6 +37,7 @@ def compute_probability(dice, target):
frac = simplify(ways, total_outcomes) frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng): def generate_puzzle(num_dice, max_dice_size, rng):
""" """
Generates a puzzle: Generates a puzzle:
@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng):
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1. - The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
- The dice are then shuffled. - The dice are then shuffled.
- The target total is chosen roughly in the middle (but you can adjust the method). - The target total is chosen roughly in the middle (but you can adjust the method).
It then computes the probability of rolling a total at least the target. It then computes the probability of rolling a total at least the target.
Finally, it prints out the puzzle statement and the answer. Finally, it prints out the puzzle statement and the answer.
""" """
# Guarantee one die is the maximum. # Guarantee one die is the maximum.
dice = [max_dice_size] dice = [max_dice_size]
for _ in range(num_dice - 1): for _ in range(num_dice - 1):
@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng):
else: else:
die = 2 die = 2
dice.append(die) dice.append(die)
# Optionally, sort dice in descending order (as is common in puzzles) # Optionally, sort dice in descending order (as is common in puzzles)
dice.sort(reverse=True) dice.sort(reverse=True)
# Compute minimum and maximum possible totals. # Compute minimum and maximum possible totals.
min_total = num_dice # each die gives at least 1 min_total = num_dice # each die gives at least 1
max_total = sum(dice) max_total = sum(dice)
# Choose a target total. For an interesting puzzle, # Choose a target total. For an interesting puzzle,
# we choose a target somewhere in the middle third of the range. # we choose a target somewhere in the middle third of the range.
low_target = min_total + (max_total - min_total) // 3 low_target = min_total + (max_total - min_total) // 3
high_target = min_total + 2 * (max_total - min_total) // 3 high_target = min_total + 2 * (max_total - min_total) // 3
target = rng.randint(low_target, high_target) target = rng.randint(low_target, high_target)
# Compute probability. # Compute probability.
(num, den), prob = compute_probability(dice, target) (num, den), prob = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice) dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle. # Return the puzzle.
return { return {"dice_str": dice_str, "target": target, "num": num, "den": den}
'dice_str': dice_str,
'target': target,
'num': num,
'den': den
}
@dataclass @dataclass
@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset):
return { return {
"question": puzzle_str, "question": puzzle_str,
"answer": answer_str, "answer": answer_str,
"metadata": { "metadata": {},
},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: