mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
dice curriculum (#284)
* curriculum + unit tests * add difficulty to metadata --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
d1c06e9f98
commit
6d0b219412
3 changed files with 59 additions and 5 deletions
|
|
@ -4,6 +4,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -75,7 +76,7 @@ def generate_puzzle(num_dice, max_dice_size, rng):
|
|||
target = rng.randint(low_target, high_target)
|
||||
|
||||
# Compute probability.
|
||||
(num, den), prob = compute_probability(dice, target)
|
||||
(num, den) = compute_probability(dice, target)
|
||||
|
||||
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
||||
dice_str = ", ".join(f"1d{s}" for s in dice)
|
||||
|
|
@ -122,7 +123,12 @@ class DiceDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": puzzle_str,
|
||||
"answer": answer_str,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"difficulty": {
|
||||
"num_dice": self.config.num_dice,
|
||||
"max_dice_size": self.config.max_dice_size,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -145,4 +151,32 @@ class DiceDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig)
|
||||
class DiceCurriculum(BaseCurriculum):
|
||||
"""Curriculum for dice puzzle generation"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(DiceCurriculum.__name__, DiceConfig)
|
||||
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="num_dice",
|
||||
levels=[4, 5, 6, 7],
|
||||
default_level=0,
|
||||
description="Number of dice to roll",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=4,
|
||||
field_name="num_dice",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_dice_size",
|
||||
levels=[20, 25, 30, 35],
|
||||
default_level=0,
|
||||
description="Maximum number of sides on any die",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=20,
|
||||
field_name="max_dice_size",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue