diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index c984c71f..016d252f 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -9,7 +9,7 @@ from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset -from .dice import DiceConfig, DiceDataset +from .dice import DiceConfig, DiceCurriculum, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig @@ -54,6 +54,7 @@ __all__ = [ "CountBitsCurriculum", "DiceConfig", "DiceDataset", + "DiceCurriculum", "NumberFormatConfig", "NumberFormatDataset", "DecimalArithmeticConfig", diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index 00cc6b7d..82c430d3 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -4,6 +4,7 @@ from math import gcd from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -75,7 +76,7 @@ def generate_puzzle(num_dice, max_dice_size, rng): target = rng.randint(low_target, high_target) # Compute probability. - (num, den), prob = compute_probability(dice, target) + (num, den) = compute_probability(dice, target) # Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc. dice_str = ", ".join(f"1d{s}" for s in dice) @@ -122,7 +123,12 @@ class DiceDataset(ProceduralDataset): return { "question": puzzle_str, "answer": answer_str, - "metadata": {}, + "metadata": { + "difficulty": { + "num_dice": self.config.num_dice, + "max_dice_size": self.config.max_dice_size, + } + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -145,4 +151,32 @@ class DiceDataset(ProceduralDataset): return 0.0 -register_dataset("dice", DiceDataset, DiceConfig) +class DiceCurriculum(BaseCurriculum): + """Curriculum for dice puzzle generation""" + + def __init__(self): + super().__init__(DiceCurriculum.__name__, DiceConfig) + + self._define_attributes( + ScalarAttributeDefinition( + name="num_dice", + levels=[4, 5, 6, 7], + default_level=0, + description="Number of dice to roll", + attr_type=AttributeType.STATIC, + min_value=4, + field_name="num_dice", + ), + ScalarAttributeDefinition( + name="max_dice_size", + levels=[20, 25, 30, 35], + default_level=0, + description="Maximum number of sides on any die", + attr_type=AttributeType.STATIC, + min_value=20, + field_name="max_dice_size", + ), + ) + + +register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum) diff --git a/tests/test_dice.py b/tests/test_dice.py index 8a3bd991..7656072a 100644 --- a/tests/test_dice.py +++ b/tests/test_dice.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset +from reasoning_gym.arithmetic.dice import DiceConfig, DiceCurriculum, DiceDataset def test_dice(): @@ -33,3 +33,22 @@ def test_dice(): for item in dataset: assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_dice_curriculum(): + """Test that the curriculum generates correct configurations""" + + curriculum = DiceCurriculum() + + base_value = {"size": 150, "seed": 1} + base_cfg: DiceConfig = curriculum.generate_configuration(base_value) + assert base_cfg.size == 150 + assert base_cfg.seed == 1 + assert base_cfg.num_dice == 4 + assert base_cfg.max_dice_size == 20 + + curriculum.increment_attr_level("num_dice") + curriculum.increment_attr_level("max_dice_size") + increased_cfg: DiceConfig = curriculum.generate_configuration() + assert increased_cfg.num_dice == 5 + assert increased_cfg.max_dice_size == 25