mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Merge pull request #111 from open-thought/rich/rectanglecount
Add Rectangle Count Dataset
This commit is contained in:
commit
36e8228ff2
3 changed files with 157 additions and 0 deletions
|
|
@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities.
|
|||
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
|
||||
from .figlet_fonts import FigletFontConfig, FigletFontDataset
|
||||
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
|
||||
from .rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -16,4 +17,6 @@ __all__ = [
|
|||
"NumberSequenceDataset",
|
||||
"RubiksCubeConfig",
|
||||
"RubiksCubeDataset",
|
||||
"RectangleCountConfig",
|
||||
"RectangleCountDataset",
|
||||
]
|
||||
|
|
|
|||
135
reasoning_gym/cognition/rectangle_count.py
Normal file
135
reasoning_gym/cognition/rectangle_count.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
def draw_rectangles_with_overlap(n, width, height, rng):
|
||||
# Create a grid that holds a count of how many times a cell is drawn.
|
||||
grid = [[0 for _ in range(width)] for _ in range(height)]
|
||||
rectangles = []
|
||||
|
||||
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
|
||||
attempts = 0
|
||||
|
||||
while len(rectangles) < n and attempts < max_attempts:
|
||||
attempts += 1
|
||||
# Ensure minimum width and height of 3.
|
||||
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
|
||||
# Similarly, bottom must be at least top + 2.
|
||||
left = rng.randint(0, width - 3)
|
||||
right = rng.randint(left + 2, width - 1)
|
||||
top = rng.randint(0, height - 3)
|
||||
bottom = rng.randint(top + 2, height - 1)
|
||||
|
||||
# Prepare a list of all the cells that would be updated.
|
||||
cells_to_update = []
|
||||
|
||||
# Top edge:
|
||||
for col in range(left, right + 1):
|
||||
cells_to_update.append((top, col))
|
||||
# Bottom edge:
|
||||
for col in range(left, right + 1):
|
||||
cells_to_update.append((bottom, col))
|
||||
# Left edge (excluding corners already drawn):
|
||||
for row in range(top + 1, bottom):
|
||||
cells_to_update.append((row, left))
|
||||
# Right edge (excluding corners already drawn):
|
||||
for row in range(top + 1, bottom):
|
||||
cells_to_update.append((row, right))
|
||||
|
||||
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
|
||||
conflict = False
|
||||
for r, c in cells_to_update:
|
||||
if grid[r][c] >= 2:
|
||||
conflict = True
|
||||
break
|
||||
if conflict:
|
||||
continue # Skip this rectangle candidate
|
||||
|
||||
# No conflict: update the grid counts.
|
||||
for r, c in cells_to_update:
|
||||
grid[r][c] += 1
|
||||
|
||||
# Save the rectangle (stored as (left, right, top, bottom)).
|
||||
rectangles.append((left, right, top, bottom))
|
||||
|
||||
if len(rectangles) < n:
|
||||
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
|
||||
|
||||
# Print the grid.
|
||||
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
|
||||
lines = ""
|
||||
for row in grid:
|
||||
line = "".join(" " if count == 0 else ("#" if count == 1 else "█") for count in row)
|
||||
lines = lines + line + "\n"
|
||||
return lines, len(rectangles)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RectangleCountConfig:
|
||||
"""Configuration for RectangleCount puzzle generation"""
|
||||
|
||||
max_rectangles: int = 10
|
||||
width: int = 80
|
||||
height: int = 80
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert self.width >= 10, "width must be gte 10"
|
||||
assert self.height >= 10, "height must be gte 10"
|
||||
|
||||
|
||||
class RectangleCountDataset(ProceduralDataset):
|
||||
"""Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters"""
|
||||
|
||||
def __init__(self, config: RectangleCountConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single RectangleCount task
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the task description
|
||||
- answer: str, a solution string
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
target = rng.randint(1, self.config.max_rectangles)
|
||||
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
||||
|
||||
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. \n\n {puzzle}"
|
||||
|
||||
return {
|
||||
"question": puzz,
|
||||
"answer": str(answer),
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Determine if the solution provided solves the RectangleCount task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
|
||||
|
||||
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)
|
||||
19
tests/test_rectangle_count.py
Normal file
19
tests/test_rectangle_count.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||
|
||||
|
||||
def test_dice():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40)
|
||||
dataset = RectangleCountDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
Loading…
Add table
Add a link
Reference in a new issue