mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
add ArcAgiDataset class, fix score_entry() metadata params
This commit is contained in:
parent
2ad0965fdc
commit
4e49806d22
20 changed files with 194 additions and 93 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from .arc_1d import Arc1DConfig, Arc1DDataset
|
||||
from .arc_agi import ArcAgiConfig, ArcAgiDataset
|
||||
from .rearc import ReArcConfig, ReArcDataset
|
||||
|
||||
__all__ = ["Arc1DConfig", "Arc1DDataset", "ReArcDataset", "ReArcConfig"]
|
||||
__all__ = ["Arc1DConfig", "Arc1DDataset", "ArcAgiConfig", "ArcAgiDataset", "ReArcDataset", "ReArcConfig"]
|
||||
|
|
|
|||
110
reasoning_gym/arc/arc_agi.py
Normal file
110
reasoning_gym/arc/arc_agi.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
from dataclasses import dataclass, field
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
import arckit
|
||||
|
||||
from reasoning_gym.arc.board_format import (
|
||||
ARC_PROMPT_TEMPLATE,
|
||||
BoardFormattingOptions,
|
||||
format_board,
|
||||
format_board_pair,
|
||||
parse_board,
|
||||
)
|
||||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.factory import register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArcAgiConfig:
|
||||
use_train: bool = True
|
||||
use_eval: bool = True
|
||||
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
assert self.size > 0, "Size of dataset must be positive."
|
||||
|
||||
|
||||
class ArcAgiDataset(ProceduralDataset):
|
||||
def __init__(self, config: ArcAgiConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.board_format_opts = config.board_format_opts
|
||||
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
||||
|
||||
self._tasks = {}
|
||||
train_set, eval_set = arckit.load_data()
|
||||
if config.use_train:
|
||||
for x in train_set:
|
||||
self._tasks[x.id] = x.to_dict()
|
||||
if config.use_eval:
|
||||
for x in eval_set:
|
||||
self._tasks[x.id] = x.to_dict()
|
||||
self._task_ids = list(self._tasks.keys())
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""
|
||||
Generate a single ARC-AGI-1 task
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
task_id = rng.choice(self._task_ids)
|
||||
task = self._tasks[task_id]
|
||||
|
||||
train = task["train"]
|
||||
test = task["test"][0]
|
||||
|
||||
examples = [
|
||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train)
|
||||
]
|
||||
examples = "".join(examples)
|
||||
test_input = format_board(test["input"], self.board_format_opts)
|
||||
test_output = format_board(test["output"], self.board_format_opts)
|
||||
|
||||
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
|
||||
|
||||
def totuple(board: list[list[int]]) -> tuple[tuple[int, ...], ...]:
|
||||
return tuple(tuple(r) for r in board)
|
||||
|
||||
return {
|
||||
"question": input_prompt,
|
||||
"answer": test_output,
|
||||
"metadata": {
|
||||
"input": totuple(test["input"]),
|
||||
"output": totuple(test["output"]),
|
||||
"task_id": task_id,
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
try:
|
||||
answer_board = parse_board(answer, self.board_format_opts)
|
||||
if answer_board == metadata["output"]:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.05
|
||||
except:
|
||||
reward = 0.01
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = ArcAgiConfig(seed=99)
|
||||
test = ArcAgiDataset(cfg)
|
||||
|
||||
x = test[1]
|
||||
|
||||
a = """1 6 7
|
||||
6 7 6
|
||||
2 2 6"""
|
||||
|
||||
print("q:", x["question"])
|
||||
print("a:", x["answer"])
|
||||
print("score:", test.score_answer(answer=a, entry=x))
|
||||
|
|
@ -1,6 +1,16 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple
|
||||
|
||||
ARC_PROMPT_TEMPLATE = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
||||
|
||||
{examples}
|
||||
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
||||
Your final answer should just be the text output grid itself.
|
||||
|
||||
Input:
|
||||
{input_grid}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoardFormattingOptions:
|
||||
|
|
@ -10,26 +20,6 @@ class BoardFormattingOptions:
|
|||
array_brackets: bool = False
|
||||
|
||||
|
||||
def format_arc_task(
|
||||
input_grid: Tuple[Tuple[int, ...], ...], output_grid: Tuple[Tuple[int, ...], ...], options: BoardFormattingOptions
|
||||
) -> str:
|
||||
"""
|
||||
Format an ARC task as a string
|
||||
"""
|
||||
|
||||
buffer = []
|
||||
if options.task_identifier:
|
||||
buffer.append(f"ARC Task: {options.task_identifier}")
|
||||
|
||||
buffer.append("\nInput Grid:")
|
||||
buffer.append(format_board(input_grid, options))
|
||||
|
||||
buffer.append("\n\nOutput Grid:")
|
||||
buffer.append(format_board(output_grid, options))
|
||||
|
||||
return "\n".join(buffer)
|
||||
|
||||
|
||||
def format_board(
|
||||
board: List[List[int]], formatting_options: BoardFormattingOptions, with_board_shape: bool = False
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -3,17 +3,7 @@ from random import Random
|
|||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
||||
|
||||
{examples}
|
||||
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
||||
Your final answer should just be the text output grid itself.
|
||||
|
||||
Input:
|
||||
{input_grid}
|
||||
"""
|
||||
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -37,7 +27,7 @@ class ReArcDataset(ProceduralDataset):
|
|||
def __init__(self, config: ReArcConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.board_format_opts = config.board_format_opts
|
||||
self._prompt_templates = _REARC_PROMPT_TEMPLATES
|
||||
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
||||
self.diff_lb = config.diff_lb
|
||||
self.diff_ub = config.diff_ub
|
||||
|
||||
|
|
@ -89,10 +79,11 @@ class ReArcDataset(ProceduralDataset):
|
|||
rng_difficulty = self.get_rng_difficulty(rng)
|
||||
pso_difficulty = self.get_pso_difficulty(task)
|
||||
input_prompt = self.format_rearc_input(rng, task, generator)
|
||||
answer = format_board(task["output"], self.board_format_opts)
|
||||
|
||||
return {
|
||||
"question": input_prompt,
|
||||
"answer": task["output"],
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"input": task["input"],
|
||||
"output": task["output"],
|
||||
|
|
@ -104,12 +95,13 @@ class ReArcDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: str, metadata: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: str, entry: Dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
try:
|
||||
formatted_answer = parse_board(answer, self.board_format_opts)
|
||||
if formatted_answer == metadata["output"]:
|
||||
answer_board = parse_board(answer, self.board_format_opts)
|
||||
if answer_board == metadata["output"]:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.05
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue