mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
gcd curriculum (#331)
This commit is contained in:
parent
126eecc798
commit
c3c6cc8051
3 changed files with 78 additions and 3 deletions
|
|
@ -11,7 +11,7 @@ from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurric
|
|||
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
|
||||
from .dice import DiceConfig, DiceCurriculum, DiceDataset
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .gcd import GCDConfig, GCDDataset
|
||||
from .gcd import GCDConfig, GCDCurriculum, GCDDataset
|
||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||
from .lcm import LCMConfig, LCMDataset
|
||||
from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset
|
||||
|
|
@ -34,6 +34,7 @@ __all__ = [
|
|||
"FractionSimplificationDataset",
|
||||
"GCDConfig",
|
||||
"GCDDataset",
|
||||
"GCDCurriculum",
|
||||
"LCMConfig",
|
||||
"LCMDataset",
|
||||
"LegCountingConfig",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -54,13 +55,50 @@ class GCDDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
|
||||
numbers, result = self._generate_numbers(rng)
|
||||
num_terms = len(numbers)
|
||||
numbers_str = ", ".join(str(n) for n in numbers)
|
||||
|
||||
return {
|
||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {"numbers": numbers, "result": result},
|
||||
"metadata": {
|
||||
"numbers": numbers,
|
||||
"result": result,
|
||||
"difficulty": {
|
||||
"num_terms": num_terms,
|
||||
"max_value": self.config.max_value,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class GCDCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(GCDCurriculum.__name__, GCDConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0,
|
||||
description="number of terms",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_numbers",
|
||||
upper_field_name="max_numbers",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="max_value",
|
||||
levels=[100, 1000, 10000, 100000],
|
||||
default_level=0,
|
||||
description="maximum value",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("gcd", GCDDataset, GCDConfig)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from math import gcd
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
|
||||
from reasoning_gym.arithmetic import GCDConfig, GCDCurriculum, GCDDataset
|
||||
|
||||
|
||||
def test_gcd_config_validation():
|
||||
|
|
@ -115,3 +115,39 @@ def test_gcd_special_cases():
|
|||
# With enough samples, we should see both coprime and non-coprime numbers
|
||||
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
|
||||
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"
|
||||
|
||||
|
||||
def test_gcd_curriculum():
|
||||
"""Test that curriculum generates correct items"""
|
||||
curriculum = GCDCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
base_cfg: GCDConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_numbers == 2 and base_cfg.max_numbers == 2
|
||||
assert base_cfg.min_value == 100 and base_cfg.max_value == 100
|
||||
|
||||
curriculum.increment_attr_level("num_terms")
|
||||
curriculum.increment_attr_level("max_value")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 3
|
||||
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 1000
|
||||
|
||||
curriculum.increment_attr_level("num_terms")
|
||||
curriculum.increment_attr_level("max_value")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 4
|
||||
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 10000
|
||||
|
||||
curriculum.increment_attr_level("num_terms")
|
||||
curriculum.increment_attr_level("max_value")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 5
|
||||
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 100000
|
||||
|
||||
curriculum.decrement_attr_level("num_terms")
|
||||
curriculum.decrement_attr_level("max_value")
|
||||
decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert decreased_cfg.min_numbers == 2 and decreased_cfg.max_numbers == 4
|
||||
assert decreased_cfg.min_value == 100 and decreased_cfg.max_value == 10000
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue