This commit is contained in:
Rich Jones 2025-02-11 12:54:23 +01:00
parent 93a7a58023
commit 9cd4e825d4
2 changed files with 14 additions and 19 deletions

View file

@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig from .chain_sum import ChainSum, ChainSumConfig
from .count_bits import CountBitsConfig, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsDataset
from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset from .gcd import GCDConfig, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
@ -14,7 +15,6 @@ from .leg_counting import LegCountingConfig, LegCountingDataset
from .power_function import PowerFunctionConfig, PowerFunctionDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
from .dice import DiceConfig, DiceDataset
__all__ = [ __all__ = [
"BasicArithmeticDataset", "BasicArithmeticDataset",

View file

@ -1,8 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random from random import Random
from typing import Dict, Optional from typing import Dict, Optional
from math import gcd
from functools import reduce
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -37,6 +37,7 @@ def compute_probability(dice, target):
frac = simplify(ways, total_outcomes) frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng): def generate_puzzle(num_dice, max_dice_size, rng):
""" """
Generates a puzzle: Generates a puzzle:
@ -80,12 +81,7 @@ def generate_puzzle(num_dice, max_dice_size, rng):
dice_str = ", ".join(f"1d{s}" for s in dice) dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle. # Return the puzzle.
return { return {"dice_str": dice_str, "target": target, "num": num, "den": den}
'dice_str': dice_str,
'target': target,
'num': num,
'den': den
}
@dataclass @dataclass
@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset):
return { return {
"question": puzzle_str, "question": puzzle_str,
"answer": answer_str, "answer": answer_str,
"metadata": { "metadata": {},
},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: