mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
fraction simplification curriculum (#349)
This commit is contained in:
parent
db3868150f
commit
ee001f38a4
3 changed files with 80 additions and 7 deletions
|
|
@ -6,6 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||
|
|
@ -42,7 +43,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
def __init__(self, config: FractionSimplificationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int]:
|
||||
def _generate_fraction(self, rng: Random) -> tuple[int, int, int, int, int]:
|
||||
"""Generate a random fraction and its simplified form.
|
||||
Returns (numerator, denominator, simplified_num, simplified_den)"""
|
||||
# Try to generate valid fractions until we get one that meets our criteria
|
||||
|
|
@ -69,7 +70,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
numerator = simplified_num * factor
|
||||
denominator = simplified_den * factor
|
||||
return numerator, denominator, simplified_num, simplified_den
|
||||
return numerator, denominator, simplified_num, simplified_den, factor
|
||||
|
||||
# If we failed to find a good fraction after max attempts,
|
||||
# generate one that's guaranteed to be within bounds
|
||||
|
|
@ -81,7 +82,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
simplified_num, simplified_den = simplified_den, simplified_num
|
||||
|
||||
factor = rng.randint(self.config.min_factor, self.config.max_factor)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den)
|
||||
return (simplified_num * factor, simplified_den * factor, simplified_num, simplified_den, factor)
|
||||
|
||||
def _format_fraction(self, num: int, den: int, style: str = "plain") -> str:
|
||||
"""Format a fraction in various styles"""
|
||||
|
|
@ -100,7 +101,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"""Generate a single fraction simplification task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
num, den, simple_num, simple_den = self._generate_fraction(rng)
|
||||
num, den, simple_num, simple_den, factor = self._generate_fraction(rng)
|
||||
|
||||
# Choose a random style from configured styles
|
||||
style = self.config.styles[rng.randint(0, len(self.config.styles) - 1)]
|
||||
|
|
@ -119,6 +120,10 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"simplified_denominator": simple_den,
|
||||
"reduction_factor": num // simple_num, # Will be same as den // simple_den
|
||||
"style": style,
|
||||
"difficulty": {
|
||||
"factor": factor,
|
||||
"value": (simple_num, simple_den),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -152,4 +157,38 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
class FractionSimplificationCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(FractionSimplificationCurriculum.__name__, FractionSimplificationConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[1, 100, 1000, 10000],
|
||||
default_level=1,
|
||||
description="Value range for numerator and denominator",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="factor",
|
||||
levels=[1, 10, 100, 1000],
|
||||
default_level=1,
|
||||
description="Factor range for generating unsimplified fractions",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_factor",
|
||||
upper_field_name="max_factor",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"fraction_simplification",
|
||||
FractionSimplificationDataset,
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationCurriculum,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue