Refactor Curriculum Attributes (#335)

* remove min_value from AttributeDefinition
* remove type from AttributeDefinition
* Add CurriculumContext
* add ensure_interval option for RangeAttributes
* docs: Add legend explaining curriculum indicators in dataset gallery
* update GALLERY.md
This commit is contained in:
Andreas Köpf 2025-03-16 15:40:28 +01:00 committed by GitHub
parent 4e7d9296ee
commit d2c895f1d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 286 additions and 677 deletions

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -240,20 +240,14 @@ class BasicArithmeticCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="num_terms",
levels=[2, 5, 10, 20],
default_level=0,
description="Number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 5, 10],
default_level=0,
description="Number of digits in the numbers",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_digits",
upper_field_name="max_digits",
),

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -186,10 +186,7 @@ class BitwiseArithmeticCurriculum(BaseCurriculum):
ScalarAttributeDefinition(
name="difficulty",
levels=[1, 2, 3, 4],
default_level=0,
description="Range of difficulty levels",
attr_type=AttributeType.STATIC,
min_value=1,
field_name="difficulty",
),
)

View file

@ -6,7 +6,7 @@ from datetime import date, timedelta
from enum import Enum, StrEnum, auto
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -511,17 +511,13 @@ class CalendarArithmeticCurriculum(BaseCurriculum):
"recurring_event_day",
],
],
default_level=0,
description="Controls which calendar tasks are included",
attr_type=AttributeType.STATIC,
field_name="tasks",
),
ScalarAttributeDefinition(
name="date_range",
levels=[30, 100, 250, 365],
default_level=0,
description="Maximum day range for offset and counting tasks",
attr_type=AttributeType.STATIC,
field_name="offset_upper_bound",
),
)

View file

@ -4,7 +4,7 @@ from typing import Any, Optional
from reasoning_gym import utils
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -126,8 +126,6 @@ class ChainSumCurriculum(BaseCurriculum):
levels=list(range(2, 13)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2, # Ensure at least 2 terms
lower_field_name="min_terms",
upper_field_name="max_terms",
),
@ -136,8 +134,6 @@ class ChainSumCurriculum(BaseCurriculum):
levels=list(range(1, 11)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure numbers are at least 1 digit
lower_field_name="min_digits",
upper_field_name="max_digits",
),

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
@ -60,10 +60,7 @@ class CountBitsCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="n",
levels=[1_000, 1_000_000, 100_000_000, 2**31 - 1],
default_level=0,
description="Number to count bits in",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_n",
upper_field_name="max_n",
),

View file

@ -4,7 +4,7 @@ from decimal import ROUND_HALF_UP, Decimal, getcontext
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -232,20 +232,14 @@ class DecimalArithmeticCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="decimal_places",
levels=[3, 5, 8, 10],
default_level=0,
description="Number of decimal places of the numbers in problem",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_num_decimal_places",
upper_field_name="max_num_decimal_places",
),
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 6],
default_level=0,
description="Number of terms in the arithmetic expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -170,10 +170,7 @@ class DecimalChainSumCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0,
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),
@ -182,18 +179,13 @@ class DecimalChainSumCurriculum(BaseCurriculum):
levels=[1, 2, 4, 10],
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_digits",
upper_field_name="max_digits",
),
RangeAttributeDefinition(
name="decimal_places",
levels=[1, 2, 3, 4],
default_level=0,
description="Number of decimal places in each operand",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_decimal_places",
upper_field_name="max_decimal_places",
),

View file

@ -4,7 +4,7 @@ from math import gcd
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -162,19 +162,13 @@ class DiceCurriculum(BaseCurriculum):
ScalarAttributeDefinition(
name="num_dice",
levels=[4, 5, 6, 7],
default_level=0,
description="Number of dice to roll",
attr_type=AttributeType.STATIC,
min_value=4,
field_name="num_dice",
),
ScalarAttributeDefinition(
name="max_dice_size",
levels=[20, 25, 30, 35],
default_level=0,
description="Maximum number of sides on any die",
attr_type=AttributeType.STATIC,
min_value=20,
field_name="max_dice_size",
),
)

View file

@ -6,7 +6,7 @@ from math import gcd
from random import Random
from typing import Any, Optional, Sequence
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
@ -166,22 +166,18 @@ class FractionSimplificationCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="value",
levels=[1, 100, 1000, 10000],
default_level=1,
description="Value range for numerator and denominator",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_value",
upper_field_name="max_value",
ensure_interval=True,
),
RangeAttributeDefinition(
name="factor",
levels=[1, 10, 100, 1000],
default_level=1,
description="Factor range for generating unsimplified fractions",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_factor",
upper_field_name="max_factor",
ensure_interval=True,
),
)

