mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
Refactor Curriculum Attributes (#335)
* remove min_value from AttributeDefinition * remove type from AttributeDefinition * Add CurriculumContext * add ensure_interval option for RangeAttributes * docs: Add legend explaining curriculum indicators in dataset gallery * update GALLERY.md
This commit is contained in:
parent
4e7d9296ee
commit
d2c895f1d3
101 changed files with 286 additions and 677 deletions
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -240,20 +240,14 @@ class BasicArithmeticCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 5, 10, 20],
|
||||
default_level=0,
|
||||
description="Number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 5, 10],
|
||||
default_level=0,
|
||||
description="Number of digits in the numbers",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -186,10 +186,7 @@ class BitwiseArithmeticCurriculum(BaseCurriculum):
|
|||
ScalarAttributeDefinition(
|
||||
name="difficulty",
|
||||
levels=[1, 2, 3, 4],
|
||||
default_level=0,
|
||||
description="Range of difficulty levels",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=1,
|
||||
field_name="difficulty",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from datetime import date, timedelta
|
|||
from enum import Enum, StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -511,17 +511,13 @@ class CalendarArithmeticCurriculum(BaseCurriculum):
|
|||
"recurring_event_day",
|
||||
],
|
||||
],
|
||||
default_level=0,
|
||||
description="Controls which calendar tasks are included",
|
||||
attr_type=AttributeType.STATIC,
|
||||
field_name="tasks",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="date_range",
|
||||
levels=[30, 100, 250, 365],
|
||||
default_level=0,
|
||||
description="Maximum day range for offset and counting tasks",
|
||||
attr_type=AttributeType.STATIC,
|
||||
field_name="offset_upper_bound",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from reasoning_gym import utils
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -126,8 +126,6 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
levels=list(range(2, 13)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2, # Ensure at least 2 terms
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
|
|
@ -136,8 +134,6 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
levels=list(range(1, 11)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure numbers are at least 1 digit
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
|
||||
|
|
@ -60,10 +60,7 @@ class CountBitsCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="n",
|
||||
levels=[1_000, 1_000_000, 100_000_000, 2**31 - 1],
|
||||
default_level=0,
|
||||
description="Number to count bits in",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_n",
|
||||
upper_field_name="max_n",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from decimal import ROUND_HALF_UP, Decimal, getcontext
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -232,20 +232,14 @@ class DecimalArithmeticCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="decimal_places",
|
||||
levels=[3, 5, 8, 10],
|
||||
default_level=0,
|
||||
description="Number of decimal places of the numbers in problem",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=3,
|
||||
lower_field_name="min_num_decimal_places",
|
||||
upper_field_name="max_num_decimal_places",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 6],
|
||||
default_level=0,
|
||||
description="Number of terms in the arithmetic expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -170,10 +170,7 @@ class DecimalChainSumCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0,
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
|
|
@ -182,18 +179,13 @@ class DecimalChainSumCurriculum(BaseCurriculum):
|
|||
levels=[1, 2, 4, 10],
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="decimal_places",
|
||||
levels=[1, 2, 3, 4],
|
||||
default_level=0,
|
||||
description="Number of decimal places in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_decimal_places",
|
||||
upper_field_name="max_decimal_places",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -162,19 +162,13 @@ class DiceCurriculum(BaseCurriculum):
|
|||
ScalarAttributeDefinition(
|
||||
name="num_dice",
|
||||
levels=[4, 5, 6, 7],
|
||||
default_level=0,
|
||||
description="Number of dice to roll",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=4,
|
||||
field_name="num_dice",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_dice_size",
|
||||
levels=[20, 25, 30, 35],
|
||||
default_level=0,
|
||||
description="Maximum number of sides on any die",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=20,
|
||||
field_name="max_dice_size",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||
|
|
@ -166,22 +166,18 @@ class FractionSimplificationCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[1, 100, 1000, 10000],
|
||||
default_level=1,
|
||||
description="Value range for numerator and denominator",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
ensure_interval=True,
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="factor",
|
||||
levels=[1, 10, 100, 1000],
|
||||
default_level=1,
|
||||
description="Factor range for generating unsimplified fractions",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_factor",
|
||||
upper_field_name="max_factor",
|
||||
ensure_interval=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -81,20 +81,14 @@ class GCDCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0,
|
||||
description="number of terms",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_numbers",
|
||||
upper_field_name="max_numbers",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="max_value",
|
||||
levels=[100, 1000, 10000, 100000],
|
||||
default_level=0,
|
||||
description="maximum value",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import lcm
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -83,22 +83,17 @@ class LCMCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="numbers",
|
||||
levels=[2, 4, 6, 8, 10],
|
||||
default_level=0,
|
||||
description="Number of integers to find LCM of",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_numbers",
|
||||
upper_field_name="max_numbers",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[1, 100, 500, 1000, 5000],
|
||||
default_level=1,
|
||||
description="Range of values for each integer",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
ensure_interval=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
ANIMALS = {
|
||||
|
|
@ -136,20 +136,14 @@ class LegCountingCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="num_animals",
|
||||
levels=list(range(1, 20)),
|
||||
default_level=0,
|
||||
description="Number of animals in question",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure at least 1 animal
|
||||
lower_field_name="min_animals",
|
||||
upper_field_name="max_animals",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_instances",
|
||||
levels=[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024],
|
||||
default_level=0,
|
||||
description="Number of instances of each animal",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_instances",
|
||||
upper_field_name="max_instances",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Your task is to pick the largest/smallest number out of several options.
|
||||
|
|
@ -115,31 +115,24 @@ class NumberFormatCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="num_candidates",
|
||||
levels=[5, 25, 100, 500],
|
||||
default_level=1,
|
||||
description="Number of candidates",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_num_candidates",
|
||||
upper_field_name="max_num_candidates",
|
||||
ensure_interval=True,
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="n",
|
||||
levels=[10, 1_000, 1_000_000, 1_000_000_000],
|
||||
default_level=1,
|
||||
description="Magnitude of the values",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_n",
|
||||
upper_field_name="max_n",
|
||||
ensure_interval=True,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_delta",
|
||||
field_name="max_delta",
|
||||
levels=[1e1, 1e0, 1e-3, 1e-6],
|
||||
default_level=0,
|
||||
description="Max delta",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=1e-6,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import pow
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
|
||||
|
|
@ -84,9 +84,6 @@ class PowerFunctionCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="exponent",
|
||||
levels=[2, 4, 6, 10],
|
||||
default_level=0,
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_exponent",
|
||||
upper_field_name="max_exponent",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -96,12 +96,10 @@ class PrimeFactorizationCurriculum(BaseCurriculum):
|
|||
RangeAttributeDefinition(
|
||||
name="value",
|
||||
levels=[10, 1_000, 10_000, 50_000],
|
||||
default_level=1,
|
||||
description="Number to factorize",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
ensure_interval=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from reasoning_gym import utils
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -118,8 +118,6 @@ class ProductsCurriculum(BaseCurriculum):
|
|||
levels=list(range(2, 13)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2, # Ensure at least 2 terms
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
|
|
@ -128,8 +126,6 @@ class ProductsCurriculum(BaseCurriculum):
|
|||
levels=list(range(1, 11)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure numbers are at least 1 digit
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
import pytz
|
||||
from dateutil import parser
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -334,19 +334,13 @@ class TimeIntervalsCurriculum(BaseCurriculum):
|
|||
name="max_time_difference_seconds",
|
||||
field_name="max_time_difference_seconds",
|
||||
levels=[60, 24 * 60 * 60, 7 * 24 * 60 * 60, 30 * 24 * 60 * 60, 365 * 24 * 60 * 60],
|
||||
default_level=0,
|
||||
description="Maximum time difference in seconds",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=1,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_date_difference_days",
|
||||
field_name="max_date_difference_days",
|
||||
levels=[1, 7, 30, 365, 5 * 365],
|
||||
default_level=0,
|
||||
description="Maximum date difference in days",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue