mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fraction simplification curriculum (#349)
This commit is contained in:
parent
db3868150f
commit
ee001f38a4
3 changed files with 80 additions and 7 deletions
|
|
@ -10,7 +10,11 @@ from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
|
|||
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
|
||||
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
|
||||
from .dice import DiceConfig, DiceCurriculum, DiceDataset
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .fraction_simplification import (
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationCurriculum,
|
||||
FractionSimplificationDataset,
|
||||
)
|
||||
from .gcd import GCDConfig, GCDCurriculum, GCDDataset
|
||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||
from .lcm import LCMConfig, LCMCurriculum, LCMDataset
|
||||
|
|
@ -32,6 +36,7 @@ __all__ = [
|
|||
"CalendarArithmeticCurriculum",
|
||||
"FractionSimplificationConfig",
|
||||
"FractionSimplificationDataset",
|
||||
"FractionSimplificationCurriculum",
|
||||
"GCDConfig",
|
||||
"GCDDataset",
|
||||
"GCDCurriculum",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||
|
|
@ -42,7 +43,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
def __init__(self, config: FractionSimplificationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]:
|
||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int, int]:
|
||||
"""Generate a random fraction and its simplified form.
|
||||
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
||||
# Try to generate valid fractions until we get one that meets our criteria
|
||||
|
|
@ -69,7 +70,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
numerator = simplified_num * factor
|
||||
denominator = simplified_den * factor
|
||||
return numerator, denominator, simplified_num, simplified_den
|
||||
return numerator, denominator, simplified_num, simplified_den, factor
|
||||
|
||||
# If we failed to find a good fraction after max attempts,
|
||||
# generate one that's guaranteed to be within bounds
|
||||
|
|
@ -81,7 +82,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den, factor)
|
||||
|
||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||
"""Format a fraction in various styles"""
|
||||
|
|
@ -100,7 +101,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"""Generate a single fraction simplification task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||
num, den, simple_num, simple_den, factor = self._generate_fraction(rng)
|
||||
|
||||
# Choose a random style from configured styles
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||
|
|
@ -119,6 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"simplified_denominator": simple_den,
|
||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||
"style": style,
|
||||
"difficulty": {
|
||||
"factor": factor,
|
||||
"value": (simple_num, simple_den),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -152,4 +157,38 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
class FractionSimplificationCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(FractionSimplificationCurriculum.__name__, FractionSimplificationConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[1, 100, 1000, 10000],
|
||||
default_level=1,
|
||||
description="Value range for numerator and denominator",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="factor",
|
||||
levels=[1, 10, 100, 1000],
|
||||
default_level=1,
|
||||
description="Factor range for generating unsimplified fractions",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_factor",
|
||||
upper_field_name="max_factor",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"fraction_simplification",
|
||||
FractionSimplificationDataset,
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationCurriculum,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ from math import gcd
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from reasoning_gym.arithmetic import (
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationCurriculum,
|
||||
FractionSimplificationDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_fraction_config_validation():
|
||||
|
|
@ -129,3 +133,28 @@ def test_fraction_numerator_smaller():
|
|||
assert (
|
||||
metadata["simplified_numerator"] <= metadata["simplified_denominator"]
|
||||
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
||||
|
||||
|
||||
def test_fraction_simplification_curriculum():
|
||||
curriculum = FractionSimplificationCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: FractionSimplificationConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_value == 1 and base_cfg.max_value == 100
|
||||
assert base_cfg.min_factor == 1 and base_cfg.max_factor == 10
|
||||
|
||||
# test incrementing attribute levels
|
||||
curriculum.increment_attr_level("value")
|
||||
curriculum.increment_attr_level("factor")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_value == 1 and increased_cfg.max_value == 1000
|
||||
assert increased_cfg.min_factor == 1 and increased_cfg.max_factor == 100
|
||||
|
||||
# test decrementing attribute level for value again
|
||||
curriculum.decrement_attr_level("value")
|
||||
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert partially_decreased_cfg.min_value == 1 and partially_decreased_cfg.max_value == 100
|
||||
assert partially_decreased_cfg.min_factor == 1 and partially_decreased_cfg.max_factor == 100
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue