mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add curriculum for arc_agi (#372)
* Add curriculum for arc_agi * Resolve conflicts * Remove code smell * Remove unwanted code
This commit is contained in:
parent
50846c3534
commit
e3af2dd2bd
3 changed files with 136 additions and 9 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from .arc_1d import Arc1DConfig, Arc1DCurriculum, Arc1DDataset
|
||||
from .arc_agi import ArcAgiConfig, ArcAgiDataset
|
||||
from .arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset
|
||||
from .rearc import ReArcConfig, ReArcCurriculum, ReArcDataset
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -8,6 +8,7 @@ __all__ = [
|
|||
"Arc1DCurriculum",
|
||||
"ArcAgiConfig",
|
||||
"ArcAgiDataset",
|
||||
"ArcAgiCurriculum",
|
||||
"ReArcDataset",
|
||||
"ReArcConfig",
|
||||
"ReArcCurriculum",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from reasoning_gym.arc.board_format import (
|
|||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.factory import register_dataset
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
|
||||
DATASET_NAME = "arc_agi"
|
||||
|
||||
|
||||
|
|
@ -31,6 +33,13 @@ class ArcAgiConfig:
|
|||
use_color_permutation: bool = True
|
||||
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
|
||||
|
||||
rotations_weights: list[float] = field(
|
||||
default_factory=lambda: [0.25, 0.25, 0.25, 0.25]
|
||||
) # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||
mirrors_weights: list[float] = field(
|
||||
default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
) # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
|
@ -117,13 +126,19 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
# Map rotation strings to functions
|
||||
rotation_map = {"90": rot90, "180": rot180, "270": rot270}
|
||||
if self.config.rotations:
|
||||
chosen_rot = rng.choice([identity] + [rotation_map[r] for r in self.config.rotations])
|
||||
chosen_rot = rng.choices(
|
||||
[identity] + [rotation_map[r] for r in self.config.rotations],
|
||||
weights=self.config.rotations_weights,
|
||||
k=1,
|
||||
)[0]
|
||||
fns.append(chosen_rot)
|
||||
|
||||
# Map mirror strings to functions
|
||||
mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror}
|
||||
if self.config.mirrors:
|
||||
chosen_mirror = rng.choice([identity] + [mirror_map[m] for m in self.config.mirrors])
|
||||
chosen_mirror = rng.choices(
|
||||
[identity] + [mirror_map[m] for m in self.config.mirrors], weights=self.config.mirrors_weights, k=1
|
||||
)[0]
|
||||
fns.append(chosen_mirror)
|
||||
|
||||
if self.config.use_color_permutation:
|
||||
|
|
@ -189,6 +204,10 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
"input": totuple(augmented_test_input),
|
||||
"output": totuple(augmented_test_output),
|
||||
"task_id": task_id,
|
||||
"difficulty": {
|
||||
"rotations_weights": self.config.rotations_weights,
|
||||
"mirrors_weights": self.config.mirrors_weights,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -207,4 +226,34 @@ class ArcAgiDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, ArcAgiDataset, ArcAgiConfig)
|
||||
class ArcAgiCurriculum(BaseCurriculum):
|
||||
"""Curriculum for ARC-AGI-1 tasks"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ArcAgiCurriculum.__name__, ArcAgiConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="rotations_weights",
|
||||
field_name="rotations_weights",
|
||||
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||
levels=[[0.3, 0.2, 0.3, 0.2], [0.15, 0.3, 0.25, 0.3], [0.1, 0.35, 0.2, 0.35], [0.0, 0.4, 0.2, 0.4]],
|
||||
description="Rotation augmentation weights",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="mirrors_weights",
|
||||
field_name="mirrors_weights",
|
||||
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||
levels=[
|
||||
[0.3, 0.3, 0.2, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2, 0.2],
|
||||
[0.1, 0.1, 0.2, 0.3, 0.3],
|
||||
[0.05, 0.05, 0.1, 0.4, 0.4],
|
||||
],
|
||||
description="Mirror augmentation weights",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig, ArcAgiCurriculum)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiDataset
|
||||
from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset
|
||||
|
||||
|
||||
def test_arc_agi_config_validation():
|
||||
|
|
@ -65,12 +65,28 @@ def test_arc_agi_items():
|
|||
def test_arc_agi_augmentations():
|
||||
"""Test that augmentations can be selectively enabled/disabled"""
|
||||
# Test with all augmentations disabled
|
||||
config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=False)
|
||||
config = ArcAgiConfig(
|
||||
seed=42,
|
||||
size=10,
|
||||
rotations=[],
|
||||
mirrors=[],
|
||||
rotations_weights=[1.0],
|
||||
mirrors_weights=[1.0],
|
||||
use_color_permutation=False,
|
||||
)
|
||||
base_dataset = ArcAgiDataset(config)
|
||||
base_items = list(base_dataset)
|
||||
|
||||
# Test with specific rotation only
|
||||
rot_config = ArcAgiConfig(seed=42, size=10, rotations=["90"], mirrors=[], use_color_permutation=False)
|
||||
rot_config = ArcAgiConfig(
|
||||
seed=42,
|
||||
size=10,
|
||||
rotations=["90"],
|
||||
mirrors=[],
|
||||
rotations_weights=[0.5, 0.5],
|
||||
mirrors_weights=[1.0],
|
||||
use_color_permutation=False,
|
||||
)
|
||||
rot_dataset = ArcAgiDataset(rot_config)
|
||||
rot_items = list(rot_dataset)
|
||||
|
||||
|
|
@ -80,7 +96,15 @@ def test_arc_agi_augmentations():
|
|||
), "90-degree rotation augmentation had no effect"
|
||||
|
||||
# Test with specific mirror only
|
||||
mirror_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=["horizontal"], use_color_permutation=False)
|
||||
mirror_config = ArcAgiConfig(
|
||||
seed=42,
|
||||
size=10,
|
||||
rotations=[],
|
||||
mirrors=["horizontal"],
|
||||
rotations_weights=[1.0],
|
||||
mirrors_weights=[0.5, 0.5],
|
||||
use_color_permutation=False,
|
||||
)
|
||||
mirror_dataset = ArcAgiDataset(mirror_config)
|
||||
mirror_items = list(mirror_dataset)
|
||||
|
||||
|
|
@ -90,7 +114,15 @@ def test_arc_agi_augmentations():
|
|||
), "Horizontal mirror augmentation had no effect"
|
||||
|
||||
# Test with color permutation only
|
||||
color_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=True)
|
||||
color_config = ArcAgiConfig(
|
||||
seed=42,
|
||||
size=10,
|
||||
rotations=[],
|
||||
mirrors=[],
|
||||
rotations_weights=[1.0],
|
||||
mirrors_weights=[1.0],
|
||||
use_color_permutation=True,
|
||||
)
|
||||
color_dataset = ArcAgiDataset(color_config)
|
||||
color_items = list(color_dataset)
|
||||
|
||||
|
|
@ -166,3 +198,48 @@ def test_arc_agi_shuffled_order():
|
|||
for a, b in zip(shuffled, unshuffled):
|
||||
assert a["question"] != b["question"]
|
||||
assert a["answer"] == b["answer"]
|
||||
|
||||
|
||||
def test_arc_agi_curriculum():
|
||||
"""Test the curriculum for complex arithmetic."""
|
||||
curriculum = ArcAgiCurriculum()
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
||||
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
||||
assert base_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|
||||
|
||||
# Test and validate increase in levels
|
||||
curriculum.increment_attr_level("rotations_weights")
|
||||
curriculum.increment_attr_level("mirrors_weights")
|
||||
|
||||
increased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.rotations_weights == [0.15, 0.3, 0.25, 0.3]
|
||||
assert increased_cfg.mirrors_weights == [0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
|
||||
# Test and validate decrease in levels
|
||||
curriculum.decrement_attr_level("rotations_weights")
|
||||
curriculum.decrement_attr_level("mirrors_weights")
|
||||
|
||||
decreased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert decreased_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
||||
assert decreased_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|
||||
|
||||
# Test upper bound boundary condition
|
||||
for _ in range(10):
|
||||
curriculum.increment_attr_level("rotations_weights")
|
||||
curriculum.increment_attr_level("mirrors_weights")
|
||||
upper_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert upper_bound_cfg.rotations_weights == [0.0, 0.4, 0.2, 0.4]
|
||||
assert upper_bound_cfg.mirrors_weights == [0.05, 0.05, 0.1, 0.4, 0.4]
|
||||
|
||||
# Test lower bound boundary condition
|
||||
for _ in range(10):
|
||||
curriculum.decrement_attr_level("rotations_weights")
|
||||
curriculum.decrement_attr_level("mirrors_weights")
|
||||
lower_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert lower_bound_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
||||
assert lower_bound_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue