fraction simplification curriculum (#349)

This commit is contained in:
Zafir Stojanovski 2025-03-13 21:05:50 +01:00 committed by GitHub
parent db3868150f
commit ee001f38a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 80 additions and 7 deletions

View file

@ -10,7 +10,11 @@ from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
from .dice import DiceConfig, DiceCurriculum, DiceDataset from .dice import DiceConfig, DiceCurriculum, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import (
FractionSimplificationConfig,
FractionSimplificationCurriculum,
FractionSimplificationDataset,
)
from .gcd import GCDConfig, GCDCurriculum, GCDDataset from .gcd import GCDConfig, GCDCurriculum, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .lcm import LCMConfig, LCMCurriculum, LCMDataset from .lcm import LCMConfig, LCMCurriculum, LCMDataset
@ -32,6 +36,7 @@ __all__ = [
"CalendarArithmeticCurriculum", "CalendarArithmeticCurriculum",
"FractionSimplificationConfig", "FractionSimplificationConfig",
"FractionSimplificationDataset", "FractionSimplificationDataset",
"FractionSimplificationCurriculum",
"GCDConfig", "GCDConfig",
"GCDDataset", "GCDDataset",
"GCDCurriculum", "GCDCurriculum",

View file

@ -6,6 +6,7 @@ from math import gcd
from random import Random from random import Random
from typing import Any, Optional, Sequence from typing import Any, Optional, Sequence
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer." QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
@ -42,7 +43,7 @@ class FractionSimplificationDataset(ProceduralDataset):
def __init__(self, config: FractionSimplificationConfig): def __init__(self, config: FractionSimplificationConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]: def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int, int]:
"""Generate a random fraction and its simplified form. """Generate a random fraction and its simplified form.
Returns (numerator, denominator, simplified_num, simplified_den)""" Returns (numerator, denominator, simplified_num, simplified_den)"""
# Try to generate valid fractions until we get one that meets our criteria # Try to generate valid fractions until we get one that meets our criteria
@ -69,7 +70,7 @@ class FractionSimplificationDataset(ProceduralDataset):
factor = rng.randint(self.config.min_factor, self.config.max_factor) factor = rng.randint(self.config.min_factor, self.config.max_factor)
numerator = simplified_num * factor numerator = simplified_num * factor
denominator = simplified_den * factor denominator = simplified_den * factor
return numerator, denominator, simplified_num, simplified_den return numerator, denominator, simplified_num, simplified_den, factor
# If we failed to find a good fraction after max attempts, # If we failed to find a good fraction after max attempts,
# generate one that's guaranteed to be within bounds # generate one that's guaranteed to be within bounds
@ -81,7 +82,7 @@ class FractionSimplificationDataset(ProceduralDataset):
simplified_num, simplified_den = simplified_den, simplified_num simplified_num, simplified_den = simplified_den, simplified_num
factor = rng.randint(self.config.min_factor, self.config.max_factor) factor = rng.randint(self.config.min_factor, self.config.max_factor)
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den) return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den, factor)
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
"""Format a fraction in various styles""" """Format a fraction in various styles"""
@ -100,7 +101,7 @@ class FractionSimplificationDataset(ProceduralDataset):
"""Generate a single fraction simplification task""" """Generate a single fraction simplification task"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
num, den, simple_num, simple_den = self._generate_fraction(rng) num, den, simple_num, simple_den, factor = self._generate_fraction(rng)
# Choose a random style from configured styles # Choose a random style from configured styles
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)] style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
@ -119,6 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
"simplified_denominator": simple_den, "simplified_denominator": simple_den,
"reduction_factor": num // simple_num, # Will be same as den // simple_den "reduction_factor": num // simple_num, # Will be same as den // simple_den
"style": style, "style": style,
"difficulty": {
"factor": factor,
"value": (simple_num, simple_den),
},
}, },
} }
@ -152,4 +157,38 @@ class FractionSimplificationDataset(ProceduralDataset):
return reward return reward
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig) class FractionSimplificationCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FractionSimplificationCurriculum.__name__, FractionSimplificationConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="value",
levels=[1, 100, 1000, 10000],
default_level=1,
description="Value range for numerator and denominator",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_value",
upper_field_name="max_value",
),
RangeAttributeDefinition(
name="factor",
levels=[1, 10, 100, 1000],
default_level=1,
description="Factor range for generating unsimplified fractions",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_factor",
upper_field_name="max_factor",
),
)
register_dataset(
"fraction_simplification",
FractionSimplificationDataset,
FractionSimplificationConfig,
FractionSimplificationCurriculum,
)

View file

@ -2,7 +2,11 @@ from math import gcd
import pytest import pytest
from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset from reasoning_gym.arithmetic import (
FractionSimplificationConfig,
FractionSimplificationCurriculum,
FractionSimplificationDataset,
)
def test_fraction_config_validation(): def test_fraction_config_validation():
@ -129,3 +133,28 @@ def test_fraction_numerator_smaller():
assert ( assert (
metadata["simplified_numerator"] <= metadata["simplified_denominator"] metadata["simplified_numerator"] <= metadata["simplified_denominator"]
), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" ), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}"
def test_fraction_simplification_curriculum():
curriculum = FractionSimplificationCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: FractionSimplificationConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_value == 1 and base_cfg.max_value == 100
assert base_cfg.min_factor == 1 and base_cfg.max_factor == 10
# test incrementing attribute levels
curriculum.increment_attr_level("value")
curriculum.increment_attr_level("factor")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_value == 1 and increased_cfg.max_value == 1000
assert increased_cfg.min_factor == 1 and increased_cfg.max_factor == 100
# test decrementing attribute level for value again
curriculum.decrement_attr_level("value")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_value == 1 and partially_decreased_cfg.max_value == 100
assert partially_decreased_cfg.min_factor == 1 and partially_decreased_cfg.max_factor == 100