mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
re-arc cleanup
This commit is contained in:
parent
9fe245200c
commit
052c983cd5
6 changed files with 520 additions and 174 deletions
|
|
@ -3,69 +3,34 @@ from random import Random
|
|||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
from .rearc_board_format import (
|
||||
BoardFormattingOptions,
|
||||
default_board_format_opts,
|
||||
format_board,
|
||||
format_board_pair,
|
||||
parse_board,
|
||||
)
|
||||
from .rearc_utils import generators, verifiers
|
||||
from .rearc_utils.dsl import *
|
||||
from .rearc_utils.utils import *
|
||||
from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below
|
||||
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
||||
|
||||
Examples:
|
||||
{examples}
|
||||
|
||||
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
||||
Your final answer should just be the text output grid itself.
|
||||
|
||||
|
||||
Input Grid:
|
||||
Input:
|
||||
{input_grid}
|
||||
|
||||
Output Grid:"""
|
||||
|
||||
_COLOUR_MAP = ListedColormap(
|
||||
["#000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"]
|
||||
)
|
||||
|
||||
|
||||
def strip_prefix(string: str, prefix: str) -> str:
|
||||
"""
|
||||
removes prefix
|
||||
"""
|
||||
return string[len(prefix) :]
|
||||
|
||||
|
||||
def get_generators() -> dict:
|
||||
"""
|
||||
returns mapper from task identifiers (keys) to example generator functions
|
||||
"""
|
||||
prefix = "generate_"
|
||||
return {strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
|
||||
|
||||
|
||||
def get_verifiers() -> dict:
|
||||
"""
|
||||
returns mapper from task identifiers (keys) to example verifier functions
|
||||
"""
|
||||
prefix = "verify_"
|
||||
return {strip_prefix(n, prefix): getattr(verifiers, n) for n in dir(verifiers) if n.startswith(prefix)}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReArcConfig:
|
||||
min_examples: int = 3 # minimum number of board pairs shown
|
||||
max_examples: int = 5 # maximum number of board pairs shown
|
||||
diff_lb: int = 0
|
||||
diff_ub: int = 1
|
||||
board_format_opts: BoardFormattingOptions = field(default_factory=default_board_format_opts)
|
||||
diff_ub: int = 0.2
|
||||
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
assert self.diff_lb < self.diff_ub, "diff_lb must be < diff_ub."
|
||||
assert self.min_examples > 0, "min_examples must be positive"
|
||||
assert self.min_examples <= self.max_examples, "min_examples must be <= max_examples"
|
||||
assert self.diff_lb <= self.diff_ub, "diff_lb must be <= diff_ub."
|
||||
assert self.size > 0, "Size of dataset must be positive."
|
||||
|
||||
|
||||
|
|
@ -76,8 +41,13 @@ class ReArcDataset(ProceduralDataset):
|
|||
self._prompt_templates = _REARC_PROMPT_TEMPLATES
|
||||
self.diff_lb = config.diff_lb
|
||||
self.diff_ub = config.diff_ub
|
||||
self._generators = get_generators()
|
||||
self._verifiers = get_verifiers()
|
||||
|
||||
# lazy import of re-arc dsl & generators
|
||||
from .rearc_utils import generators
|
||||
from .rearc_utils.utils import get_generators, get_pso_difficulty
|
||||
|
||||
self._generators = get_generators(generators)
|
||||
self.get_pso_difficulty = get_pso_difficulty
|
||||
|
||||
@staticmethod
|
||||
def get_rng_difficulty(rng: Random) -> float:
|
||||
|
|
@ -88,57 +58,22 @@ class ReArcDataset(ProceduralDataset):
|
|||
rng.difficulty_samples = []
|
||||
return avg
|
||||
|
||||
@staticmethod
|
||||
def get_pso_difficulty(example: dict) -> float:
|
||||
"""
|
||||
PSO-Difficulty: proxy measure for example difficulty, defined as weighted sum of #Pixels, #Symbols, #Objects
|
||||
"""
|
||||
i, o = example["input"], example["output"]
|
||||
hwi = height(i) * width(i)
|
||||
hwo = height(o) * width(o)
|
||||
pix_pct = (hwi + hwo) / 1800
|
||||
col_pct = len(palette(i) | palette(o)) / 10
|
||||
obj_dens = (len(objects(i, T, F, F)) / hwi + len(objects(o, T, F, F)) / hwo) / 2
|
||||
return (pix_pct + col_pct + obj_dens) / 3
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.size
|
||||
|
||||
@staticmethod
|
||||
def visualise_pair(example: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Visualise a ReArc task pair
|
||||
"""
|
||||
norm = Normalize(vmin=0, vmax=9)
|
||||
args = {"cmap": _COLOUR_MAP, "norm": norm}
|
||||
|
||||
# Change to 1 row, 2 columns
|
||||
height = 1
|
||||
width = 2
|
||||
figure_size = (3 * width * 3, height * 3)
|
||||
figure, axes = plt.subplots(height, width, figsize=figure_size)
|
||||
|
||||
# Plot input and output side by side
|
||||
axes[0].imshow(example["metadata"]["input"], **args)
|
||||
axes[1].imshow(example["metadata"]["output"], **args)
|
||||
|
||||
# Add titles to distinguish the plots
|
||||
axes[0].set_title("Input")
|
||||
axes[1].set_title("Output")
|
||||
|
||||
def format_rearc_input(self, idx: int, task: dict, generator: Callable) -> str:
|
||||
def format_rearc_input(self, rng: Random, task: dict, generator: Callable) -> str:
|
||||
"""
|
||||
Format a ReArc task input with multiple examples and test input.
|
||||
"""
|
||||
example_1 = generator(Random((self.seed + idx) * 1 * self.size), self.diff_lb, self.diff_ub)
|
||||
example_2 = generator(Random((self.seed + idx) * 2 * self.size), self.diff_lb, self.diff_ub)
|
||||
example_3 = generator(Random((self.seed + idx) * 3 * self.size), self.diff_lb, self.diff_ub)
|
||||
|
||||
examples = (
|
||||
format_board_pair(1, example_1, self.board_format_opts)
|
||||
+ format_board_pair(2, example_2, self.board_format_opts)
|
||||
+ format_board_pair(3, example_3, self.board_format_opts)
|
||||
)
|
||||
num_examples = rng.randint(self.config.min_examples, self.config.max_examples)
|
||||
examples = [
|
||||
format_board_pair(
|
||||
i + 1, generator(rng, self.diff_lb, self.diff_ub), formatting_options=self.config.board_format_opts
|
||||
)
|
||||
for i in range(num_examples)
|
||||
]
|
||||
examples = "".join(examples)
|
||||
input_grid = format_board(task["input"], self.board_format_opts)
|
||||
|
||||
return self._prompt_templates.format(examples=examples, input_grid=input_grid)
|
||||
|
|
@ -154,7 +89,7 @@ class ReArcDataset(ProceduralDataset):
|
|||
|
||||
rng_difficulty = self.get_rng_difficulty(rng)
|
||||
pso_difficulty = self.get_pso_difficulty(task)
|
||||
input_prompt = self.format_rearc_input(idx, task, generator)
|
||||
input_prompt = self.format_rearc_input(rng, task, generator)
|
||||
|
||||
return {
|
||||
"question": input_prompt,
|
||||
|
|
@ -163,8 +98,10 @@ class ReArcDataset(ProceduralDataset):
|
|||
"input": task["input"],
|
||||
"output": task["output"],
|
||||
"task_id": task_id,
|
||||
"rng": rng_difficulty,
|
||||
"pso": pso_difficulty,
|
||||
"difficulty": {
|
||||
"rng": rng_difficulty,
|
||||
"pso": pso_difficulty,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue