mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
feat: Add Arc1D dataset with comprehensive task generation and configuration
This commit is contained in:
parent
905ef7b89d
commit
b599d6e1a2
1 changed files with 169 additions and 1 deletions
|
|
@ -1,5 +1,8 @@
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Callable, Tuple
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
from ..factory import register_dataset
|
||||||
|
|
||||||
|
|
||||||
def gen_field(size: int, color: int = 0) -> List[int]:
|
def gen_field(size: int, color: int = 0) -> List[int]:
|
||||||
|
|
@ -1007,6 +1010,167 @@ def task_repeat_pattern_full(rng: Random, size: int) -> Optional[Dict[str, List[
|
||||||
return {"input": question, "output": answer}
|
return {"input": question, "output": answer}
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Arc1DConfig:
|
||||||
|
"""Configuration for ARC 1D task generation"""
|
||||||
|
min_size: int = 10 # Minimum grid size
|
||||||
|
max_size: int = 30 # Maximum grid size
|
||||||
|
num_train: int = 3 # Number of training examples
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.min_size > 0, "min_size must be positive"
|
||||||
|
assert self.max_size >= self.min_size, "max_size must be >= min_size"
|
||||||
|
assert self.num_train > 0, "num_train must be positive"
|
||||||
|
assert self.size > 0, "size must be positive"
|
||||||
|
|
||||||
|
|
||||||
|
# Table of all ARC 1D task functions with their parameters
|
||||||
|
ARC_1D_TASKS = {
|
||||||
|
# Move tasks
|
||||||
|
"move_1pix_solid": (task_move_n_pix, {"move_pix": 1, "solid": True}),
|
||||||
|
"move_2pix_solid": (task_move_n_pix, {"move_pix": 2, "solid": True}),
|
||||||
|
"move_3pix_solid": (task_move_n_pix, {"move_pix": 3, "solid": True}),
|
||||||
|
"move_4pix_solid": (task_move_n_pix, {"move_pix": 4, "solid": True}),
|
||||||
|
"move_1pix_colorful": (task_move_n_pix, {"move_pix": 1, "solid": False}),
|
||||||
|
"move_2pix_colorful": (task_move_n_pix, {"move_pix": 2, "solid": False}),
|
||||||
|
"move_3pix_colorful": (task_move_n_pix, {"move_pix": 3, "solid": False}),
|
||||||
|
"move_4pix_colorful": (task_move_n_pix, {"move_pix": 4, "solid": False}),
|
||||||
|
|
||||||
|
# Move wrapped tasks
|
||||||
|
"move_1pix_solid_wrapped": (task_move_n_pix_wrapped, {"move_pix": 1, "solid": True}),
|
||||||
|
"move_2pix_solid_wrapped": (task_move_n_pix_wrapped, {"move_pix": 2, "solid": True}),
|
||||||
|
"move_3pix_solid_wrapped": (task_move_n_pix_wrapped, {"move_pix": 3, "solid": True}),
|
||||||
|
"move_4pix_solid_wrapped": (task_move_n_pix_wrapped, {"move_pix": 4, "solid": True}),
|
||||||
|
"move_1pix_colorful_wrapped": (task_move_n_pix_wrapped, {"move_pix": 1, "solid": False}),
|
||||||
|
"move_2pix_colorful_wrapped": (task_move_n_pix_wrapped, {"move_pix": 2, "solid": False}),
|
||||||
|
"move_3pix_colorful_wrapped": (task_move_n_pix_wrapped, {"move_pix": 3, "solid": False}),
|
||||||
|
"move_4pix_colorful_wrapped": (task_move_n_pix_wrapped, {"move_pix": 4, "solid": False}),
|
||||||
|
|
||||||
|
# Gravity tasks
|
||||||
|
"gravity": (task_gravity, {}),
|
||||||
|
"gravity_counting": (task_gravity_counting, {}),
|
||||||
|
"gravity_antigravity": (task_gravity_antigravity, {}),
|
||||||
|
"gravity_one_step": (task_gravity_one_step, {}),
|
||||||
|
"gravity_weighted_colors": (task_gravity_weighted_colors, {}),
|
||||||
|
|
||||||
|
# Block tasks
|
||||||
|
"block_touch_dot": (task_block_touch_dot, {}),
|
||||||
|
"block_touch_dot_1pix": (task_block_touch_dot_n_pix, {"move_pix": 1}),
|
||||||
|
"block_touch_dot_2pix": (task_block_touch_dot_n_pix, {"move_pix": 2}),
|
||||||
|
"block_touch_dot_3pix": (task_block_touch_dot_n_pix, {"move_pix": 3}),
|
||||||
|
"block_touch_dot_4pix": (task_block_touch_dot_n_pix, {"move_pix": 4}),
|
||||||
|
"block_scale_to_dot": (task_block_scale_to_dot, {}),
|
||||||
|
"block_and_noise_remove": (task_block_and_noise_remove, {}),
|
||||||
|
"block_and_noise_remove_inside": (task_block_and_noise_remove_inside, {}),
|
||||||
|
"move_block_by_own_size": (task_move_block_by_own_size, {}),
|
||||||
|
|
||||||
|
# Pattern tasks
|
||||||
|
"two_points_and_fill": (task_two_points_and_fill, {}),
|
||||||
|
"copy_block_to_dots": (task_copy_block_to_dots, {}),
|
||||||
|
"copy_block_to_dots_colors": (task_copy_block_to_dots_colors, {}),
|
||||||
|
"repeat_pattern_full": (task_repeat_pattern_full, {}),
|
||||||
|
|
||||||
|
# Reflection tasks
|
||||||
|
"reflect_block_with_border_pixel": (task_reflect_block_with_border_pixel, {}),
|
||||||
|
"reflect_block_random": (task_reflect_block_with_border_pixel_random, {}),
|
||||||
|
"reflect_block_around_dot": (task_reflect_block_around_dot, {}),
|
||||||
|
|
||||||
|
# Color tasks
|
||||||
|
"paint_biggest_block": (task_paint_biggest_block, {}),
|
||||||
|
"recolor_blocks_by_size": (task_recolor_blocks_by_size, {}),
|
||||||
|
"change_to_five": (task_change_to_five, {}),
|
||||||
|
"recolor_blocks_from_palette": (task_recolor_blocks_from_palette, {}),
|
||||||
|
"color_left_half_blocks": (task_color_left_half_blocks, {}),
|
||||||
|
|
||||||
|
# Sorting tasks
|
||||||
|
"sort_blocks_by_size": (task_sort_blocks_by_size, {}),
|
||||||
|
"sort_complete_sequence": (task_sort_complete_sequence, {}),
|
||||||
|
|
||||||
|
# Fill tasks
|
||||||
|
"duplicate_block_from_seeds": (task_duplicate_block_from_seeds, {}),
|
||||||
|
"fill_from_pixel": (task_fill_from_pixel, {}),
|
||||||
|
"fill_until_collision": (task_fill_until_collision, {}),
|
||||||
|
|
||||||
|
# Marking tasks
|
||||||
|
"mark_size_two_blocks": (task_mark_size_two_blocks, {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Arc1DDataset(ProceduralDataset):
|
||||||
|
"""Generates ARC 1D tasks by randomly selecting from available task generators"""
|
||||||
|
|
||||||
|
def __init__(self, config: Arc1DConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.task_names = list(ARC_1D_TASKS.keys())
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single ARC 1D task with training examples
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Index of the item to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the task description and examples
|
||||||
|
- answer: str, the expected output format
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
# Create deterministic RNG from base seed and idx
|
||||||
|
item_rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Select random task
|
||||||
|
task_name = item_rng.choice(self.task_names)
|
||||||
|
task_func, task_kwargs = ARC_1D_TASKS[task_name]
|
||||||
|
|
||||||
|
# Generate training examples
|
||||||
|
train_examples = []
|
||||||
|
size = item_rng.randint(self.config.min_size, self.config.max_size)
|
||||||
|
|
||||||
|
for _ in range(self.config.num_train):
|
||||||
|
example = None
|
||||||
|
while example is None:
|
||||||
|
example = task_func(item_rng, size, **task_kwargs)
|
||||||
|
|
||||||
|
train_examples.append(example)
|
||||||
|
|
||||||
|
# Generate test example
|
||||||
|
test_example = None
|
||||||
|
while test_example is None:
|
||||||
|
test_example = task_func(item_rng, size, **task_kwargs)
|
||||||
|
|
||||||
|
# Format question
|
||||||
|
question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n"
|
||||||
|
|
||||||
|
# Add training examples
|
||||||
|
for i, example in enumerate(train_examples, 1):
|
||||||
|
question += f"Example {i}:\n"
|
||||||
|
question += "Input: " + " ".join(str(x) for x in example["input"]) + "\n"
|
||||||
|
question += "Output: " + " ".join(str(x) for x in example["output"]) + "\n\n"
|
||||||
|
|
||||||
|
# Add test input
|
||||||
|
question += "Below is a test input grid. Predict the corresponding output grid by applying the rule you found. "
|
||||||
|
question += "Describe how you derived the rule and your overall reasoning process in detail before you submit your answer. "
|
||||||
|
question += "Your final answer must be placed in <output></output> tags and should be just be the text output grid itself.\n\n"
|
||||||
|
question += "Input:\n"
|
||||||
|
question += " ".join(str(x) for x in test_example["input"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": question,
|
||||||
|
"answer": " ".join(str(x) for x in test_example["output"]),
|
||||||
|
"metadata": {
|
||||||
|
"task_name": task_name,
|
||||||
|
"size": size,
|
||||||
|
"train_examples": train_examples,
|
||||||
|
"test_example": test_example,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def task_gravity_weighted_colors(rng: Random, size: int) -> Optional[Dict[str, List[int]]]:
|
def task_gravity_weighted_colors(rng: Random, size: int) -> Optional[Dict[str, List[int]]]:
|
||||||
"""Generate a task where color 2 is heavier than color 1 in gravity."""
|
"""Generate a task where color 2 is heavier than color 1 in gravity."""
|
||||||
# Generate random field with only colors 1 and 2
|
# Generate random field with only colors 1 and 2
|
||||||
|
|
@ -1055,6 +1219,10 @@ def task_identity(task_result: Optional[Dict[str, List[int]]]) -> Optional[Dict[
|
||||||
return task_result
|
return task_result
|
||||||
|
|
||||||
|
|
||||||
|
# Register the dataset
|
||||||
|
register_dataset("arc_1d", Arc1DDataset, Arc1DConfig)
|
||||||
|
|
||||||
|
|
||||||
def task_color_left_half_blocks(rng: Random, size: int) -> Optional[Dict[str, List[int]]]:
|
def task_color_left_half_blocks(rng: Random, size: int) -> Optional[Dict[str, List[int]]]:
|
||||||
"""Generate a task where left half of blocks are colored differently."""
|
"""Generate a task where left half of blocks are colored differently."""
|
||||||
pos = 0
|
pos = 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue