Merge branch 'main' of https://github.com/open-thought/reasoning-gym into env/spiral-matrix

This commit is contained in:
Zafir Stojanovski 2025-02-08 18:52:45 +01:00
commit 3f5cfeed95
27 changed files with 28427 additions and 125 deletions

View file

@ -15,6 +15,8 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset
from .ransom_note import RansomNoteConfig, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
@ -52,6 +54,10 @@ __all__ = [
"GroupAnagramsDataset",
"SpiralMatrixConfig",
"SpiralMatrixDataset",
"RansomNoteConfig",
"RansomNoteDataset",
"IsomorphicStringsConfig",
"IsomorphicStringsDataset",
"RotateMatrixConfig",
"RotateMatrixDataset",
]

View file

@ -0,0 +1,99 @@
"""Check if you can construct a ransom note from letters in a magazine.
A popular Leetcode problem:
https://leetcode.com/problems/ransom-note/description/
"""
from collections import defaultdict
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
MAX_NOTE_LENGTH = 100_000
MAX_MAGAZINE_LENGTH = 100_001
QUESTION_TEMPLATE = """Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise.
Each letter in the magazine string can only be used once in your ransom note.
Ransom note: {ransom_note}
Magazine: {magazine}
"""
@dataclass
class RansomNoteConfig:
"""Configuration for Ransom Note dataset generation"""
max_note_length: int = 10 # Maximum length of the ransom note
max_magazine_length: int = 30 # Maximum length of the magazine
p_solvable: float = 0.5 # Probability that the ransom note can be constructed
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH"
assert (
2 <= self.max_magazine_length <= MAX_MAGAZINE_LENGTH
), "max_magazine_length must be between 2 and MAX_MAGAZINE_LENGTH"
assert self.max_note_length < self.max_magazine_length, "max_note_length must be less than max_magazine_length"
assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"
class RansomNoteDataset(ProceduralDataset):
"""Generates Ransom Note exercises with configurable difficulty"""
def __init__(self, config: RansomNoteConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)}
def _get_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]:
"""Generate random ransom note and magazine"""
ransom_note_len = rng.randint(1, self.config.max_note_length)
ransom_note = [rng.choice(list(self.letters)) for _ in range(ransom_note_len)]
magazine_len = rng.randint(ransom_note_len, self.config.max_magazine_length)
magazine = ransom_note.copy()
if solvable:
magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_len - ransom_note_len)])
else:
remove_letter = rng.choice(magazine)
magazine.remove(remove_letter)
magazine.extend(
[rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_len - ransom_note_len + 1)]
)
rng.shuffle(ransom_note)
rng.shuffle(magazine)
return "".join(ransom_note), "".join(magazine)
def _can_construct(self, ransom_note: str, magazine: str) -> bool:
"""Check if ransom note can be constructed from magazine"""
count = defaultdict(int)
for c in magazine:
count[c] += 1
for c in ransom_note:
if count[c] <= 0:
return False
count[c] -= 1
return True
def __getitem__(self, idx: int) -> dict:
"""Generate a single Group Anagrams question"""
rng = Random(self.seed + idx)
solvable = rng.random() < self.config.p_solvable
ransom_note, magazine = self._get_inputs(rng, solvable)
answer = self._can_construct(ransom_note, magazine)
return {
"question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine),
"answer": str(answer),
"metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable},
}
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)

View file

@ -0,0 +1,103 @@
"""Rotate a square matrix clockwise.
A popular Leetcode problem:
https://leetcode.com/problems/rotate-image/description/
"""
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
Example:
Input: Rotate the matrix below by 90 degrees clockwise:
1 2 3
4 5 6
7 8 9
Output:
7 4 1
8 5 2
9 6 3
Rotate the matrix below by {degrees} degrees clockwise:
{matrix}
"""
@dataclass
class RotateMatrixConfig:
"""Configuration for Rotate Matrix dataset generation"""
max_n: int = 10 # Maximum size of the matrix
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
class RotateMatrixDataset(ProceduralDataset):
"""Generates Rotate Matrix exercises with configurable difficulty"""
def __init__(self, config: RotateMatrixConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix"""
n = rng.randint(1, self.config.max_n)
numbers = list(range(n**2))
rng.shuffle(numbers)
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix K times by 90 degrees clockwise"""
num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix)
for _ in range(num_rotations):
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
return output
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Spiral Matrix question"""
rng = Random(self.seed + idx)
matrix = self._get_matrix(rng)
num_rotations = rng.randint(0, self.config.max_rotations)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_rotated(matrix, num_rotations)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
"answer": answer_str,
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
}
register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig)