diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index fed97fea..1816ddae 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -18,7 +18,7 @@ from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingData from .graph_color import GraphColorConfig, GraphColorCurriculum, GraphColorDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurriculum, IsomorphicStringsDataset -from .jugs import JugsConfig, JugsDataset +from .jugs import JugsConfig, JugsCurriculum, JugsDataset from .letter_counting import LetterCountingConfig, LetterCountingCurriculum, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleCurriculum, LetterJumbleDataset from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset @@ -148,6 +148,7 @@ __all__ = [ "RottenOrangesCurriculum", "JugsConfig", "JugsDataset", + "JugsCurriculum", "BinaryAlternationConfig", "BinaryAlternationDataset", "BinaryAlternationCurriculum", diff --git a/reasoning_gym/algorithmic/jugs.py b/reasoning_gym/algorithmic/jugs.py index 0dc7210c..c0077a23 100644 --- a/reasoning_gym/algorithmic/jugs.py +++ b/reasoning_gym/algorithmic/jugs.py @@ -6,6 +6,7 @@ from functools import reduce from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -280,7 +281,13 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil return { "question": question, "answer": json.dumps(solution), # one possible solution - "metadata": {"puzzle": puzzle}, + "metadata": { + "puzzle": puzzle, + "difficulty": { + "num_jugs": self.config.num_jugs, + "difficulty": self.config.difficulty, + }, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -310,4 +317,33 @@ Reply as a JSON-parsable list of moves which result in any of the jugs being fil return 0.0 -register_dataset("jugs", JugsDataset, JugsConfig) +class JugsCurriculum(BaseCurriculum): + """Curriculum for Jugs puzzles""" + + def __init__(self): + super().__init__(JugsCurriculum.__name__, JugsConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="num_jugs", + field_name="num_jugs", + levels=[3, 4, 5, 7], + default_level=0, + description="Number of jugs in the puzzle", + attr_type=AttributeType.STATIC, + min_value=3, + ), + ScalarAttributeDefinition( + name="difficulty", + field_name="difficulty", + levels=[2, 4, 6, 8], + default_level=0, + description="Minimum required moves to solve the puzzle", + attr_type=AttributeType.STATIC, + min_value=10, + ), + ) + + +register_dataset("jugs", JugsDataset, JugsConfig, JugsCurriculum) diff --git a/tests/test_jugs.py b/tests/test_jugs.py index adc05511..384c3d3c 100644 --- a/tests/test_jugs.py +++ b/tests/test_jugs.py @@ -2,7 +2,7 @@ import json import pytest -from reasoning_gym.algorithmic.jugs import JugsConfig, JugsDataset +from reasoning_gym.algorithmic.jugs import JugsConfig, JugsCurriculum, JugsDataset def test_jugs(): @@ -44,7 +44,53 @@ def test_jugs(): assert "question" in item assert "answer" in item assert "metadata" in item + assert "difficulty" in item["metadata"] # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_game_of_life_curriculum(): + """Test the curriculum for complex arithmetic.""" + curriculum = JugsCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: JugsCurriculum = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.num_jugs == 3 + assert base_cfg.difficulty == 2 + + # Test and validate increase in levels + curriculum.increment_attr_level("num_jugs") + curriculum.increment_attr_level("difficulty") + + increased_cfg: JugsCurriculum = curriculum.generate_configuration(base_value) + assert increased_cfg.num_jugs == 4 + assert increased_cfg.difficulty == 4 + + # Test and validate decrease in levels + curriculum.decrement_attr_level("num_jugs") + curriculum.decrement_attr_level("difficulty") + + decreased_cfg: JugsCurriculum = curriculum.generate_configuration(base_value) + assert decreased_cfg.num_jugs == 3 + assert decreased_cfg.difficulty == 2 + + # Test upper bound boundary condition + for _ in range(10): + curriculum.increment_attr_level("num_jugs") + curriculum.increment_attr_level("difficulty") + upper_bound_cfg: JugsCurriculum = curriculum.generate_configuration(base_value) + assert upper_bound_cfg.num_jugs == 7 + assert upper_bound_cfg.difficulty == 8 + + # Test lower bound boundary condition + for _ in range(10): + curriculum.decrement_attr_level("num_jugs") + curriculum.decrement_attr_level("difficulty") + lower_bound_cfg: JugsCurriculum = curriculum.generate_configuration(base_value) + assert lower_bound_cfg.num_jugs == 3 + assert lower_bound_cfg.difficulty == 2