add ArcAgiDataset class, fix score_entry() metadata params

This commit is contained in:
Andreas Koepf 2025-02-08 23:18:18 +01:00
parent 2ad0965fdc
commit 4e49806d22
20 changed files with 194 additions and 93 deletions

View file

@ -3,17 +3,7 @@ from random import Random
from typing import Any, Callable, Dict, Optional
from ..factory import ProceduralDataset, register_dataset
from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below.
{examples}
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
Your final answer should just be the text output grid itself.
Input:
{input_grid}
"""
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
@dataclass
@ -37,7 +27,7 @@ class ReArcDataset(ProceduralDataset):
def __init__(self, config: ReArcConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.board_format_opts = config.board_format_opts
self._prompt_templates = _REARC_PROMPT_TEMPLATES
self._prompt_templates = ARC_PROMPT_TEMPLATE
self.diff_lb = config.diff_lb
self.diff_ub = config.diff_ub
@ -89,10 +79,11 @@ class ReArcDataset(ProceduralDataset):
rng_difficulty = self.get_rng_difficulty(rng)
pso_difficulty = self.get_pso_difficulty(task)
input_prompt = self.format_rearc_input(rng, task, generator)
answer = format_board(task["output"], self.board_format_opts)
return {
"question": input_prompt,
"answer": task["output"],
"answer": answer,
"metadata": {
"input": task["input"],
"output": task["output"],
@ -104,12 +95,13 @@ class ReArcDataset(ProceduralDataset):
},
}
def score_answer(self, answer: str, metadata: Dict[str, Any]) -> float:
def score_answer(self, answer: str, entry: Dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
formatted_answer = parse_board(answer, self.board_format_opts)
if formatted_answer == metadata["output"]:
answer_board = parse_board(answer, self.board_format_opts)
if answer_board == metadata["output"]:
reward = 1.0
else:
reward = 0.05