mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* added curriculum * readapted readme * corrected small errors * Delete eval/eval/r1/algorithmic/word_sorting.json * removed redundant argument * added spell * removed duplicated fit * changed config * added composite changes * added composite changes * updated yaml * added spell backward * updated read me * added qwen2.5 * added * Add files via upload * updated missing trainer func * updated curr * updated spell back * updated correctness score func * updated configs * added local evals * added updates * updated datasets * added fsdp to hf utility * added algorithmic qwen 3b yaml * updated read me * updated configs * added preappend token * updated with thinking token * updated test score board * resolved comments * added evaluation scripts * removed results from pr * added config * added partial reward scoring * added evaluation composites * added training configs * added games eval * added rubriks cube * resolved merge cinflicts * added games config * added latest eval configs * updated strucutre * Delete training/evaluations/eval_graphs_composite.yaml --------- Co-authored-by: joesharratt1229 <joesharrat1229@gmail.com>
177 lines
5.9 KiB
Python
177 lines
5.9 KiB
Python
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Any, Optional
|
|
|
|
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
|
|
QUESTION_TEMPLATE = """Your task is to count how many rectangles are present in an ASCII grid.
|
|
|
|
Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'.
|
|
|
|
Your output should be a single number, representing the total count of rectangles.
|
|
|
|
Now, it's your turn. How many rectangles do you see in the grid below?
|
|
{puzzle}
|
|
"""
|
|
|
|
DATASET_NAME = "rectangle_count"
|
|
CONST_TERM = 0.8
|
|
D = 5
|
|
|
|
|
|
def draw_rectangles_with_overlap(n, width, height, rng):
|
|
# Create a grid that holds a count of how many times a cell is drawn.
|
|
grid = [[0 for _ in range(width)] for _ in range(height)]
|
|
rectangles = []
|
|
|
|
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
|
|
attempts = 0
|
|
|
|
while len(rectangles) < n and attempts < max_attempts:
|
|
attempts += 1
|
|
# Ensure minimum width and height of 3.
|
|
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
|
|
# Similarly, bottom must be at least top + 2.
|
|
left = rng.randint(0, width - 3)
|
|
right = rng.randint(left + 2, width - 1)
|
|
top = rng.randint(0, height - 3)
|
|
bottom = rng.randint(top + 2, height - 1)
|
|
|
|
# Prepare a list of all the cells that would be updated.
|
|
cells_to_update = []
|
|
|
|
# Top edge:
|
|
for col in range(left, right + 1):
|
|
cells_to_update.append((top, col))
|
|
# Bottom edge:
|
|
for col in range(left, right + 1):
|
|
cells_to_update.append((bottom, col))
|
|
# Left edge (excluding corners already drawn):
|
|
for row in range(top + 1, bottom):
|
|
cells_to_update.append((row, left))
|
|
# Right edge (excluding corners already drawn):
|
|
for row in range(top + 1, bottom):
|
|
cells_to_update.append((row, right))
|
|
|
|
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
|
|
conflict = False
|
|
for r, c in cells_to_update:
|
|
if grid[r][c] >= 2:
|
|
conflict = True
|
|
break
|
|
if conflict:
|
|
continue # Skip this rectangle candidate
|
|
|
|
# No conflict: update the grid counts.
|
|
for r, c in cells_to_update:
|
|
grid[r][c] += 1
|
|
|
|
# Save the rectangle (stored as (left, right, top, bottom)).
|
|
rectangles.append((left, right, top, bottom))
|
|
|
|
if len(rectangles) < n:
|
|
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
|
|
|
|
# Print the grid.
|
|
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
|
|
lines = ""
|
|
for row in grid:
|
|
line = "".join(" " if count == 0 else ("#" if count == 1 else "█") for count in row)
|
|
lines = lines + line + "\n"
|
|
return lines, len(rectangles)
|
|
|
|
|
|
@dataclass
|
|
class RectangleCountConfig:
|
|
"""Configuration for RectangleCount puzzle generation"""
|
|
|
|
max_rectangles: int = 10
|
|
width: int = 80
|
|
height: int = 80
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
|
|
def validate(self):
|
|
"""Validate configuration parameters"""
|
|
assert self.width >= 10, "width must be gte 10"
|
|
assert self.height >= 10, "height must be gte 10"
|
|
|
|
|
|
class RectangleCountDataset(ProceduralDataset):
|
|
"""Generates ASCII rectangle counting puzzles with configurable parameters"""
|
|
|
|
def __init__(self, config: RectangleCountConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single RectangleCount task
|
|
|
|
Returns:
|
|
dict with keys:
|
|
- question: str, the task description
|
|
- answer: str, a solution string
|
|
- metadata: dict with generation parameters
|
|
"""
|
|
rng = Random(self.seed + idx)
|
|
|
|
target = rng.randint(1, self.config.max_rectangles)
|
|
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
|
|
|
return {
|
|
"question": QUESTION_TEMPLATE.format(puzzle=puzzle),
|
|
"answer": str(answer),
|
|
"metadata": {
|
|
"source_dataset": DATASET_NAME,
|
|
"source_index": idx,
|
|
"puzzle": puzzle,
|
|
"solution": answer,
|
|
"num_rectangles": target,
|
|
"difficulty": {
|
|
"max_rectangles": self.config.max_rectangles,
|
|
},
|
|
},
|
|
}
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
"""Determine if the solution provided solves the RectangleCount task,
|
|
awarding partial credit if the guess is close.
|
|
|
|
Returns:
|
|
float: A score between 0.0 and 1.0.
|
|
"""
|
|
correct_str = entry["answer"].lower().replace("\n", "")
|
|
|
|
try:
|
|
correct_val = int(correct_str)
|
|
user_val = int(answer.strip())
|
|
except (ValueError, TypeError, AttributeError):
|
|
return 0.0
|
|
distance = abs(user_val - correct_val)
|
|
|
|
if distance == 0:
|
|
return 1.0
|
|
if distance >= D:
|
|
return 0.0
|
|
|
|
score = 1.0 - (distance / float(D))
|
|
score = CONST_TERM * score
|
|
return max(0.0, score)
|
|
|
|
|
|
class RectangleCountCurriculum(BaseCurriculum):
|
|
def __init__(self):
|
|
super().__init__(RectangleCountCurriculum.__name__, RectangleCountConfig)
|
|
|
|
# Define attributes
|
|
self._define_attributes(
|
|
ScalarAttributeDefinition(
|
|
name="max_rectangles",
|
|
levels=[5, 10, 15, 20, 25],
|
|
description="Number of rectangles in the grid",
|
|
field_name="max_rectangles",
|
|
),
|
|
)
|
|
|
|
|
|
register_dataset(DATASET_NAME, RectangleCountDataset, RectangleCountConfig, RectangleCountCurriculum)
|