mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* Add curriculum for arc_agi * Resolve conflicts * Remove code smell * Remove unwanted code
259 lines
8.6 KiB
Python
259 lines
8.6 KiB
Python
from dataclasses import dataclass, field
|
|
from random import Random
|
|
from typing import Any, Callable, Optional
|
|
|
|
import arckit
|
|
|
|
from reasoning_gym.arc.board_format import (
|
|
ARC_PROMPT_TEMPLATE,
|
|
BoardFormattingOptions,
|
|
format_board,
|
|
format_board_pair,
|
|
parse_board,
|
|
)
|
|
from reasoning_gym.dataset import ProceduralDataset
|
|
from reasoning_gym.factory import register_dataset
|
|
|
|
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
|
|
|
DATASET_NAME = "arc_agi"
|
|
|
|
|
|
@dataclass
|
|
class ArcAgiConfig:
|
|
use_train: bool = True
|
|
use_eval: bool = True
|
|
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
|
|
|
# Augmentation options
|
|
rotations: list[str] = field(default_factory=lambda: ["90", "180", "270"]) # empty list for no rotations
|
|
mirrors: list[str] = field(
|
|
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
|
) # empty list for no mirrors
|
|
use_color_permutation: bool = True
|
|
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
|
|
|
|
rotations_weights: list[float] = field(
|
|
default_factory=lambda: [0.25, 0.25, 0.25, 0.25]
|
|
) # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
|
mirrors_weights: list[float] = field(
|
|
default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2]
|
|
) # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
|
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
|
|
def validate(self):
|
|
assert self.size > 0, "Size of dataset must be positive."
|
|
valid_rotations = ["90", "180", "270"]
|
|
valid_mirrors = ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
|
for rot in self.rotations:
|
|
assert rot in valid_rotations, f"Invalid rotation option: {rot}"
|
|
for mirror in self.mirrors:
|
|
assert mirror in valid_mirrors, f"Invalid mirror option: {mirror}"
|
|
|
|
|
|
Board = list[list[int]]
|
|
|
|
|
|
def identity(board: Board) -> Board:
|
|
return board
|
|
|
|
|
|
def rot90(board: Board) -> Board:
|
|
"""quarter clockwise rotation"""
|
|
return [row for row in zip(*board[::-1])]
|
|
|
|
|
|
def rot180(board: Board) -> Board:
|
|
"""half rotation"""
|
|
return [row[::-1] for row in board[::-1]]
|
|
|
|
|
|
def rot270(board: Board) -> Board:
|
|
"""quarter anticlockwise rotation"""
|
|
return [row[::-1] for row in zip(*board[::-1])][::-1]
|
|
|
|
|
|
def hmirror(board: Board) -> Board:
|
|
"""mirroring along horizontal"""
|
|
return board[::-1]
|
|
|
|
|
|
def vmirror(board: Board) -> Board:
|
|
"""mirroring along vertical"""
|
|
return [row[::-1] for row in board]
|
|
|
|
|
|
def dmirror(board: Board) -> Board:
|
|
"""mirroring along diagonal"""
|
|
return list(zip(*board))
|
|
|
|
|
|
def cmirror(board: Board) -> Board:
|
|
"""mirroring along counterdiagonal"""
|
|
return list(zip(*[r[::-1] for r in board[::-1]]))
|
|
|
|
|
|
def cmap(board: Board, colors: list[int]) -> Board:
|
|
return [[colors[c] for c in row] for row in board]
|
|
|
|
|
|
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
|
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
|
|
|
|
|
class ArcAgiDataset(ProceduralDataset):
|
|
def __init__(self, config: ArcAgiConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
self.board_format_opts = config.board_format_opts
|
|
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
|
|
|
self._tasks = {}
|
|
train_set, eval_set = arckit.load_data()
|
|
if config.use_train:
|
|
for x in train_set:
|
|
self._tasks[x.id] = x.to_dict()
|
|
if config.use_eval:
|
|
for x in eval_set:
|
|
self._tasks[x.id] = x.to_dict()
|
|
self._task_ids = list(self._tasks.keys())
|
|
|
|
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
|
|
"""Create a composite augmentation function from enabled options"""
|
|
fns = []
|
|
|
|
# Map rotation strings to functions
|
|
rotation_map = {"90": rot90, "180": rot180, "270": rot270}
|
|
if self.config.rotations:
|
|
chosen_rot = rng.choices(
|
|
[identity] + [rotation_map[r] for r in self.config.rotations],
|
|
weights=self.config.rotations_weights,
|
|
k=1,
|
|
)[0]
|
|
fns.append(chosen_rot)
|
|
|
|
# Map mirror strings to functions
|
|
mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror}
|
|
if self.config.mirrors:
|
|
chosen_mirror = rng.choices(
|
|
[identity] + [mirror_map[m] for m in self.config.mirrors], weights=self.config.mirrors_weights, k=1
|
|
)[0]
|
|
fns.append(chosen_mirror)
|
|
|
|
if self.config.use_color_permutation:
|
|
color_table = list(range(10))
|
|
rng.shuffle(color_table)
|
|
fns.append(lambda x: cmap(x, color_table))
|
|
|
|
def composite_fn(board: Board) -> Board:
|
|
result = board
|
|
for fn in fns:
|
|
result = fn(result)
|
|
return result
|
|
|
|
return composite_fn
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""
|
|
Generate a single ARC-AGI-1 task
|
|
"""
|
|
rng = Random(self.seed + idx)
|
|
|
|
task_id = rng.choice(self._task_ids)
|
|
task = self._tasks[task_id]
|
|
|
|
# Create augmentation function to be used for all examples
|
|
augment = self._create_augmentation_fn(rng)
|
|
|
|
train = task["train"]
|
|
test = task["test"][0]
|
|
|
|
# Apply augmentation to all train examples
|
|
augmented_train = []
|
|
for p in train:
|
|
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
|
|
|
if self.config.shuffle_example_order:
|
|
rng.shuffle(augmented_train)
|
|
|
|
examples = [
|
|
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
|
for i, p in enumerate(augmented_train)
|
|
]
|
|
examples = "".join(examples)
|
|
|
|
# Apply augmentation to test example
|
|
augmented_test_input = augment(test["input"])
|
|
augmented_test_output = augment(test["output"])
|
|
|
|
test_input = format_board(augmented_test_input, self.board_format_opts)
|
|
test_output = format_board(augmented_test_output, self.board_format_opts)
|
|
|
|
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
|
|
|
|
def totuple(board: list[list[int]]) -> tuple[tuple[int, ...], ...]:
|
|
return tuple(tuple(r) for r in board)
|
|
|
|
return {
|
|
"question": input_prompt,
|
|
"answer": test_output,
|
|
"metadata": {
|
|
"source_dataset": DATASET_NAME,
|
|
"source_index": idx,
|
|
"input": totuple(augmented_test_input),
|
|
"output": totuple(augmented_test_output),
|
|
"task_id": task_id,
|
|
"difficulty": {
|
|
"rotations_weights": self.config.rotations_weights,
|
|
"mirrors_weights": self.config.mirrors_weights,
|
|
},
|
|
},
|
|
}
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
reward = 0.0
|
|
metadata = entry["metadata"]
|
|
if answer is not None:
|
|
try:
|
|
answer_board = parse_board(answer, self.board_format_opts)
|
|
if answer_board == metadata["output"]:
|
|
reward = 1.0
|
|
else:
|
|
reward = 0.05
|
|
except:
|
|
reward = 0.0
|
|
return reward
|
|
|
|
|
|
class ArcAgiCurriculum(BaseCurriculum):
|
|
"""Curriculum for ARC-AGI-1 tasks"""
|
|
|
|
def __init__(self):
|
|
super().__init__(ArcAgiCurriculum.__name__, ArcAgiConfig)
|
|
|
|
# Define attributes
|
|
self._define_attributes(
|
|
ScalarAttributeDefinition(
|
|
name="rotations_weights",
|
|
field_name="rotations_weights",
|
|
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
|
levels=[[0.3, 0.2, 0.3, 0.2], [0.15, 0.3, 0.25, 0.3], [0.1, 0.35, 0.2, 0.35], [0.0, 0.4, 0.2, 0.4]],
|
|
description="Rotation augmentation weights",
|
|
),
|
|
ScalarAttributeDefinition(
|
|
name="mirrors_weights",
|
|
field_name="mirrors_weights",
|
|
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
|
levels=[
|
|
[0.3, 0.3, 0.2, 0.1, 0.1],
|
|
[0.2, 0.2, 0.2, 0.2, 0.2],
|
|
[0.1, 0.1, 0.2, 0.3, 0.3],
|
|
[0.05, 0.05, 0.1, 0.4, 0.4],
|
|
],
|
|
description="Mirror augmentation weights",
|
|
),
|
|
)
|
|
|
|
|
|
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig, ArcAgiCurriculum)
|