diff --git a/reasoning_gym/arc/__init__.py b/reasoning_gym/arc/__init__.py index c7c48b94..d422e098 100644 --- a/reasoning_gym/arc/__init__.py +++ b/reasoning_gym/arc/__init__.py @@ -1,5 +1,13 @@ from .arc_1d import Arc1DConfig, Arc1DDataset from .arc_agi import ArcAgiConfig, ArcAgiDataset -from .rearc import ReArcConfig, ReArcDataset +from .rearc import ReArcConfig, ReArcCurriculum, ReArcDataset -__all__ = ["Arc1DConfig", "Arc1DDataset", "ArcAgiConfig", "ArcAgiDataset", "ReArcDataset", "ReArcConfig"] +__all__ = [ + "Arc1DConfig", + "Arc1DDataset", + "ArcAgiConfig", + "ArcAgiDataset", + "ReArcDataset", + "ReArcConfig", + "ReArcCurriculum", +] diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index dabbe808..e84b49fc 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -2,9 +2,20 @@ from dataclasses import dataclass, field from random import Random from typing import Any, Callable, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board +RNG_DIFFICULTY_LEVELS = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2] +RNG_DIFFICULTY_RANGES = [ + (RNG_DIFFICULTY_LEVELS[i], RNG_DIFFICULTY_LEVELS[i + 1]) for i in range(len(RNG_DIFFICULTY_LEVELS) - 1) +] + +PSO_DIFFICULTY_LEVELS = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 1] +PSO_DIFFICULTY_RANGES = [ + (PSO_DIFFICULTY_LEVELS[i], PSO_DIFFICULTY_LEVELS[i + 1]) for i in range(len(PSO_DIFFICULTY_LEVELS) - 1) +] + @dataclass class ReArcConfig: @@ -15,6 +26,14 @@ class ReArcConfig: board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions()) seed: Optional[int] = None size: int = 500 + rng_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: RNG_DIFFICULTY_RANGES) + rng_difficulty_weights: list[float] = field( + default_factory=lambda: [1 / len(RNG_DIFFICULTY_RANGES)] * len(RNG_DIFFICULTY_RANGES) + ) + pso_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: PSO_DIFFICULTY_RANGES) + pso_difficulty_weights: list[float] = field( + default_factory=lambda: [1 / len(PSO_DIFFICULTY_RANGES)] * len(PSO_DIFFICULTY_RANGES) + ) def validate(self): assert self.min_examples > 0, "min_examples must be positive" @@ -72,12 +91,22 @@ class ReArcDataset(ProceduralDataset): Generate a single ReArc task """ rng = Random(self.seed + idx) - task_id = rng.choice(list(self._generators.keys())) - generator = self._generators[task_id] - task = generator(rng, self.diff_lb, self.diff_ub) + pso_difficulty_range = rng.choices( + self.config.pso_difficulty_ranges, weights=self.config.pso_difficulty_weights, k=1 + )[0] + + while True: + task_id = rng.choice(list(self._generators.keys())) + generator = self._generators[task_id] + difficulty_range = rng.choices( + self.config.rng_difficulty_ranges, weights=self.config.rng_difficulty_weights, k=1 + )[0] + task = generator(rng, difficulty_range[0], difficulty_range[1]) + pso_difficulty = self.get_pso_difficulty(task) + if (pso_difficulty_range[0] <= pso_difficulty) and (pso_difficulty <= pso_difficulty_range[1]): + break rng_difficulty = self.get_rng_difficulty(rng) - pso_difficulty = self.get_pso_difficulty(task) input_prompt = self.format_rearc_input(rng, task, generator) answer = format_board(task["output"], self.board_format_opts) @@ -110,4 +139,43 @@ class ReArcDataset(ProceduralDataset): return reward -register_dataset("rearc", ReArcDataset, ReArcConfig) +class ReArcCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ReArcCurriculum.__name__, ReArcConfig) + self._define_attributes( + ScalarAttributeDefinition( + name="pso_difficulty", + field_name="pso_difficulty_weights", + description="The range of PSO difficulty for the Arc problem", + default_level=0, + levels=[ + [1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + ], # only sample/generate the hardest tasks PSO difficulty + ), + ScalarAttributeDefinition( + name="rng_difficulty", + field_name="rng_difficulty_weights", + description="The range of RNG difficulty for the Arc problem", + default_level=0, + levels=[ + [1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + ], # only sample/generate the hardest tasks wrs RNG difficulty + ), + ) + + +register_dataset("rearc", ReArcDataset, ReArcConfig, ReArcCurriculum) diff --git a/tests/test_rearc.py b/tests/test_rearc.py index a23d8000..1d035774 100644 --- a/tests/test_rearc.py +++ b/tests/test_rearc.py @@ -1,7 +1,7 @@ import pytest from reasoning_gym.arc.board_format import format_board -from reasoning_gym.arc.rearc import ReArcConfig, ReArcDataset +from reasoning_gym.arc.rearc import ReArcConfig, ReArcCurriculum, ReArcDataset def test_rearc_config_validation(): @@ -85,3 +85,55 @@ def test_rearc_scoring_edge_cases(): # Case sensitivity answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower() assert dataset.score_answer(answer, entry=item) == 1.0 + + +def test_rearc_curriculum(): + """Test the ReArc curriculum functionality""" + curriculum = ReArcCurriculum() + + base_value = {"size": 50, "seed": 42} + + # Test default configuration + base_cfg: ReArcConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 42 + assert base_cfg.size == 50 + + # Default levels should have weights that select only the easiest tasks + assert base_cfg.pso_difficulty_weights == [1, 0, 0, 0, 0, 0, 0, 0] + assert base_cfg.rng_difficulty_weights == [1, 0, 0, 0, 0, 0, 0, 0] + + # Test incrementing pso_difficulty attribute + curriculum.increment_attr_level("pso_difficulty") + pso_cfg = curriculum.generate_configuration(base_value) + assert pso_cfg.pso_difficulty_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Level 1: second difficulty range + assert pso_cfg.rng_difficulty_weights == [1, 0, 0, 0, 0, 0, 0, 0] # RNG unchanged + + # Test incrementing rng_difficulty attribute + curriculum.increment_attr_level("rng_difficulty") + rng_cfg = curriculum.generate_configuration(base_value) + assert rng_cfg.pso_difficulty_weights == [0, 1, 0, 0, 0, 0, 0, 0] # PSO unchanged + assert rng_cfg.rng_difficulty_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Level 1: second difficulty range + + # Test decrementing pso_difficulty attribute + curriculum.decrement_attr_level("pso_difficulty") + decr_cfg = curriculum.generate_configuration(base_value) + assert decr_cfg.pso_difficulty_weights == [1, 0, 0, 0, 0, 0, 0, 0] # Back to level 0 + assert decr_cfg.rng_difficulty_weights == [0, 1, 0, 0, 0, 0, 0, 0] # RNG unchanged + + # Test global level setting to higher level + curriculum.set_global_level(3) # Set all attributes to level 3 + global_cfg = curriculum.generate_configuration(base_value) + assert global_cfg.pso_difficulty_weights == [0, 0, 0, 1, 0, 0, 0, 0] # Level 3 + assert global_cfg.rng_difficulty_weights == [0, 0, 0, 1, 0, 0, 0, 0] # Level 3 + + # Test increment global level + curriculum.increment_global_level() # Should go to level 4 + incr_global_cfg = curriculum.generate_configuration(base_value) + assert incr_global_cfg.pso_difficulty_weights == [0, 0, 0, 0, 1, 0, 0, 0] # Level 4 + assert incr_global_cfg.rng_difficulty_weights == [0, 0, 0, 0, 1, 0, 0, 0] # Level 4 + + # Test decrement global level + curriculum.decrement_global_level() # Should go back to level 3 + decr_global_cfg = curriculum.generate_configuration(base_value) + assert decr_global_cfg.pso_difficulty_weights == [0, 0, 0, 1, 0, 0, 0, 0] # Level 3 + assert decr_global_cfg.rng_difficulty_weights == [0, 0, 0, 1, 0, 0, 0, 0] # Level 3