mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
add ArcAgiDataset class, fix score_entry() metadata params
This commit is contained in:
parent
2ad0965fdc
commit
4e49806d22
20 changed files with 194 additions and 93 deletions
|
|
@ -3,17 +3,7 @@ from random import Random
|
|||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
||||
|
||||
{examples}
|
||||
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
||||
Your final answer should just be the text output grid itself.
|
||||
|
||||
Input:
|
||||
{input_grid}
|
||||
"""
|
||||
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -37,7 +27,7 @@ class ReArcDataset(ProceduralDataset):
|
|||
def __init__(self, config: ReArcConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.board_format_opts = config.board_format_opts
|
||||
self._prompt_templates = _REARC_PROMPT_TEMPLATES
|
||||
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
||||
self.diff_lb = config.diff_lb
|
||||
self.diff_ub = config.diff_ub
|
||||
|
||||
|
|
@ -89,10 +79,11 @@ class ReArcDataset(ProceduralDataset):
|
|||
rng_difficulty = self.get_rng_difficulty(rng)
|
||||
pso_difficulty = self.get_pso_difficulty(task)
|
||||
input_prompt = self.format_rearc_input(rng, task, generator)
|
||||
answer = format_board(task["output"], self.board_format_opts)
|
||||
|
||||
return {
|
||||
"question": input_prompt,
|
||||
"answer": task["output"],
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"input": task["input"],
|
||||
"output": task["output"],
|
||||
|
|
@ -104,12 +95,13 @@ class ReArcDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: str, metadata: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: str, entry: Dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
metadata = entry["metadata"]
|
||||
if answer is not None:
|
||||
try:
|
||||
formatted_answer = parse_board(answer, self.board_format_opts)
|
||||
if formatted_answer == metadata["output"]:
|
||||
answer_board = parse_board(answer, self.board_format_opts)
|
||||
if answer_board == metadata["output"]:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.05
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue