mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
re-arc cleanup
This commit is contained in:
parent
9fe245200c
commit
052c983cd5
6 changed files with 520 additions and 174 deletions
|
|
@ -1,12 +1,45 @@
|
|||
import random
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import ListedColormap, Normalize
|
||||
|
||||
from .dsl import *
|
||||
|
||||
|
||||
def strip_prefix(string: str, prefix: str) -> str:
|
||||
"""
|
||||
removes prefix
|
||||
"""
|
||||
return string[len(prefix) :]
|
||||
|
||||
|
||||
def get_generators(generators) -> dict:
|
||||
"""
|
||||
returns mapper from task identifiers (keys) to example generator functions
|
||||
"""
|
||||
prefix = "generate_"
|
||||
return {strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
|
||||
|
||||
|
||||
def get_verifiers(verifiers) -> dict:
|
||||
"""
|
||||
returns mapper from task identifiers (keys) to example verifier functions
|
||||
"""
|
||||
prefix = "verify_"
|
||||
return {strip_prefix(n, prefix): getattr(verifiers, n) for n in dir(verifiers) if n.startswith(prefix)}
|
||||
|
||||
|
||||
def get_pso_difficulty(example: dict) -> float:
|
||||
"""
|
||||
PSO-Difficulty: proxy measure for example difficulty, defined as weighted sum of #Pixels, #Symbols, #Objects
|
||||
"""
|
||||
i, o = example["input"], example["output"]
|
||||
hwi = height(i) * width(i)
|
||||
hwo = height(o) * width(o)
|
||||
pix_pct = (hwi + hwo) / 1800
|
||||
col_pct = len(palette(i) | palette(o)) / 10
|
||||
obj_dens = (len(objects(i, T, F, F)) / hwi + len(objects(o, T, F, F)) / hwo) / 2
|
||||
return (pix_pct + col_pct + obj_dens) / 3
|
||||
|
||||
|
||||
def unifint(rng: random.Random, diff_lb: float, diff_ub: float, bounds: Tuple[int, int]) -> int:
|
||||
"""
|
||||
rng
|
||||
|
|
@ -74,30 +107,6 @@ def format_task(task: dict) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def plot_task(task: List[dict], title: str = None) -> None:
|
||||
"""
|
||||
displays a task
|
||||
"""
|
||||
cmap = ListedColormap(
|
||||
["#000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"]
|
||||
)
|
||||
norm = Normalize(vmin=0, vmax=9)
|
||||
args = {"cmap": cmap, "norm": norm}
|
||||
height = 2
|
||||
width = len(task)
|
||||
figure_size = (width * 3, height * 3)
|
||||
figure, axes = plt.subplots(height, width, figsize=figure_size)
|
||||
for column, example in enumerate(task):
|
||||
axes[0, column].imshow(example["input"], **args)
|
||||
axes[1, column].imshow(example["output"], **args)
|
||||
axes[0, column].axis("off")
|
||||
axes[1, column].axis("off")
|
||||
if title is not None:
|
||||
figure.suptitle(title, fontsize=20)
|
||||
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
||||
plt.show()
|
||||
|
||||
|
||||
def fix_bugs(dataset: dict) -> None:
|
||||
"""
|
||||
fixes bugs in the original ARC training dataset
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue