mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
fmt
This commit is contained in:
parent
93a7a58023
commit
9cd4e825d4
2 changed files with 14 additions and 19 deletions
|
|
@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
|
||||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||||
from .chain_sum import ChainSum, ChainSumConfig
|
from .chain_sum import ChainSum, ChainSumConfig
|
||||||
from .count_bits import CountBitsConfig, CountBitsDataset
|
from .count_bits import CountBitsConfig, CountBitsDataset
|
||||||
|
from .dice import DiceConfig, DiceDataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
from .gcd import GCDConfig, GCDDataset
|
from .gcd import GCDConfig, GCDDataset
|
||||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||||
|
|
@ -14,7 +15,6 @@ from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||||
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
||||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
||||||
from .dice import DiceConfig, DiceDataset
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasicArithmeticDataset",
|
"BasicArithmeticDataset",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from math import gcd
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -37,6 +37,7 @@ def compute_probability(dice, target):
|
||||||
frac = simplify(ways, total_outcomes)
|
frac = simplify(ways, total_outcomes)
|
||||||
return frac, ways / total_outcomes
|
return frac, ways / total_outcomes
|
||||||
|
|
||||||
|
|
||||||
def generate_puzzle(num_dice, max_dice_size, rng):
|
def generate_puzzle(num_dice, max_dice_size, rng):
|
||||||
"""
|
"""
|
||||||
Generates a puzzle:
|
Generates a puzzle:
|
||||||
|
|
@ -44,11 +45,11 @@ def generate_puzzle(num_dice, max_dice_size, rng):
|
||||||
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
|
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
|
||||||
- The dice are then shuffled.
|
- The dice are then shuffled.
|
||||||
- The target total is chosen roughly in the middle (but you can adjust the method).
|
- The target total is chosen roughly in the middle (but you can adjust the method).
|
||||||
|
|
||||||
It then computes the probability of rolling a total at least the target.
|
It then computes the probability of rolling a total at least the target.
|
||||||
Finally, it prints out the puzzle statement and the answer.
|
Finally, it prints out the puzzle statement and the answer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Guarantee one die is the maximum.
|
# Guarantee one die is the maximum.
|
||||||
dice = [max_dice_size]
|
dice = [max_dice_size]
|
||||||
for _ in range(num_dice - 1):
|
for _ in range(num_dice - 1):
|
||||||
|
|
@ -59,33 +60,28 @@ def generate_puzzle(num_dice, max_dice_size, rng):
|
||||||
else:
|
else:
|
||||||
die = 2
|
die = 2
|
||||||
dice.append(die)
|
dice.append(die)
|
||||||
|
|
||||||
# Optionally, sort dice in descending order (as is common in puzzles)
|
# Optionally, sort dice in descending order (as is common in puzzles)
|
||||||
dice.sort(reverse=True)
|
dice.sort(reverse=True)
|
||||||
|
|
||||||
# Compute minimum and maximum possible totals.
|
# Compute minimum and maximum possible totals.
|
||||||
min_total = num_dice # each die gives at least 1
|
min_total = num_dice # each die gives at least 1
|
||||||
max_total = sum(dice)
|
max_total = sum(dice)
|
||||||
|
|
||||||
# Choose a target total. For an interesting puzzle,
|
# Choose a target total. For an interesting puzzle,
|
||||||
# we choose a target somewhere in the middle third of the range.
|
# we choose a target somewhere in the middle third of the range.
|
||||||
low_target = min_total + (max_total - min_total) // 3
|
low_target = min_total + (max_total - min_total) // 3
|
||||||
high_target = min_total + 2 * (max_total - min_total) // 3
|
high_target = min_total + 2 * (max_total - min_total) // 3
|
||||||
target = rng.randint(low_target, high_target)
|
target = rng.randint(low_target, high_target)
|
||||||
|
|
||||||
# Compute probability.
|
# Compute probability.
|
||||||
(num, den), prob = compute_probability(dice, target)
|
(num, den), prob = compute_probability(dice, target)
|
||||||
|
|
||||||
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
||||||
dice_str = ", ".join(f"1d{s}" for s in dice)
|
dice_str = ", ".join(f"1d{s}" for s in dice)
|
||||||
|
|
||||||
# Return the puzzle.
|
# Return the puzzle.
|
||||||
return {
|
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
|
||||||
'dice_str': dice_str,
|
|
||||||
'target': target,
|
|
||||||
'num': num,
|
|
||||||
'den': den
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -126,8 +122,7 @@ class DiceDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": puzzle_str,
|
"question": puzzle_str,
|
||||||
"answer": answer_str,
|
"answer": answer_str,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue