diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index bb25e09a..f886e4a0 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -10,7 +10,11 @@ from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurriculum, DecimalArithmeticDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset from .dice import DiceConfig, DiceCurriculum, DiceDataset -from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset +from .fraction_simplification import ( + FractionSimplificationConfig, + FractionSimplificationCurriculum, + FractionSimplificationDataset, +) from .gcd import GCDConfig, GCDCurriculum, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMCurriculum, LCMDataset @@ -32,6 +36,7 @@ __all__ = [ "CalendarArithmeticCurriculum", "FractionSimplificationConfig", "FractionSimplificationDataset", + "FractionSimplificationCurriculum", "GCDConfig", "GCDDataset", "GCDCurriculum", diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index d0cc2cb8..e06d4519 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -6,6 +6,7 @@ from math import gcd from random import Random from typing import Any, Optional, Sequence +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer." @@ -42,7 +43,7 @@ class FractionSimplificationDataset(ProceduralDataset): def __init__(self, config: FractionSimplificationConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]: + def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int, int]: """Generate a random fraction and its simplified form. Returns (numerator, denominator, simplified_num, simplified_den)""" # Try to generate valid fractions until we get one that meets our criteria @@ -69,7 +70,7 @@ class FractionSimplificationDataset(ProceduralDataset): factor = rng.randint(self.config.min_factor, self.config.max_factor) numerator = simplified_num * factor denominator = simplified_den * factor - return numerator, denominator, simplified_num, simplified_den + return numerator, denominator, simplified_num, simplified_den, factor # If we failed to find a good fraction after max attempts, # generate one that's guaranteed to be within bounds @@ -81,7 +82,7 @@ class FractionSimplificationDataset(ProceduralDataset): simplified_num, simplified_den = simplified_den, simplified_num factor = rng.randint(self.config.min_factor, self.config.max_factor) - return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den) + return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den, factor) def _format_fraction(self, num: int, den: int, style: str = "plain") -> str: """Format a fraction in various styles""" @@ -100,7 +101,7 @@ class FractionSimplificationDataset(ProceduralDataset): """Generate a single fraction simplification task""" rng = Random(self.seed + idx) - num, den, simple_num, simple_den = self._generate_fraction(rng) + num, den, simple_num, simple_den, factor = self._generate_fraction(rng) # Choose a random style from configured styles style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)] @@ -119,6 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset): "simplified_denominator": simple_den, "reduction_factor": num // simple_num, # Will be same as den // simple_den "style": style, + "difficulty": { + "factor": factor, + "value": (simple_num, simple_den), + }, }, } @@ -152,4 +157,38 @@ class FractionSimplificationDataset(ProceduralDataset): return reward -register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig) +class FractionSimplificationCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(FractionSimplificationCurriculum.__name__, FractionSimplificationConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="value", + levels=[1, 100, 1000, 10000], + default_level=1, + description="Value range for numerator and denominator", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_value", + upper_field_name="max_value", + ), + RangeAttributeDefinition( + name="factor", + levels=[1, 10, 100, 1000], + default_level=1, + description="Factor range for generating unsimplified fractions", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_factor", + upper_field_name="max_factor", + ), + ) + + +register_dataset( + "fraction_simplification", + FractionSimplificationDataset, + FractionSimplificationConfig, + FractionSimplificationCurriculum, +) diff --git a/tests/test_fraction_simplification.py b/tests/test_fraction_simplification.py index 4b399e8f..49437d1f 100644 --- a/tests/test_fraction_simplification.py +++ b/tests/test_fraction_simplification.py @@ -2,7 +2,11 @@ from math import gcd import pytest -from reasoning_gym.arithmetic import FractionSimplificationConfig, FractionSimplificationDataset +from reasoning_gym.arithmetic import ( + FractionSimplificationConfig, + FractionSimplificationCurriculum, + FractionSimplificationDataset, +) def test_fraction_config_validation(): @@ -129,3 +133,28 @@ def test_fraction_numerator_smaller(): assert ( metadata["simplified_numerator"] <= metadata["simplified_denominator"] ), f"Simplified numerator {metadata['simplified_numerator']} should be <= denominator {metadata['simplified_denominator']}" + + +def test_fraction_simplification_curriculum(): + curriculum = FractionSimplificationCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: FractionSimplificationConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_value == 1 and base_cfg.max_value == 100 + assert base_cfg.min_factor == 1 and base_cfg.max_factor == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("value") + curriculum.increment_attr_level("factor") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_value == 1 and increased_cfg.max_value == 1000 + assert increased_cfg.min_factor == 1 and increased_cfg.max_factor == 100 + + # test decrementing attribute level for value again + curriculum.decrement_attr_level("value") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_value == 1 and partially_decreased_cfg.max_value == 100 + assert partially_decreased_cfg.min_factor == 1 and partially_decreased_cfg.max_factor == 100