Add curriculum to ab dataset (#345)

* Add curriculum to ab dataset

* Add difficulty to metadata
This commit is contained in:
Adefioye 2025-03-13 15:03:02 -05:00 committed by GitHub
parent 596de511f0
commit 4ec1154b47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 66 additions and 4 deletions

View file

@ -6,7 +6,7 @@ Algorithmic tasks for training reasoning capabilities:
- Pattern matching
"""
from .ab import ABConfig, ABDataset
from .ab import ABConfig, ABCurriculum, ABDataset
from .base_conversion import BaseConversionConfig, BaseConversionCurriculum, BaseConversionDataset
from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixCurriculum, BinaryMatrixDataset
@ -121,6 +121,7 @@ __all__ = [
"PoolMatrixCurriculum",
"ABConfig",
"ABDataset",
"ABCurriculum",
"CountPrimesConfig",
"CountPrimesDataset",
"CountPrimesCurriculum",

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -114,7 +115,11 @@ Return the final state of the program.
return {
"question": prompt,
"answer": " ".join(steps[-1]),
"metadata": {},
"metadata": {
"difficulty": {
"length": self.config.length,
}
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -135,5 +140,25 @@ Return the final state of the program.
return 0.0
class ABCurriculum(BaseCurriculum):
"""Curriculum for A::B dataset"""
def __init__(self):
super().__init__(ABCurriculum.__name__, ABConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="length",
field_name="length",
levels=[1, 10, 50, 100],
default_level=0,
description="Length of the A::B program",
attr_type=AttributeType.STATIC,
min_value=1,
)
)
# Register the dataset
register_dataset("ab", ABDataset, ABConfig)
register_dataset("ab", ABDataset, ABConfig, ABCurriculum)

View file

@ -2,7 +2,7 @@ import random
import pytest
from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program
from reasoning_gym.algorithmic.ab import ABConfig, ABCurriculum, ABDataset, compute_steps, generate_program
def test_ab_config_validation():
@ -98,3 +98,39 @@ def test_ab_item_structure():
# Test answer format
answer_tokens = item["answer"].split()
assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens)
def test_ab_curriculum():
"""Test the curriculum ab dataset."""
curriculum = ABCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.length == 1
# Test and validate increase in levels
curriculum.increment_attr_level("length")
increase_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert increase_cfg.length == 10
# Test and validate decrease in levels
curriculum.decrement_attr_level("length")
decrease_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert decrease_cfg.length == 1
# Test upper bound boundary condition
for _ in range(10):
curriculum.increment_attr_level("length")
upper_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.length == 100
# Test lower bound boundary condition
for _ in range(10):
curriculum.decrement_attr_level("length")
lower_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.length == 1