mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fmt
This commit is contained in:
parent
93a7a58023
commit
9cd4e825d4
2 changed files with 14 additions and 19 deletions
|
|
@ -1,8 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from math import gcd
|
||||
from functools import reduce
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -37,6 +37,7 @@ def compute_probability(dice, target):
|
|||
frac = simplify(ways, total_outcomes)
|
||||
return frac, ways / total_outcomes
|
||||
|
||||
|
||||
def generate_puzzle(num_dice, max_dice_size, rng):
|
||||
"""
|
||||
Generates a puzzle:
|
||||
|
|
@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng):
|
|||
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
|
||||
- The dice are then shuffled.
|
||||
- The target total is chosen roughly in the middle (but you can adjust the method).
|
||||
|
||||
|
||||
It then computes the probability of rolling a total at least the target.
|
||||
Finally, it prints out the puzzle statement and the answer.
|
||||
"""
|
||||
|
||||
|
||||
# Guarantee one die is the maximum.
|
||||
dice = [max_dice_size]
|
||||
for _ in range(num_dice - 1):
|
||||
|
|
@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng):
|
|||
else:
|
||||
die = 2
|
||||
dice.append(die)
|
||||
|
||||
|
||||
# Optionally, sort dice in descending order (as is common in puzzles)
|
||||
dice.sort(reverse=True)
|
||||
|
||||
|
||||
# Compute minimum and maximum possible totals.
|
||||
min_total = num_dice # each die gives at least 1
|
||||
max_total = sum(dice)
|
||||
|
||||
|
||||
# Choose a target total. For an interesting puzzle,
|
||||
# we choose a target somewhere in the middle third of the range.
|
||||
low_target = min_total + (max_total - min_total) // 3
|
||||
high_target = min_total + 2 * (max_total - min_total) // 3
|
||||
target = rng.randint(low_target, high_target)
|
||||
|
||||
|
||||
# Compute probability.
|
||||
(num, den), prob = compute_probability(dice, target)
|
||||
|
||||
|
||||
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
||||
dice_str = ", ".join(f"1d{s}" for s in dice)
|
||||
|
||||
|
||||
# Return the puzzle.
|
||||
return {
|
||||
'dice_str': dice_str,
|
||||
'target': target,
|
||||
'num': num,
|
||||
'den': den
|
||||
}
|
||||
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": puzzle_str,
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue