Add curriculum to ab dataset (#345)

* Add curriculum to ab dataset

* Add difficulty to metadata
This commit is contained in:
Adefioye 2025-03-13 15:03:02 -05:00 committed by GitHub
parent 4f45c8d655
commit 454250a4ea
3 changed files with 66 additions and 4 deletions

View file

@ -6,7 +6,7 @@ Algorithmic tasks for training reasoning capabilities:
- Pattern matching - Pattern matching
""" """
from .ab import ABConfig, ABDataset from .ab import ABConfig, ABCurriculum, ABDataset
from .base_conversion import BaseConversionConfig, BaseConversionCurriculum, BaseConversionDataset from .base_conversion import BaseConversionConfig, BaseConversionCurriculum, BaseConversionDataset
from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixCurriculum, BinaryMatrixDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixCurriculum, BinaryMatrixDataset
@ -121,6 +121,7 @@ __all__ = [
"PoolMatrixCurriculum", "PoolMatrixCurriculum",
"ABConfig", "ABConfig",
"ABDataset", "ABDataset",
"ABCurriculum",
"CountPrimesConfig", "CountPrimesConfig",
"CountPrimesDataset", "CountPrimesDataset",
"CountPrimesCurriculum", "CountPrimesCurriculum",

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -114,7 +115,11 @@ Return the final state of the program.
return { return {
"question": prompt, "question": prompt,
"answer": " ".join(steps[-1]), "answer": " ".join(steps[-1]),
"metadata": {}, "metadata": {
"difficulty": {
"length": self.config.length,
}
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -135,5 +140,25 @@ Return the final state of the program.
return 0.0 return 0.0
class ABCurriculum(BaseCurriculum):
"""Curriculum for A::B dataset"""
def __init__(self):
super().__init__(ABCurriculum.__name__, ABConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="length",
field_name="length",
levels=[1, 10, 50, 100],
default_level=0,
description="Length of the A::B program",
attr_type=AttributeType.STATIC,
min_value=1,
)
)
# Register the dataset # Register the dataset
register_dataset("ab", ABDataset, ABConfig) register_dataset("ab", ABDataset, ABConfig, ABCurriculum)

View file

@ -2,7 +2,7 @@ import random
import pytest import pytest
from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program from reasoning_gym.algorithmic.ab import ABConfig, ABCurriculum, ABDataset, compute_steps, generate_program
def test_ab_config_validation(): def test_ab_config_validation():
@ -98,3 +98,39 @@ def test_ab_item_structure():
# Test answer format # Test answer format
answer_tokens = item["answer"].split() answer_tokens = item["answer"].split()
assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens) assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens)
def test_ab_curriculum():
"""Test the curriculum ab dataset."""
curriculum = ABCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.length == 1
# Test and validate increase in levels
curriculum.increment_attr_level("length")
increase_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert increase_cfg.length == 10
# Test and validate decrease in levels
curriculum.decrement_attr_level("length")
decrease_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert decrease_cfg.length == 1
# Test upper bound boundary condition
for _ in range(10):
curriculum.increment_attr_level("length")
upper_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.length == 100
# Test lower bound boundary condition
for _ in range(10):
curriculum.decrement_attr_level("length")
lower_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.length == 1