mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
largest island curriculum (#270)
This commit is contained in:
parent
9bb6d028a3
commit
5bac641650
3 changed files with 143 additions and 37 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
|
||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
|
||||
from .largest_island import LargestIslandDataset
|
||||
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
|
||||
from .quantum_lock import QuantumLockConfig, QuantumLockDataset
|
||||
from .shortest_path import ShortestPathConfig, ShortestPathDataset
|
||||
|
||||
|
|
@ -10,6 +10,8 @@ __all__ = [
|
|||
"QuantumLockConfig",
|
||||
"QuantumLockDataset",
|
||||
"LargestIslandDataset",
|
||||
"LargestIslandConfig",
|
||||
"LargestIslandCurriculum",
|
||||
"CourseScheduleDataset",
|
||||
"CourseScheduleConfig",
|
||||
"CourseScheduleCurriculum",
|
||||
|
|
|
|||
|
|
@ -9,10 +9,9 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
MIN_MAP_DIM = 1
|
||||
|
||||
QUESTION_TEMPLATE = """You are given the following {rows} x {cols} binary matrix grid:
|
||||
{grid}
|
||||
|
||||
|
|
@ -29,11 +28,15 @@ Return the maximum area of an island in grid. If there is no island, return 0.
|
|||
class LargestIslandConfig:
|
||||
"""Configuration for Largest Island dataset generation"""
|
||||
|
||||
rows: int = 10 # Number of rows in the grid
|
||||
cols: int = 10 # Number of columns in the grid
|
||||
min_rows: int = 5 # Minimum number of rows in the grid
|
||||
max_rows: int = 10 # Maximum number of rows in the grid
|
||||
min_cols: int = 5 # Minimum number of columns in the grid
|
||||
max_cols: int = 10 # Maximum number of columns in the grid
|
||||
min_num_islands: int = 0
|
||||
max_num_islands: int = (
|
||||
5 # Maximum number of islands (actual max might be smaller due to merging of islands during random walk)
|
||||
)
|
||||
min_island_size: int = 0
|
||||
max_island_size: int = (
|
||||
10 # Maximum size of an island (actual max might be larger due to merging of islands during random walk)
|
||||
)
|
||||
|
|
@ -43,10 +46,10 @@ class LargestIslandConfig:
|
|||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert MIN_MAP_DIM <= self.rows, f"rows must be between larger than {MIN_MAP_DIM}"
|
||||
assert MIN_MAP_DIM <= self.cols, f"cols must be between larger than {MIN_MAP_DIM}"
|
||||
assert 0 <= self.max_num_islands, "max_num_islands must be non-negative"
|
||||
assert 0 <= self.max_island_size, "max_island_size must be non-negative"
|
||||
assert 1 <= self.min_rows <= self.max_rows, "Invalid rows range"
|
||||
assert 1 <= self.min_cols <= self.max_cols, "Invalid cols range"
|
||||
assert 0 <= self.min_num_islands <= self.max_num_islands, "Invalid num_islands range"
|
||||
assert 0 <= self.min_island_size <= self.max_island_size, "Invalid island_size range"
|
||||
|
||||
|
||||
class LargestIslandDataset(ProceduralDataset):
|
||||
|
|
@ -55,27 +58,27 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
def __init__(self, config: LargestIslandConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _is_valid_cell(self, r: int, c: int) -> bool:
|
||||
return 0 <= r < self.config.rows and 0 <= c < self.config.cols
|
||||
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
|
||||
return 0 <= r < rows and 0 <= c < cols
|
||||
|
||||
def _create_grid(self, rng: Random) -> list[list[int]]:
|
||||
def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]:
|
||||
"""Create a random grid of islands using a random walk algorithm"""
|
||||
grid = [[0] * self.config.cols for _ in range(self.config.rows)]
|
||||
grid = [[0] * cols for _ in range(rows)]
|
||||
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
|
||||
|
||||
def create_island():
|
||||
r, c = rng.randint(0, self.config.rows - 1), rng.randint(0, self.config.cols - 1)
|
||||
capped_size = min(rng.randint(0, self.config.max_island_size), self.config.rows * self.config.cols)
|
||||
r, c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
|
||||
capped_size = min(rng.randint(self.config.min_island_size, self.config.max_island_size), rows * cols)
|
||||
for _ in range(capped_size):
|
||||
grid[r][c] = 1
|
||||
rng.shuffle(directions)
|
||||
for dr, dc in directions:
|
||||
new_r, new_c = r + dr, c + dc
|
||||
if self._is_valid_cell(new_r, new_c) and grid[new_r][new_c] == 0:
|
||||
if self._is_valid_cell(new_r, new_c, rows, cols) and grid[new_r][new_c] == 0:
|
||||
r, c = new_r, new_c
|
||||
break
|
||||
|
||||
num_islands = rng.randint(0, self.config.max_num_islands)
|
||||
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
|
||||
for _ in range(num_islands):
|
||||
create_island()
|
||||
|
||||
|
|
@ -83,6 +86,7 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
|
||||
def _get_largest_island(self, grid: list[list[int]]) -> int:
|
||||
"""Find the largest island in the grid"""
|
||||
rows, cols = len(grid), len(grid[0])
|
||||
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
|
||||
visited = set()
|
||||
|
||||
|
|
@ -94,15 +98,19 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
r, c = queue.popleft()
|
||||
for dr, dc in directions:
|
||||
new_r, new_c = r + dr, c + dc
|
||||
if self._is_valid_cell(new_r, new_c) and (new_r, new_c) not in visited and grid[new_r][new_c] == 1:
|
||||
if (
|
||||
self._is_valid_cell(new_r, new_c, rows, cols)
|
||||
and (new_r, new_c) not in visited
|
||||
and grid[new_r][new_c] == 1
|
||||
):
|
||||
area += 1
|
||||
visited.add((new_r, new_c))
|
||||
queue.append((new_r, new_c))
|
||||
return area
|
||||
|
||||
max_area = 0
|
||||
for r in range(self.config.rows):
|
||||
for c in range(self.config.cols):
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
if grid[r][c] == 1 and (r, c) not in visited:
|
||||
max_area = max(max_area, bfs(r, c))
|
||||
|
||||
|
|
@ -120,16 +128,67 @@ class LargestIslandDataset(ProceduralDataset):
|
|||
"""Generate a single Largest Island question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
grid = self._create_grid(rng)
|
||||
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||
cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||
grid = self._create_grid(rng, rows, cols)
|
||||
grid_str = self._grid_to_string(grid)
|
||||
|
||||
answer = self._get_largest_island(grid)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(rows=self.config.rows, cols=self.config.cols, grid=grid_str),
|
||||
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
|
||||
"answer": str(answer),
|
||||
"metadata": {"grid": grid, "solution": answer},
|
||||
}
|
||||
|
||||
|
||||
class LargestIslandCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(LargestIslandCurriculum.__name__, LargestIslandConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="rows",
|
||||
levels=[5, 10, 50, 100],
|
||||
default_level=0,
|
||||
description="Number of rows in the grid",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_rows",
|
||||
upper_field_name="max_rows",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="cols",
|
||||
levels=[5, 10, 50, 100],
|
||||
default_level=0,
|
||||
description="Number of columns in the grid",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_cols",
|
||||
upper_field_name="max_cols",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_islands",
|
||||
levels=[2, 5, 10, 20],
|
||||
default_level=0,
|
||||
description="Number of islands in the grid",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_num_islands",
|
||||
upper_field_name="max_num_islands",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="island_size",
|
||||
levels=[5, 10, 20, 30],
|
||||
default_level=0,
|
||||
description="Size of the islands in the grid",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=0,
|
||||
lower_field_name="min_island_size",
|
||||
upper_field_name="max_island_size",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue