diff --git a/reasoning_gym/algorithmic/game_of_life_halting.py b/reasoning_gym/algorithmic/game_of_life_halting.py index 99b9661e..a9c19868 100644 --- a/reasoning_gym/algorithmic/game_of_life_halting.py +++ b/reasoning_gym/algorithmic/game_of_life_halting.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional import cellpylib as cpl +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -389,4 +390,60 @@ class GameOfLifeHaltingDataset(ProceduralDataset): return 0.0 -register_dataset("game_of_life_halting", GameOfLifeHaltingDataset, GameOfLifeHaltingConfig) +class GameOfLifeHaltingCurriculum(BaseCurriculum): + """Curriculum for Game of Life Halting dataset""" + + def __init__(self): + super().__init__(GameOfLifeHaltingCurriculum.__name__, GameOfLifeHaltingConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="grid_size_x", + field_name="grid_size_x", + levels=[12, 25, 50, 200], + default_level=0, + description="Grid size in the x direction", + attr_type=AttributeType.STATIC, + min_value=12, + ), + ScalarAttributeDefinition( + name="grid_size_y", + field_name="grid_size_y", + levels=[12, 25, 50, 200], + default_level=0, + description="Grid size in the y direction", + attr_type=AttributeType.STATIC, + min_value=12, + ), + ScalarAttributeDefinition( + name="difficulty", + field_name="difficulty", + levels=[1, 2, 3], + default_level=0, + description="Oscillator type difficulty", + attr_type=AttributeType.STATIC, + min_value=1, + ), + ScalarAttributeDefinition( + name="num_oscillators", + field_name="num_oscillators", + levels=[3, 7, 10, 20], + default_level=0, + description="Number of oscillators to place", + attr_type=AttributeType.STATIC, + min_value=3, + ), + ScalarAttributeDefinition( + name="max_simulation_steps", + field_name="max_simulation_steps", + levels=[20, 50, 100, 200], + default_level=0, + description="Number of simulation steps to query", + attr_type=AttributeType.STATIC, + min_value=20, + ), + ) + + +register_dataset("game_of_life_halting", GameOfLifeHaltingDataset, GameOfLifeHaltingConfig, GameOfLifeHaltingCurriculum) diff --git a/tests/test_game_of_life_halting.py b/tests/test_game_of_life_halting.py index b09d0989..fae6bd62 100644 --- a/tests/test_game_of_life_halting.py +++ b/tests/test_game_of_life_halting.py @@ -1,6 +1,10 @@ import pytest -from reasoning_gym.algorithmic.game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset +from reasoning_gym.algorithmic.game_of_life_halting import ( + GameOfLifeHaltingConfig, + GameOfLifeHaltingCurriculum, + GameOfLifeHaltingDataset, +) def test_game_of_life(): @@ -38,3 +42,33 @@ def test_game_of_life_halting_deterministic(): for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] assert dataset1[i] != dataset3[i] + + +def test_game_of_life_halting_curriculum(): + """Test the curriculum for complex arithmetic.""" + curriculum = GameOfLifeHaltingCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: GameOfLifeHaltingCurriculum = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.grid_size_x == 12 + assert base_cfg.grid_size_y == 12 + assert base_cfg.difficulty == 1 + assert base_cfg.num_oscillators == 3 + assert base_cfg.max_simulation_steps == 20 + + # Test and validate increase in levels + curriculum.increment_attr_level("grid_size_x") + curriculum.increment_attr_level("grid_size_y") + curriculum.increment_attr_level("difficulty") + curriculum.increment_attr_level("num_oscillators") + curriculum.increment_attr_level("max_simulation_steps") + + increased_cfg: GameOfLifeHaltingCurriculum = curriculum.generate_configuration(base_value) + assert increased_cfg.grid_size_x == 25 + assert increased_cfg.grid_size_y == 25 + assert increased_cfg.difficulty == 2 + assert increased_cfg.num_oscillators == 7 + assert increased_cfg.max_simulation_steps == 50