mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fraction simplification curriculum (#349)
This commit is contained in:
parent
db3868150f
commit
ee001f38a4
3 changed files with 80 additions and 7 deletions
|
|
@ -10,7 +10,11 @@ from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
|
||||||
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
|
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
|
||||||
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
|
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
|
||||||
from .dice import DiceConfig, DiceCurriculum, DiceDataset
|
from .dice import DiceConfig, DiceCurriculum, DiceDataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
from .fraction_simplification import (
|
||||||
|
FractionSimplificationConfig,
|
||||||
|
FractionSimplificationCurriculum,
|
||||||
|
FractionSimplificationDataset,
|
||||||
|
)
|
||||||
from .gcd import GCDConfig, GCDCurriculum, GCDDataset
|
from .gcd import GCDConfig, GCDCurriculum, GCDDataset
|
||||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||||
from .lcm import LCMConfig, LCMCurriculum, LCMDataset
|
from .lcm import LCMConfig, LCMCurriculum, LCMDataset
|
||||||
|
|
@ -32,6 +36,7 @@ __all__ = [
|
||||||
"CalendarArithmeticCurriculum",
|
"CalendarArithmeticCurriculum",
|
||||||
"FractionSimplificationConfig",
|
"FractionSimplificationConfig",
|
||||||
"FractionSimplificationDataset",
|
"FractionSimplificationDataset",
|
||||||
|
"FractionSimplificationCurriculum",
|
||||||
"GCDConfig",
|
"GCDConfig",
|
||||||
"GCDDataset",
|
"GCDDataset",
|
||||||
"GCDCurriculum",
|
"GCDCurriculum",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Optional, Sequence
|
from typing import Any, Optional, Sequence
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||||
|
|
@ -42,7 +43,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
def __init__(self, config: FractionSimplificationConfig):
|
def __init__(self, config: FractionSimplificationConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]:
|
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int, int]:
|
||||||
"""Generate a random fraction and its simplified form.
|
"""Generate a random fraction and its simplified form.
|
||||||
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
||||||
# Try to generate valid fractions until we get one that meets our criteria
|
# Try to generate valid fractions until we get one that meets our criteria
|
||||||
|
|
@ -69,7 +70,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||||
numerator = simplified_num * factor
|
numerator = simplified_num * factor
|
||||||
denominator = simplified_den * factor
|
denominator = simplified_den * factor
|
||||||
return numerator, denominator, simplified_num, simplified_den
|
return numerator, denominator, simplified_num, simplified_den, factor
|
||||||
|
|
||||||
# If we failed to find a good fraction after max attempts,
|
# If we failed to find a good fraction after max attempts,
|
||||||
# generate one that's guaranteed to be within bounds
|
# generate one that's guaranteed to be within bounds
|
||||||
|
|
@ -81,7 +82,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
simplified_num, simplified_den = simplified_den, simplified_num
|
simplified_num, simplified_den = simplified_den, simplified_num
|
||||||
|
|
||||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den, factor)
|
||||||
|
|
||||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||||
"""Format a fraction in various styles"""
|
"""Format a fraction in various styles"""
|
||||||
|
|
@ -100,7 +101,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
"""Generate a single fraction simplification task"""
|
"""Generate a single fraction simplification task"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
num, den, simple_num, simple_den, factor = self._generate_fraction(rng)
|
||||||
|
|
||||||
# Choose a random style from configured styles
|
# Choose a random style from configured styles
|
||||||
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||||
|
|
@ -119,6 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
"simplified_denominator": simple_den,
|
"simplified_denominator": simple_den,
|
||||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||||
"style": style,
|
"style": style,
|
||||||
|
"difficulty": {
|
||||||
|
"factor": factor,
|
||||||
|
"value": (simple_num, simple_den),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -152,4 +157,38 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
class FractionSimplificationCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(FractionSimplificationCurriculum.__name__, FractionSimplificationConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="value",
|
||||||
|
levels=[1, 100, 1000, 10000],
|
||||||
|
default_level=1,
|
||||||
|
description="Value range for numerator and denominator",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=1,
|
||||||
|
lower_field_name="min_value",
|
||||||
|
upper_field_name="max_value",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="factor",
|
||||||
|
levels=[1, 10, 100, 1000],
|
||||||
|
default_level=1,
|
||||||
|
description="Factor range for generating unsimplified fractions",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=1,
|
||||||
|
lower_field_name="min_factor",
|
||||||
|
upper_field_name="max_factor",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
"fraction_simplification",
|
||||||
|
FractionSimplificationDataset,
|
||||||
|
FractionSimplificationConfig,
|
||||||
|
FractionSimplificationCurriculum,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,11 @@ from math import gcd
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset
|
from reasoning_gym.arithmetic import (
|
||||||
|
FractionSimplificationConfig,
|
||||||
|
FractionSimplificationCurriculum,
|
||||||
|
FractionSimplificationDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_fraction_config_validation():
|
def test_fraction_config_validation():
|
||||||
|
|
@ -129,3 +133,28 @@ def test_fraction_numerator_smaller():
|
||||||
assert (
|
assert (
|
||||||
metadata["simplified_numerator"] <= metadata["simplified_denominator"]
|
metadata["simplified_numerator"] <= metadata["simplified_denominator"]
|
||||||
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fraction_simplification_curriculum():
|
||||||
|
curriculum = FractionSimplificationCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: FractionSimplificationConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_value == 1 and base_cfg.max_value == 100
|
||||||
|
assert base_cfg.min_factor == 1 and base_cfg.max_factor == 10
|
||||||
|
|
||||||
|
# test incrementing attribute levels
|
||||||
|
curriculum.increment_attr_level("value")
|
||||||
|
curriculum.increment_attr_level("factor")
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_value == 1 and increased_cfg.max_value == 1000
|
||||||
|
assert increased_cfg.min_factor == 1 and increased_cfg.max_factor == 100
|
||||||
|
|
||||||
|
# test decrementing attribute level for value again
|
||||||
|
curriculum.decrement_attr_level("value")
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_value == 1 and partially_decreased_cfg.max_value == 100
|
||||||
|
assert partially_decreased_cfg.min_factor == 1 and partially_decreased_cfg.max_factor == 100
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue