add ArcAgiDataset class, fix score_entry() metadata params

This commit is contained in:
Andreas Koepf 2025-02-08 23:18:18 +01:00
parent 2ad0965fdc
commit 4e49806d22
20 changed files with 194 additions and 93 deletions

View file

@ -0,0 +1,110 @@
from dataclasses import dataclass, field
from random import Random
from typing import Any, Optional
import arckit
from reasoning_gym.arc.board_format import (
ARC_PROMPT_TEMPLATE,
BoardFormattingOptions,
format_board,
format_board_pair,
parse_board,
)
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.factory import register_dataset
@dataclass
class ArcAgiConfig:
use_train: bool = True
use_eval: bool = True
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
seed: Optional[int] = None
size: int = 500
def validate(self):
assert self.size > 0, "Size of dataset must be positive."
class ArcAgiDataset(ProceduralDataset):
def __init__(self, config: ArcAgiConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.board_format_opts = config.board_format_opts
self._prompt_templates = ARC_PROMPT_TEMPLATE
self._tasks = {}
train_set, eval_set = arckit.load_data()
if config.use_train:
for x in train_set:
self._tasks[x.id] = x.to_dict()
if config.use_eval:
for x in eval_set:
self._tasks[x.id] = x.to_dict()
self._task_ids = list(self._tasks.keys())
def __getitem__(self, idx: int) -> dict:
"""
Generate a single ARC-AGI-1 task
"""
rng = Random(self.seed + idx)
task_id = rng.choice(self._task_ids)
task = self._tasks[task_id]
train = task["train"]
test = task["test"][0]
examples = [
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train)
]
examples = "".join(examples)
test_input = format_board(test["input"], self.board_format_opts)
test_output = format_board(test["output"], self.board_format_opts)
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
def totuple(board: list[list[int]]) -> tuple[tuple[int, ...], ...]:
return tuple(tuple(r) for r in board)
return {
"question": input_prompt,
"answer": test_output,
"metadata": {
"input": totuple(test["input"]),
"output": totuple(test["output"]),
"task_id": task_id,
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
answer_board = parse_board(answer, self.board_format_opts)
if answer_board == metadata["output"]:
reward = 1.0
else:
reward = 0.05
except:
reward = 0.01
return reward
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
if __name__ == "__main__":
cfg = ArcAgiConfig(seed=99)
test = ArcAgiDataset(cfg)
x = test[1]
a = """1 6 7
6 7 6
2 2 6"""
print("q:", x["question"])
print("a:", x["answer"])
print("score:", test.score_answer(answer=a, entry=x))