dice curriculum (#284)

* curriculum + unit tests
* add difficulty to metadata

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
vncntt 2025-03-07 16:43:45 -08:00 committed by GitHub
parent d1c06e9f98
commit 6d0b219412
3 changed files with 59 additions and 5 deletions

View file

@ -9,7 +9,7 @@ from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
from .dice import DiceConfig, DiceDataset
from .dice import DiceConfig, DiceCurriculum, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
@ -54,6 +54,7 @@ __all__ = [
"CountBitsCurriculum",
"DiceConfig",
"DiceDataset",
"DiceCurriculum",
"NumberFormatConfig",
"NumberFormatDataset",
"DecimalArithmeticConfig",

View file

@ -4,6 +4,7 @@ from math import gcd
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -75,7 +76,7 @@ def generate_puzzle(num_dice, max_dice_size, rng):
target = rng.randint(low_target, high_target)
# Compute probability.
(num, den), prob = compute_probability(dice, target)
(num, den) = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice)
@ -122,7 +123,12 @@ class DiceDataset(ProceduralDataset):
return {
"question": puzzle_str,
"answer": answer_str,
"metadata": {},
"metadata": {
"difficulty": {
"num_dice": self.config.num_dice,
"max_dice_size": self.config.max_dice_size,
}
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -145,4 +151,32 @@ class DiceDataset(ProceduralDataset):
return 0.0
register_dataset("dice", DiceDataset, DiceConfig)
class DiceCurriculum(BaseCurriculum):
"""Curriculum for dice puzzle generation"""
def __init__(self):
super().__init__(DiceCurriculum.__name__, DiceConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="num_dice",
levels=[4, 5, 6, 7],
default_level=0,
description="Number of dice to roll",
attr_type=AttributeType.STATIC,
min_value=4,
field_name="num_dice",
),
ScalarAttributeDefinition(
name="max_dice_size",
levels=[20, 25, 30, 35],
default_level=0,
description="Maximum number of sides on any die",
attr_type=AttributeType.STATIC,
min_value=20,
field_name="max_dice_size",
),
)
register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum)