From 454250a4eae1fd22b2928cef0559d7ee6d8bd478 Mon Sep 17 00:00:00 2001 From: Adefioye <47661641+Adefioye@users.noreply.github.com> Date: Thu, 13 Mar 2025 15:03:02 -0500 Subject: [PATCH] Add curriculum to ab dataset (#345) * Add curriculum to ab dataset * Add difficulty to metadata --- reasoning_gym/algorithmic/__init__.py | 3 ++- reasoning_gym/algorithmic/ab.py | 29 ++++++++++++++++++-- tests/test_ab.py | 38 ++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 190bbf32..6620310a 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -6,7 +6,7 @@ Algorithmic tasks for training reasoning capabilities: - Pattern matching """ -from .ab import ABConfig, ABDataset +from .ab import ABConfig, ABCurriculum, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionCurriculum, BaseConversionDataset from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixCurriculum, BinaryMatrixDataset @@ -121,6 +121,7 @@ __all__ = [ "PoolMatrixCurriculum", "ABConfig", "ABDataset", + "ABCurriculum", "CountPrimesConfig", "CountPrimesDataset", "CountPrimesCurriculum", diff --git a/reasoning_gym/algorithmic/ab.py b/reasoning_gym/algorithmic/ab.py index 3e251d31..ff488a0e 100644 --- a/reasoning_gym/algorithmic/ab.py +++ b/reasoning_gym/algorithmic/ab.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -114,7 +115,11 @@ Return the final state of the program. return { "question": prompt, "answer": " ".join(steps[-1]), - "metadata": {}, + "metadata": { + "difficulty": { + "length": self.config.length, + } + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -135,5 +140,25 @@ Return the final state of the program. return 0.0 +class ABCurriculum(BaseCurriculum): + """Curriculum for A::B dataset""" + + def __init__(self): + super().__init__(ABCurriculum.__name__, ABConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="length", + field_name="length", + levels=[1, 10, 50, 100], + default_level=0, + description="Length of the A::B program", + attr_type=AttributeType.STATIC, + min_value=1, + ) + ) + + # Register the dataset -register_dataset("ab", ABDataset, ABConfig) +register_dataset("ab", ABDataset, ABConfig, ABCurriculum) diff --git a/tests/test_ab.py b/tests/test_ab.py index e63a07bc..aa3bb81f 100644 --- a/tests/test_ab.py +++ b/tests/test_ab.py @@ -2,7 +2,7 @@ import random import pytest -from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program +from reasoning_gym.algorithmic.ab import ABConfig, ABCurriculum, ABDataset, compute_steps, generate_program def test_ab_config_validation(): @@ -98,3 +98,39 @@ def test_ab_item_structure(): # Test answer format answer_tokens = item["answer"].split() assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens) + + +def test_ab_curriculum(): + """Test the curriculum ab dataset.""" + curriculum = ABCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: ABCurriculum = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.length == 1 + + # Test and validate increase in levels + curriculum.increment_attr_level("length") + increase_cfg: ABCurriculum = curriculum.generate_configuration(base_value) + + assert increase_cfg.length == 10 + + # Test and validate decrease in levels + curriculum.decrement_attr_level("length") + decrease_cfg: ABCurriculum = curriculum.generate_configuration(base_value) + + assert decrease_cfg.length == 1 + + # Test upper bound boundary condition + for _ in range(10): + curriculum.increment_attr_level("length") + upper_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value) + assert upper_bound_cfg.length == 100 + + # Test lower bound boundary condition + for _ in range(10): + curriculum.decrement_attr_level("length") + lower_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value) + assert lower_bound_cfg.length == 1