View file

@ -6,7 +6,7 @@ from math import gcd
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -81,20 +81,14 @@ class GCDCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0,
description="number of terms",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_numbers",
upper_field_name="max_numbers",
),
RangeAttributeDefinition(
name="max_value",
levels=[100, 1000, 10000, 100000],
default_level=0,
description="maximum value",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_value",
upper_field_name="max_value",
),

View file

@ -6,7 +6,7 @@ from math import lcm
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -83,22 +83,17 @@ class LCMCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="numbers",
levels=[2, 4, 6, 8, 10],
default_level=0,
description="Number of integers to find LCM of",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_numbers",
upper_field_name="max_numbers",
),
RangeAttributeDefinition(
name="value",
levels=[1, 100, 500, 1000, 5000],
default_level=1,
description="Range of values for each integer",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_value",
upper_field_name="max_value",
ensure_interval=True,
),
)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
ANIMALS = {
@ -136,20 +136,14 @@ class LegCountingCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="num_animals",
levels=list(range(1, 20)),
default_level=0,
description="Number of animals in question",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure at least 1 animal
lower_field_name="min_animals",
upper_field_name="max_animals",
),
RangeAttributeDefinition(
name="num_instances",
levels=[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024],
default_level=0,
description="Number of instances of each animal",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_instances",
upper_field_name="max_instances",
),

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to pick the largest/smallest number out of several options.
@ -115,31 +115,24 @@ class NumberFormatCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="num_candidates",
levels=[5, 25, 100, 500],
default_level=1,
description="Number of candidates",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_num_candidates",
upper_field_name="max_num_candidates",
ensure_interval=True,
),
RangeAttributeDefinition(
name="n",
levels=[10, 1_000, 1_000_000, 1_000_000_000],
default_level=1,
description="Magnitude of the values",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_n",
upper_field_name="max_n",
ensure_interval=True,
),
ScalarAttributeDefinition(
name="max_delta",
field_name="max_delta",
levels=[1e1, 1e0, 1e-3, 1e-6],
default_level=0,
description="Max delta",
attr_type=AttributeType.STATIC,
min_value=1e-6,
),
)

View file

@ -6,7 +6,7 @@ from math import pow
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
@ -84,9 +84,6 @@ class PowerFunctionCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="exponent",
levels=[2, 4, 6, 10],
default_level=0,
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_exponent",
upper_field_name="max_exponent",
),

View file

@ -5,7 +5,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -96,12 +96,10 @@ class PrimeFactorizationCurriculum(BaseCurriculum):
RangeAttributeDefinition(
name="value",
levels=[10, 1_000, 10_000, 50_000],
default_level=1,
description="Number to factorize",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_value",
upper_field_name="max_value",
ensure_interval=True,
)
)

View file

@ -4,7 +4,7 @@ from typing import Any, Optional
from reasoning_gym import utils
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -118,8 +118,6 @@ class ProductsCurriculum(BaseCurriculum):
levels=list(range(2, 13)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2, # Ensure at least 2 terms
lower_field_name="min_terms",
upper_field_name="max_terms",
),
@ -128,8 +126,6 @@ class ProductsCurriculum(BaseCurriculum):
levels=list(range(1, 11)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure numbers are at least 1 digit
lower_field_name="min_digits",
upper_field_name="max_digits",
),

View file

@ -6,7 +6,7 @@ from typing import Optional
import pytz
from dateutil import parser
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -334,19 +334,13 @@ class TimeIntervalsCurriculum(BaseCurriculum):
name="max_time_difference_seconds",
field_name="max_time_difference_seconds",
levels=[60, 24 * 60 * 60, 7 * 24 * 60 * 60, 30 * 24 * 60 * 60, 365 * 24 * 60 * 60],
default_level=0,
description="Maximum time difference in seconds",
attr_type=AttributeType.STATIC,
min_value=1,
),
ScalarAttributeDefinition(
name="max_date_difference_days",
field_name="max_date_difference_days",
levels=[1, 7, 30, 365, 5 * 365],
default_level=0,
description="Maximum date difference in days",
attr_type=AttributeType.STATIC,
min_value=1,
),
)