reasoning-gym/reasoning_gym/arc/rearc.py
2025-03-13 21:14:03 +01:00

181 lines
7.2 KiB
Python

from dataclasses import dataclass, field
from random import Random
from typing import Any, Callable, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
RNG_DIFFICULTY_LEVELS = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2]
RNG_DIFFICULTY_RANGES = [
(RNG_DIFFICULTY_LEVELS[i], RNG_DIFFICULTY_LEVELS[i + 1]) for i in range(len(RNG_DIFFICULTY_LEVELS) - 1)
]
PSO_DIFFICULTY_LEVELS = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 1]
PSO_DIFFICULTY_RANGES = [
(PSO_DIFFICULTY_LEVELS[i], PSO_DIFFICULTY_LEVELS[i + 1]) for i in range(len(PSO_DIFFICULTY_LEVELS) - 1)
]
@dataclass
class ReArcConfig:
min_examples: int = 3 # minimum number of board pairs shown
max_examples: int = 5 # maximum number of board pairs shown
diff_lb: int = 0
diff_ub: int = 0.2
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
seed: Optional[int] = None
size: int = 500
rng_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: RNG_DIFFICULTY_RANGES)
rng_difficulty_weights: list[float] = field(
default_factory=lambda: [1 / len(RNG_DIFFICULTY_RANGES)] * len(RNG_DIFFICULTY_RANGES)
)
pso_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: PSO_DIFFICULTY_RANGES)
pso_difficulty_weights: list[float] = field(
default_factory=lambda: [1 / len(PSO_DIFFICULTY_RANGES)] * len(PSO_DIFFICULTY_RANGES)
)
def validate(self):
assert self.min_examples > 0, "min_examples must be positive"
assert self.min_examples <= self.max_examples, "min_examples must be <= max_examples"
assert self.diff_lb <= self.diff_ub, "diff_lb must be <= diff_ub."
assert self.size > 0, "Size of dataset must be positive."
class ReArcDataset(ProceduralDataset):
def __init__(self, config: ReArcConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.board_format_opts = config.board_format_opts
self._prompt_templates = ARC_PROMPT_TEMPLATE
self.diff_lb = config.diff_lb
self.diff_ub = config.diff_ub
# lazy import of re-arc dsl & generators
from .rearc_utils import generators
from .rearc_utils.utils import get_generators, get_pso_difficulty
self._generators = get_generators(generators)
self.get_pso_difficulty = get_pso_difficulty
@staticmethod
def get_rng_difficulty(rng: Random) -> float:
if not hasattr(rng, "difficulty_samples"):
return 0.0
samples = rng.difficulty_samples
avg = sum(samples) / len(samples) if samples else 0.0
rng.difficulty_samples = []
return avg
def __len__(self) -> int:
return self.size
def format_rearc_input(self, rng: Random, task: dict, generator: Callable) -> str:
"""
Format a ReArc task input with multiple examples and test input.
"""
num_examples = rng.randint(self.config.min_examples, self.config.max_examples)
examples = [
format_board_pair(
i + 1, generator(rng, self.diff_lb, self.diff_ub), formatting_options=self.config.board_format_opts
)
for i in range(num_examples)
]
examples = "".join(examples)
input_grid = format_board(task["input"], self.board_format_opts)
return self._prompt_templates.format(examples=examples, input_grid=input_grid)
def __getitem__(self, idx: int) -> dict:
"""
Generate a single ReArc task
"""
rng = Random(self.seed + idx)
pso_difficulty_range = rng.choices(
self.config.pso_difficulty_ranges, weights=self.config.pso_difficulty_weights, k=1
)[0]
while True:
task_id = rng.choice(list(self._generators.keys()))
generator = self._generators[task_id]
difficulty_range = rng.choices(
self.config.rng_difficulty_ranges, weights=self.config.rng_difficulty_weights, k=1
)[0]
task = generator(rng, difficulty_range[0], difficulty_range[1])
pso_difficulty = self.get_pso_difficulty(task)
if (pso_difficulty_range[0] <= pso_difficulty) and (pso_difficulty <= pso_difficulty_range[1]):
break
rng_difficulty = self.get_rng_difficulty(rng)
input_prompt = self.format_rearc_input(rng, task, generator)
answer = format_board(task["output"], self.board_format_opts)
return {
"question": input_prompt,
"answer": answer,
"metadata": {
"input": task["input"],
"output": task["output"],
"task_id": task_id,
"difficulty": {
"rng": rng_difficulty,
"pso": pso_difficulty,
},
},
}
def score_answer(self, answer: str, entry: dict[str, Any]) -> float:
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
answer_board = parse_board(answer, self.board_format_opts)
if answer_board == metadata["output"]:
reward = 1.0
else:
reward = 0.05
except:
reward = 0.0
return reward
class ReArcCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ReArcCurriculum.__name__, ReArcConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="pso_difficulty",
field_name="pso_difficulty_weights",
description="The range of PSO difficulty for the Arc problem",
default_level=0,
levels=[
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
], # only sample/generate the hardest tasks PSO difficulty
),
ScalarAttributeDefinition(
name="rng_difficulty",
field_name="rng_difficulty_weights",
description="The range of RNG difficulty for the Arc problem",
default_level=0,
levels=[
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
], # only sample/generate the hardest tasks wrs RNG difficulty
),
)
register_dataset("rearc", ReArcDataset, ReArcConfig, ReArcCurriculum)