mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
added rearc curr (#358)
This commit is contained in:
parent
099ea88402
commit
6354ca5d35
3 changed files with 136 additions and 8 deletions
|
|
@ -2,9 +2,20 @@ from dataclasses import dataclass, field
|
|||
from random import Random
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||
|
||||
RNG_DIFFICULTY_LEVELS = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2]
|
||||
RNG_DIFFICULTY_RANGES = [
|
||||
(RNG_DIFFICULTY_LEVELS[i], RNG_DIFFICULTY_LEVELS[i + 1]) for i in range(len(RNG_DIFFICULTY_LEVELS) - 1)
|
||||
]
|
||||
|
||||
PSO_DIFFICULTY_LEVELS = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 1]
|
||||
PSO_DIFFICULTY_RANGES = [
|
||||
(PSO_DIFFICULTY_LEVELS[i], PSO_DIFFICULTY_LEVELS[i + 1]) for i in range(len(PSO_DIFFICULTY_LEVELS) - 1)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReArcConfig:
|
||||
|
|
@ -15,6 +26,14 @@ class ReArcConfig:
|
|||
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
rng_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: RNG_DIFFICULTY_RANGES)
|
||||
rng_difficulty_weights: list[float] = field(
|
||||
default_factory=lambda: [1 / len(RNG_DIFFICULTY_RANGES)] * len(RNG_DIFFICULTY_RANGES)
|
||||
)
|
||||
pso_difficulty_ranges: list[tuple[float, float]] = field(default_factory=lambda: PSO_DIFFICULTY_RANGES)
|
||||
pso_difficulty_weights: list[float] = field(
|
||||
default_factory=lambda: [1 / len(PSO_DIFFICULTY_RANGES)] * len(PSO_DIFFICULTY_RANGES)
|
||||
)
|
||||
|
||||
def validate(self):
|
||||
assert self.min_examples > 0, "min_examples must be positive"
|
||||
|
|
@ -72,12 +91,22 @@ class ReArcDataset(ProceduralDataset):
|
|||
Generate a single ReArc task
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
task_id = rng.choice(list(self._generators.keys()))
|
||||
generator = self._generators[task_id]
|
||||
task = generator(rng, self.diff_lb, self.diff_ub)
|
||||
pso_difficulty_range = rng.choices(
|
||||
self.config.pso_difficulty_ranges, weights=self.config.pso_difficulty_weights, k=1
|
||||
)[0]
|
||||
|
||||
while True:
|
||||
task_id = rng.choice(list(self._generators.keys()))
|
||||
generator = self._generators[task_id]
|
||||
difficulty_range = rng.choices(
|
||||
self.config.rng_difficulty_ranges, weights=self.config.rng_difficulty_weights, k=1
|
||||
)[0]
|
||||
task = generator(rng, difficulty_range[0], difficulty_range[1])
|
||||
pso_difficulty = self.get_pso_difficulty(task)
|
||||
if (pso_difficulty_range[0] <= pso_difficulty) and (pso_difficulty <= pso_difficulty_range[1]):
|
||||
break
|
||||
|
||||
rng_difficulty = self.get_rng_difficulty(rng)
|
||||
pso_difficulty = self.get_pso_difficulty(task)
|
||||
input_prompt = self.format_rearc_input(rng, task, generator)
|
||||
answer = format_board(task["output"], self.board_format_opts)
|
||||
|
||||
|
|
@ -110,4 +139,43 @@ class ReArcDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset("rearc", ReArcDataset, ReArcConfig)
|
||||
class ReArcCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(ReArcCurriculum.__name__, ReArcConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="pso_difficulty",
|
||||
field_name="pso_difficulty_weights",
|
||||
description="The range of PSO difficulty for the Arc problem",
|
||||
default_level=0,
|
||||
levels=[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
], # only sample/generate the hardest tasks PSO difficulty
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="rng_difficulty",
|
||||
field_name="rng_difficulty_weights",
|
||||
description="The range of RNG difficulty for the Arc problem",
|
||||
default_level=0,
|
||||
levels=[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
], # only sample/generate the hardest tasks wrs RNG difficulty
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("rearc", ReArcDataset, ReArcConfig, ReArcCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue