mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
110 lines
3.2 KiB
Python
110 lines
3.2 KiB
Python
from dataclasses import dataclass, field
|
|
from random import Random
|
|
from typing import Any, Optional
|
|
|
|
import arckit
|
|
|
|
from reasoning_gym.arc.board_format import (
|
|
ARC_PROMPT_TEMPLATE,
|
|
BoardFormattingOptions,
|
|
format_board,
|
|
format_board_pair,
|
|
parse_board,
|
|
)
|
|
from reasoning_gym.dataset import ProceduralDataset
|
|
from reasoning_gym.factory import register_dataset
|
|
|
|
|
|
@dataclass
|
|
class ArcAgiConfig:
|
|
use_train: bool = True
|
|
use_eval: bool = True
|
|
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
|
|
def validate(self):
|
|
assert self.size > 0, "Size of dataset must be positive."
|
|
|
|
|
|
class ArcAgiDataset(ProceduralDataset):
|
|
def __init__(self, config: ArcAgiConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
self.board_format_opts = config.board_format_opts
|
|
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
|
|
|
self._tasks = {}
|
|
train_set, eval_set = arckit.load_data()
|
|
if config.use_train:
|
|
for x in train_set:
|
|
self._tasks[x.id] = x.to_dict()
|
|
if config.use_eval:
|
|
for x in eval_set:
|
|
self._tasks[x.id] = x.to_dict()
|
|
self._task_ids = list(self._tasks.keys())
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""
|
|
Generate a single ARC-AGI-1 task
|
|
"""
|
|
rng = Random(self.seed + idx)
|
|
|
|
task_id = rng.choice(self._task_ids)
|
|
task = self._tasks[task_id]
|
|
|
|
train = task["train"]
|
|
test = task["test"][0]
|
|
|
|
examples = [
|
|
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train)
|
|
]
|
|
examples = "".join(examples)
|
|
test_input = format_board(test["input"], self.board_format_opts)
|
|
test_output = format_board(test["output"], self.board_format_opts)
|
|
|
|
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
|
|
|
|
def totuple(board: list[list[int]]) -> tuple[tuple[int, ...], ...]:
|
|
return tuple(tuple(r) for r in board)
|
|
|
|
return {
|
|
"question": input_prompt,
|
|
"answer": test_output,
|
|
"metadata": {
|
|
"input": totuple(test["input"]),
|
|
"output": totuple(test["output"]),
|
|
"task_id": task_id,
|
|
},
|
|
}
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
reward = 0.0
|
|
metadata = entry["metadata"]
|
|
if answer is not None:
|
|
try:
|
|
answer_board = parse_board(answer, self.board_format_opts)
|
|
if answer_board == metadata["output"]:
|
|
reward = 1.0
|
|
else:
|
|
reward = 0.05
|
|
except:
|
|
reward = 0.01
|
|
return reward
|
|
|
|
|
|
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cfg = ArcAgiConfig(seed=99)
|
|
test = ArcAgiDataset(cfg)
|
|
|
|
x = test[1]
|
|
|
|
a = """1 6 7
|
|
6 7 6
|
|
2 2 6"""
|
|
|
|
print("q:", x["question"])
|
|
print("a:", x["answer"])
|
|
print("score:", test.score_answer(answer=a, entry=x))
|