gcd curriculum (#331)

This commit is contained in:
vncntt 2025-03-11 00:25:24 -07:00 committed by GitHub
parent 126eecc798
commit c3c6cc8051
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 78 additions and 3 deletions

View file

@ -11,7 +11,7 @@ from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurric
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset
from .dice import DiceConfig, DiceCurriculum, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
from .gcd import GCDConfig, GCDCurriculum, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset
@ -34,6 +34,7 @@ __all__ = [
"FractionSimplificationDataset",
"GCDConfig",
"GCDDataset",
"GCDCurriculum",
"LCMConfig",
"LCMDataset",
"LegCountingConfig",

View file

@ -6,6 +6,7 @@ from math import gcd
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -54,13 +55,50 @@ class GCDDataset(ProceduralDataset):
rng = Random(self.seed + idx)
numbers, result = self._generate_numbers(rng)
num_terms = len(numbers)
numbers_str = ", ".join(str(n) for n in numbers)
return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
"answer": str(result),
"metadata": {"numbers": numbers, "result": result},
"metadata": {
"numbers": numbers,
"result": result,
"difficulty": {
"num_terms": num_terms,
"max_value": self.config.max_value,
},
},
}
class GCDCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(GCDCurriculum.__name__, GCDConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0,
description="number of terms",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_numbers",
upper_field_name="max_numbers",
),
RangeAttributeDefinition(
name="max_value",
levels=[100, 1000, 10000, 100000],
default_level=0,
description="maximum value",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_value",
upper_field_name="max_value",
),
)
register_dataset("gcd", GCDDataset, GCDConfig)

View file

@ -3,7 +3,7 @@ from math import gcd
import pytest
from reasoning_gym.arithmetic import GCDConfig, GCDDataset
from reasoning_gym.arithmetic import GCDConfig, GCDCurriculum, GCDDataset
def test_gcd_config_validation():
@ -115,3 +115,39 @@ def test_gcd_special_cases():
# With enough samples, we should see both coprime and non-coprime numbers
assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)"
assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)"
def test_gcd_curriculum():
"""Test that curriculum generates correct items"""
curriculum = GCDCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: GCDConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_numbers == 2 and base_cfg.max_numbers == 2
assert base_cfg.min_value == 100 and base_cfg.max_value == 100
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 3
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 1000
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 4
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 10000
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("max_value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 5
assert increased_cfg.min_value == 100 and increased_cfg.max_value == 100000
curriculum.decrement_attr_level("num_terms")
curriculum.decrement_attr_level("max_value")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_numbers == 2 and decreased_cfg.max_numbers == 4
assert decreased_cfg.min_value == 100 and decreased_cfg.max_value == 10000