mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* update difficulty metadata for logic datasets * update difficulty metadata for graph datasets * update difficulty metadata for geometry datasets * update difficulty metadata for games datasets * update difficulty metadata for cognition datasets * update difficulty metadata for arithmetic datasets * update difficulty metadata for arc datasets * update difficulty metadata for algorithmic datasets * update difficulty metadata for algebra datasets * use tuples * update tests * update tests
181 lines
7.3 KiB
Python
181 lines
7.3 KiB
Python
from dataclasses import dataclass, field
|
|
from random import Random
|
|
from typing import Any, Callable, Optional
|
|
|
|
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
|
|
|
|
RNG_DIFFICULTY_LEVELS = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2]
|
|
RNG_DIFFICULTY_RANGES = [
|
|
(RNG_DIFFICULTY_LEVELS[i], RNG_DIFFICULTY_LEVELS[i + 1]) for i in range(len(RNG_DIFFICULTY_LEVELS) - 1)
|
|
]
|
|
|
|
PSO_DIFFICULTY_LEVELS = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 1]
|
|
PSO_DIFFICULTY_RANGES = [
|
|
(PSO_DIFFICULTY_LEVELS[i], PSO_DIFFICULTY_LEVELS[i + 1]) for i in range(len(PSO_DIFFICULTY_LEVELS) - 1)
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ReArcConfig:
|
|
min_examples: int = 3 # minimum number of board pairs shown
|
|
max_examples: int = 5 # maximum number of board pairs shown
|
|
diff_lb: int = 0
|
|
diff_ub: int = 0.2
|
|
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
rng_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: RNG_DIFFICULTY_RANGES)
|
|
rng_difficulty_weights: list[float] = field(
|
|
default_factory=lambda: [1 / len(RNG_DIFFICULTY_RANGES)] * len(RNG_DIFFICULTY_RANGES)
|
|
)
|
|
pso_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: PSO_DIFFICULTY_RANGES)
|
|
pso_difficulty_weights: list[float] = field(
|
|
default_factory=lambda: [1 / len(PSO_DIFFICULTY_RANGES)] * len(PSO_DIFFICULTY_RANGES)
|
|
)
|
|
|
|
def validate(self):
|
|
assert self.min_examples > 0, "min_examples must be positive"
|
|
assert self.min_examples <= self.max_examples, "min_examples must be <= max_examples"
|
|
assert self.diff_lb <= self.diff_ub, "diff_lb must be <= diff_ub."
|
|
assert self.size > 0, "Size of dataset must be positive."
|
|
|
|
|
|
class ReArcDataset(ProceduralDataset):
|
|
def __init__(self, config: ReArcConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
self.board_format_opts = config.board_format_opts
|
|
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
|
self.diff_lb = config.diff_lb
|
|
self.diff_ub = config.diff_ub
|
|
|
|
# lazy import of re-arc dsl & generators
|
|
from .rearc_utils import generators
|
|
from .rearc_utils.utils import get_generators, get_pso_difficulty
|
|
|
|
self._generators = get_generators(generators)
|
|
self.get_pso_difficulty = get_pso_difficulty
|
|
|
|
@staticmethod
|
|
def get_rng_difficulty(rng: Random) -> float:
|
|
if not hasattr(rng, "difficulty_samples"):
|
|
return 0.0
|
|
samples = rng.difficulty_samples
|
|
avg = sum(samples) / len(samples) if samples else 0.0
|
|
rng.difficulty_samples = []
|
|
return avg
|
|
|
|
def __len__(self) -> int:
|
|
return self.size
|
|
|
|
def format_rearc_input(self, rng: Random, task: dict, generator: Callable) -> str:
|
|
"""
|
|
Format a ReArc task input with multiple examples and test input.
|
|
"""
|
|
|
|
num_examples = rng.randint(self.config.min_examples, self.config.max_examples)
|
|
examples = [
|
|
format_board_pair(
|
|
i + 1, generator(rng, self.diff_lb, self.diff_ub), formatting_options=self.config.board_format_opts
|
|
)
|
|
for i in range(num_examples)
|
|
]
|
|
examples = "".join(examples)
|
|
input_grid = format_board(task["input"], self.board_format_opts)
|
|
|
|
return self._prompt_templates.format(examples=examples, input_grid=input_grid)
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""
|
|
Generate a single ReArc task
|
|
"""
|
|
rng = Random(self.seed + idx)
|
|
pso_difficulty_range = rng.choices(
|
|
self.config.pso_difficulty_ranges, weights=self.config.pso_difficulty_weights, k=1
|
|
)[0]
|
|
|
|
while True:
|
|
task_id = rng.choice(list(self._generators.keys()))
|
|
generator = self._generators[task_id]
|
|
difficulty_range = rng.choices(
|
|
self.config.rng_difficulty_ranges, weights=self.config.rng_difficulty_weights, k=1
|
|
)[0]
|
|
task = generator(rng, difficulty_range[0], difficulty_range[1])
|
|
pso_difficulty = self.get_pso_difficulty(task)
|
|
if (pso_difficulty_range[0] <= pso_difficulty) and (pso_difficulty <= pso_difficulty_range[1]):
|
|
break
|
|
|
|
rng_difficulty = self.get_rng_difficulty(rng)
|
|
input_prompt = self.format_rearc_input(rng, task, generator)
|
|
answer = format_board(task["output"], self.board_format_opts)
|
|
|
|
return {
|
|
"question": input_prompt,
|
|
"answer": answer,
|
|
"metadata": {
|
|
"input": task["input"],
|
|
"output": task["output"],
|
|
"task_id": task_id,
|
|
"rng": rng_difficulty,
|
|
"pso": pso_difficulty,
|
|
"difficulty": {
|
|
"rng_difficulty": self.config.rng_difficulty_weights,
|
|
"pso_difficulty": self.config.pso_difficulty_weights,
|
|
},
|
|
},
|
|
}
|
|
|
|
def score_answer(self, answer: str, entry: dict[str, Any]) -> float:
|
|
reward = 0.0
|
|
metadata = entry["metadata"]
|
|
if answer is not None:
|
|
try:
|
|
answer_board = parse_board(answer, self.board_format_opts)
|
|
if answer_board == metadata["output"]:
|
|
reward = 1.0
|
|
else:
|
|
reward = 0.05
|
|
except:
|
|
reward = 0.0
|
|
return reward
|
|
|
|
|
|
class ReArcCurriculum(BaseCurriculum):
|
|
def __init__(self):
|
|
super().__init__(ReArcCurriculum.__name__, ReArcConfig)
|
|
self._define_attributes(
|
|
ScalarAttributeDefinition(
|
|
name="pso_difficulty",
|
|
field_name="pso_difficulty_weights",
|
|
description="The range of PSO difficulty for the Arc problem",
|
|
levels=[
|
|
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty
|
|
[0, 1, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
], # only sample/generate the hardest tasks PSO difficulty
|
|
),
|
|
ScalarAttributeDefinition(
|
|
name="rng_difficulty",
|
|
field_name="rng_difficulty_weights",
|
|
description="The range of RNG difficulty for the Arc problem",
|
|
levels=[
|
|
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty
|
|
[0, 1, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
], # only sample/generate the hardest tasks wrs RNG difficulty
|
|
),
|
|
)
|
|
|
|
|
|
register_dataset("rearc", ReArcDataset, ReArcConfig, ReArcCurriculum)
|