reasoning-gym/reasoning_gym/cognition/rectangle_count.py
joesharratt1229 d0ef136d5b
Feat/intragen experiments (#414)
* added curriculum

* readapted readme

* corrected small errors

* Delete eval/eval/r1/algorithmic/word_sorting.json

* removed redundant argument

* added spell

* removed duplicated fit

* changed config

* added composite changes

* added composite changes

* updated yaml

* added spell backward

* updated read me

* added qwen2.5

* added

* Add files via upload

* updated missing trainer func

* updated curr

* updated spell back

* updated correctness score func

* updated configs

* added local evals

* added updates

* updated datasets

* added fsdp to hf utility

* added algorithmic qwen 3b yaml

* updated read me

* updated configs

* added preappend token

* updated with thinking token

* updated test score board

* resolved comments

* added evaluation scripts

* removed results from pr

* added config

* added partial reward scoring

* added evaluation composites

* added training configs

* added games eval

* added rubriks cube

* resolved merge cinflicts

* added games config

* added latest eval configs

* updated strucutre

* Delete training/evaluations/eval_graphs_composite.yaml

---------

Co-authored-by: joesharratt1229 <joesharrat1229@gmail.com>
2025-04-16 08:04:52 +02:00

177 lines
5.9 KiB
Python

from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to count how many rectangles are present in an ASCII grid.
Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with ''.
Your output should be a single number, representing the total count of rectangles.
Now, it's your turn. How many rectangles do you see in the grid below?
{puzzle}
"""
DATASET_NAME = "rectangle_count"
CONST_TERM = 0.8
D = 5
def draw_rectangles_with_overlap(n, width, height, rng):
# Create a grid that holds a count of how many times a cell is drawn.
grid = [[0 for _ in range(width)] for _ in range(height)]
rectangles = []
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
attempts = 0
while len(rectangles) < n and attempts < max_attempts:
attempts += 1
# Ensure minimum width and height of 3.
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
# Similarly, bottom must be at least top + 2.
left = rng.randint(0, width - 3)
right = rng.randint(left + 2, width - 1)
top = rng.randint(0, height - 3)
bottom = rng.randint(top + 2, height - 1)
# Prepare a list of all the cells that would be updated.
cells_to_update = []
# Top edge:
for col in range(left, right + 1):
cells_to_update.append((top, col))
# Bottom edge:
for col in range(left, right + 1):
cells_to_update.append((bottom, col))
# Left edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, left))
# Right edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, right))
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
conflict = False
for r, c in cells_to_update:
if grid[r][c] >= 2:
conflict = True
break
if conflict:
continue # Skip this rectangle candidate
# No conflict: update the grid counts.
for r, c in cells_to_update:
grid[r][c] += 1
# Save the rectangle (stored as (left, right, top, bottom)).
rectangles.append((left, right, top, bottom))
if len(rectangles) < n:
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
# Print the grid.
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
lines = ""
for row in grid:
line = "".join(" " if count == 0 else ("#" if count == 1 else "") for count in row)
lines = lines + line + "\n"
return lines, len(rectangles)
@dataclass
class RectangleCountConfig:
"""Configuration for RectangleCount puzzle generation"""
max_rectangles: int = 10
width: int = 80
height: int = 80
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.width >= 10, "width must be gte 10"
assert self.height >= 10, "height must be gte 10"
class RectangleCountDataset(ProceduralDataset):
"""Generates ASCII rectangle counting puzzles with configurable parameters"""
def __init__(self, config: RectangleCountConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single RectangleCount task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
target = rng.randint(1, self.config.max_rectangles)
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
return {
"question": QUESTION_TEMPLATE.format(puzzle=puzzle),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"puzzle": puzzle,
"solution": answer,
"num_rectangles": target,
"difficulty": {
"max_rectangles": self.config.max_rectangles,
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the RectangleCount task,
awarding partial credit if the guess is close.
Returns:
float: A score between 0.0 and 1.0.
"""
correct_str = entry["answer"].lower().replace("\n", "")
try:
correct_val = int(correct_str)
user_val = int(answer.strip())
except (ValueError, TypeError, AttributeError):
return 0.0
distance = abs(user_val - correct_val)
if distance == 0:
return 1.0
if distance >= D:
return 0.0
score = 1.0 - (distance / float(D))
score = CONST_TERM * score
return max(0.0, score)
class RectangleCountCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(RectangleCountCurriculum.__name__, RectangleCountConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="max_rectangles",
levels=[5, 10, 15, 20, 25],
description="Number of rectangles in the grid",
field_name="max_rectangles",
),
)
register_dataset(DATASET_NAME, RectangleCountDataset, RectangleCountConfig, RectangleCountCurriculum)