diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index 2d567d40..029da94e 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -1,9 +1,30 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from random import Random +from typing import Any, Callable, Dict, Optional from ..factory import ProceduralDataset, register_dataset -from .rearc_board_format import BoardFormattingOptions, format_arc_task, format_board, format_board_pair +from .rearc_board_format import BoardFormattingOptions, default_board_format_opts, format_board, format_board_pair from .rearc_utils import generators, verifiers +from .rearc_utils.dsl import * +from .rearc_utils.utils import * + +_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below + +Examples: +{examples} + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. +Your final answer should just be the text output grid itself. + + +Input Grid: +{input_grid} + +Output Grid:""" + +_COLOUR_MAP = ListedColormap( + ["#000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"] +) def strip_prefix(string: str, prefix: str) -> str: @@ -31,41 +52,30 @@ def get_verifiers() -> dict: @dataclass class ReArcConfig: + diff_lb: int = 0 + diff_ub: int = 1 + board_format_opts: BoardFormattingOptions = field(default_factory=default_board_format_opts) seed: Optional[int] = None size: int = 500 - board_format_opts: BoardFormattingOptions - - -REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below - -Examples: -{examples} - -Below is a test input grid. Predict the corresponding output grid by applying the rule you found. -Your final answer should just be the text output grid itself. - - -Input Grid: -{input_grid_test} - -Output Grid:""" class ReArcDataset(ProceduralDataset): def __init__(self, config: ReArcConfig): super().__init__(config=config, seed=config.seed, size=config.size) - board_format_opts = config.board_format_opts - self._prompt_templates = REARC_PROMPT_TEMPLATES + self.board_format_opts = config.board_format_opts + self._prompt_templates = _REARC_PROMPT_TEMPLATES + self.diff_lb = config.diff_lb + self.diff_ub = config.diff_ub self._generators = get_generators() self._verifiers = get_verifiers() @staticmethod - def get_rng_difficulty(example: dict) -> float: + def get_rng_difficulty(rng: Random) -> float: if not hasattr(rng, "difficulty_samples"): return 0.0 samples = rng.difficulty_samples avg = sum(samples) / len(samples) if samples else 0.0 - rng.difficulty_samples = [] # Reset for next generation + rng.difficulty_samples = [] return avg @staticmethod @@ -81,13 +91,38 @@ class ReArcDataset(ProceduralDataset): obj_dens = (len(objects(i, T, F, F)) / hwi + len(objects(o, T, F, F)) / hwo) / 2 return (pix_pct + col_pct + obj_dens) / 3 + def __len__(self) -> int: + return self.size + + @staticmethod + def visualise_pair(example: Dict[str, Any]) -> None: + """ + Visualise a ReArc task pair + """ + norm = Normalize(vmin=0, vmax=9) + args = {"cmap": _COLOUR_MAP, "norm": norm} + + # Change to 1 row, 2 columns + height = 1 + width = 2 + figure_size = (3 * width * 3, height * 3) + figure, axes = plt.subplots(height, width, figsize=figure_size) + + # Plot input and output side by side + axes[0].imshow(example["metadata"]["input"], **args) + axes[1].imshow(example["metadata"]["output"], **args) + + # Add titles to distinguish the plots + axes[0].set_title("Input") + axes[1].set_title("Output") + def format_rearc_input(self, idx: int, task: dict, generator: Callable) -> str: """ Format a ReArc task input with multiple examples and test input. """ - example_1 = generator(Random((self.seed + idx) * 1 * self.size)) - example_2 = generator(Random((self.seed + idx) * 2 * self.size)) - example_3 = generator(Random((self.seed + idx) * 3 * self.size)) + example_1 = generator(Random((self.seed + idx) * 1 * self.size), self.diff_lb, self.diff_ub) + example_2 = generator(Random((self.seed + idx) * 2 * self.size), self.diff_lb, self.diff_ub) + example_3 = generator(Random((self.seed + idx) * 3 * self.size), self.diff_lb, self.diff_ub) examples = ( format_board_pair(1, example_1, self.board_format_opts) @@ -105,10 +140,7 @@ class ReArcDataset(ProceduralDataset): rng = Random(self.seed + idx) task_id = rng.choice(list(self._generators.keys())) generator = self._generators[task_id] - task = generator(rng) - - input_grid = format_board(task["input"], self.board_format_opts) - output_grid = format_board(task["output"], self.board_format_opts) + task = generator(rng, self.diff_lb, self.diff_ub) rng_difficulty = self.get_rng_difficulty(rng) pso_difficulty = self.get_pso_difficulty(task) @@ -117,11 +149,28 @@ class ReArcDataset(ProceduralDataset): return { "question": input_prompt, "answer": task["output"], - "metadata": {"difficulty": {"output_grid": output_grid, "rng": rng_difficulty, "pso": pso_difficulty}}, + "metadata": { + "input": task["input"], + "output": task["output"], + "task_id": task_id, + "rng": rng_difficulty, + "pso": pso_difficulty, + }, } def score_answer(self, answer: str, metadata: Dict[str, Any]) -> float: - """Todo""" + reward = 0.0 + if answer is not None: + try: + task_id = metadata["task_id"] + verifier = self._verifiers[task_id] + if verifier(metadata["input"]) == metadata["output"]: + reward = 1.0 + else: + reward = 0.05 + except: + reward = 0.01 + return reward -dataset = register_dataset("rearc", ReArcDataset) +register_dataset("rearc", ReArcDataset, ReArcConfig)