Add curriculum for arc_agi (#372)

* Add curriculum for arc_agi

* Resolve conflicts

* Remove code smell

* Remove unwanted code
This commit is contained in:
Adefioye 2025-04-01 14:17:52 -05:00 committed by GitHub
parent 50846c3534
commit e3af2dd2bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 136 additions and 9 deletions

View file

@ -14,6 +14,8 @@ from reasoning_gym.arc.board_format import (
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.factory import register_dataset
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
DATASET_NAME = "arc_agi"
@ -31,6 +33,13 @@ class ArcAgiConfig:
use_color_permutation: bool = True
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
rotations_weights: list[float] = field(
default_factory=lambda: [0.25, 0.25, 0.25, 0.25]
) # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
mirrors_weights: list[float] = field(
default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2]
) # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
seed: Optional[int] = None
size: int = 500
@ -117,13 +126,19 @@ class ArcAgiDataset(ProceduralDataset):
# Map rotation strings to functions
rotation_map = {"90": rot90, "180": rot180, "270": rot270}
if self.config.rotations:
chosen_rot = rng.choice([identity] + [rotation_map[r] for r in self.config.rotations])
chosen_rot = rng.choices(
[identity] + [rotation_map[r] for r in self.config.rotations],
weights=self.config.rotations_weights,
k=1,
)[0]
fns.append(chosen_rot)
# Map mirror strings to functions
mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror}
if self.config.mirrors:
chosen_mirror = rng.choice([identity] + [mirror_map[m] for m in self.config.mirrors])
chosen_mirror = rng.choices(
[identity] + [mirror_map[m] for m in self.config.mirrors], weights=self.config.mirrors_weights, k=1
)[0]
fns.append(chosen_mirror)
if self.config.use_color_permutation:
@ -189,6 +204,10 @@ class ArcAgiDataset(ProceduralDataset):
"input": totuple(augmented_test_input),
"output": totuple(augmented_test_output),
"task_id": task_id,
"difficulty": {
"rotations_weights": self.config.rotations_weights,
"mirrors_weights": self.config.mirrors_weights,
},
},
}
@ -207,4 +226,34 @@ class ArcAgiDataset(ProceduralDataset):
return reward
register_dataset(DATASET_NAME, ArcAgiDataset, ArcAgiConfig)
class ArcAgiCurriculum(BaseCurriculum):
"""Curriculum for ARC-AGI-1 tasks"""
def __init__(self):
super().__init__(ArcAgiCurriculum.__name__, ArcAgiConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="rotations_weights",
field_name="rotations_weights",
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
levels=[[0.3, 0.2, 0.3, 0.2], [0.15, 0.3, 0.25, 0.3], [0.1, 0.35, 0.2, 0.35], [0.0, 0.4, 0.2, 0.4]],
description="Rotation augmentation weights",
),
ScalarAttributeDefinition(
name="mirrors_weights",
field_name="mirrors_weights",
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
levels=[
[0.3, 0.3, 0.2, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.1, 0.1, 0.2, 0.3, 0.3],
[0.05, 0.05, 0.1, 0.4, 0.4],
],
description="Mirror augmentation weights",
),
)
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig, ArcAgiCurriculum)