mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
fix manipulate matrix (#247)
This commit is contained in:
parent
39f151ad14
commit
f549909c3d
2 changed files with 116 additions and 69 deletions
|
|
@ -3,7 +3,9 @@
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -28,21 +30,22 @@ def num_cols(matrix: list[list[int]]) -> int:
|
|||
class ManipulateMatrixConfig:
|
||||
"""Configuration for Manipulate Matrix dataset generation"""
|
||||
|
||||
min_rows: int = 1 # Minimum number of rows
|
||||
min_cols: int = 1 # Minimum number of columns
|
||||
min_rows: int = 2 # Minimum number of rows
|
||||
min_cols: int = 2 # Minimum number of columns
|
||||
max_rows: int = 10 # Maximum number of rows
|
||||
max_cols: int = 10 # Maximum number of columns
|
||||
max_transforms: int = 5 # Maximum number of transformations to apply
|
||||
p_rotate: float = 0.2 # Probability of rotating the matrix
|
||||
p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix
|
||||
p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix
|
||||
p_dmirror: float = 0.2 # Probability of mirroring along the diagonal
|
||||
p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal
|
||||
p_map: float = 0.2 # Probability of mapping a certain value to another
|
||||
p_crop: float = 0.2 # Probability of cropping the matrix
|
||||
p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row
|
||||
p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column
|
||||
p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero
|
||||
min_transforms: int = 1 # Minimum number of transformations to apply
|
||||
max_transforms: int = 10 # Maximum number of transformations to apply
|
||||
w_rotate: float = 1 # Weight of rotating the matrix
|
||||
w_hmirror: float = 1 # Weight of horizontally mirroring the matrix
|
||||
w_vmirror: float = 1 # Weight of vertically mirroring the matrix
|
||||
w_dmirror: float = 1 # Weight of mirroring along the diagonal
|
||||
w_cmirror: float = 1 # Weight of mirroring along the counterdiagonal
|
||||
w_map: float = 1 # Weight of mapping a certain value to another
|
||||
w_crop: float = 1 # Weight of cropping the matrix
|
||||
w_remove_every_nth_row: float = 1 # Weight of removing every nth row
|
||||
w_remove_every_nth_col: float = 1 # Weight of removing every nth column
|
||||
w_zero_divisible: float = 1 # Weight of setting elements divisible by some number to zero
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
|
@ -53,17 +56,27 @@ class ManipulateMatrixConfig:
|
|||
assert 1 <= self.min_cols, "min_cols must be at least 1"
|
||||
assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows"
|
||||
assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols"
|
||||
assert 0 <= self.max_transforms, "max_transforms must be non-negative"
|
||||
assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1"
|
||||
assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1"
|
||||
assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1"
|
||||
assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1"
|
||||
assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1"
|
||||
assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1"
|
||||
assert 1 <= self.min_transforms, "min_transforms must be at least 1"
|
||||
assert self.min_transforms <= self.max_transforms, "max_transforms must be at least min_transforms"
|
||||
assert (
|
||||
np.sum(
|
||||
np.exp(
|
||||
[
|
||||
self.w_rotate,
|
||||
self.w_hmirror,
|
||||
self.w_vmirror,
|
||||
self.w_dmirror,
|
||||
self.w_cmirror,
|
||||
self.w_map,
|
||||
self.w_crop,
|
||||
self.w_remove_every_nth_row,
|
||||
self.w_remove_every_nth_col,
|
||||
self.w_zero_divisible,
|
||||
]
|
||||
)
|
||||
)
|
||||
> 0
|
||||
), "At least one weight must be non-zero"
|
||||
|
||||
|
||||
class ManipulateMatrixDataset(ProceduralDataset):
|
||||
|
|
@ -89,6 +102,21 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
"remove_every_nth_row",
|
||||
"remove_every_nth_col",
|
||||
]
|
||||
weights = np.array(
|
||||
[
|
||||
config.w_rotate,
|
||||
config.w_hmirror,
|
||||
config.w_vmirror,
|
||||
config.w_dmirror,
|
||||
config.w_cmirror,
|
||||
config.w_map,
|
||||
config.w_crop,
|
||||
config.w_remove_every_nth_row,
|
||||
config.w_remove_every_nth_col,
|
||||
config.w_zero_divisible,
|
||||
]
|
||||
)
|
||||
self._weights = np.exp(weights) / np.sum(np.exp(weights))
|
||||
|
||||
def _get_matrix(self, rng: Random) -> list[list[int]]:
|
||||
"""Generate a random matrix"""
|
||||
|
|
@ -102,6 +130,25 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
"""Get a string representation of the matrix"""
|
||||
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["answer"].strip()
|
||||
if answer is not None and len(answer) > 0:
|
||||
answer = answer.strip()
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
|
||||
# perhaps the model's answer has unnecessary spaces (e.g. after last row element)
|
||||
answer = self._matrix_to_str([row.strip().split() for row in answer.strip().split("\n")]).strip()
|
||||
if answer == oracle_answer:
|
||||
return 1.0
|
||||
|
||||
if oracle_answer in answer:
|
||||
return len(oracle_answer) / len(answer)
|
||||
else:
|
||||
return 0.01
|
||||
|
||||
return 0.0
|
||||
|
||||
def _identity(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""Identity transformation"""
|
||||
return matrix
|
||||
|
|
@ -163,15 +210,16 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
matrix = self._get_matrix(rng)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
|
||||
num_transforms = rng.randint(0, self.config.max_transforms)
|
||||
transforms = rng.sample(self._all_transforms, num_transforms)
|
||||
num_transforms = rng.randint(self.config.min_transforms, self.config.max_transforms)
|
||||
operations = []
|
||||
|
||||
answer = deepcopy(matrix)
|
||||
|
||||
for transform in transforms:
|
||||
while len(operations) < num_transforms:
|
||||
# Choose a transform randomly, weighted by the probability of each transform
|
||||
transform = rng.choices(self._all_transforms, weights=self._weights, k=1)[0]
|
||||
|
||||
# Rotate
|
||||
if transform == "rotate" and rng.random() < self.config.p_rotate:
|
||||
if transform == "rotate":
|
||||
rotation = rng.choice(list(self._rotations.keys()))
|
||||
answer = self._rotations[rotation](answer)
|
||||
operations.append(
|
||||
|
|
@ -182,39 +230,39 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
}
|
||||
)
|
||||
# Horizontal mirror
|
||||
if transform == "hmirror" and rng.random() < self.config.p_hmirror:
|
||||
if transform == "hmirror":
|
||||
answer = self._hmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"})
|
||||
# Vertical mirror
|
||||
if transform == "vmirror" and rng.random() < self.config.p_vmirror:
|
||||
if transform == "vmirror":
|
||||
answer = self._vmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"})
|
||||
# Diagonal mirror
|
||||
if transform == "dmirror" and rng.random() < self.config.p_dmirror:
|
||||
if transform == "dmirror":
|
||||
answer = self._dmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"})
|
||||
# Counterdiagonal mirror
|
||||
if transform == "cmirror" and rng.random() < self.config.p_cmirror:
|
||||
if transform == "cmirror":
|
||||
answer = self._cmirror(answer)
|
||||
operations.append(
|
||||
{"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"}
|
||||
)
|
||||
# Map a value to another
|
||||
if transform == "map" and rng.random() < self.config.p_map:
|
||||
if transform == "map":
|
||||
a, b = rng.sample(range(10), 2)
|
||||
answer = self._map(answer, a, b)
|
||||
operations.append(
|
||||
{"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"}
|
||||
)
|
||||
# Set elements divisible by k to zero
|
||||
if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible:
|
||||
if transform == "zero_divisible":
|
||||
k = rng.randint(1, 9)
|
||||
answer = self._zero_divisible(answer, k)
|
||||
operations.append(
|
||||
{"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"}
|
||||
)
|
||||
# Crop the matrix
|
||||
if transform == "crop" and rng.random() < self.config.p_crop:
|
||||
if transform == "crop":
|
||||
row_start = rng.randint(1, num_rows(answer))
|
||||
row_end = rng.randint(row_start, num_rows(answer))
|
||||
col_start = rng.randint(1, num_cols(answer))
|
||||
|
|
@ -231,11 +279,7 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
}
|
||||
)
|
||||
# Remove every nth row
|
||||
if (
|
||||
transform == "remove_every_nth_row"
|
||||
and rng.random() < self.config.p_remove_every_nth_row
|
||||
and num_rows(answer) > 1
|
||||
):
|
||||
if transform == "remove_every_nth_row" and num_rows(answer) > 1:
|
||||
n = rng.randint(2, num_rows(answer))
|
||||
answer = self._remove_every_nth_row(answer, n)
|
||||
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
|
||||
|
|
@ -243,11 +287,7 @@ class ManipulateMatrixDataset(ProceduralDataset):
|
|||
{"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"}
|
||||
)
|
||||
# Remove every nth column
|
||||
if (
|
||||
transform == "remove_every_nth_col"
|
||||
and rng.random() < self.config.p_remove_every_nth_col
|
||||
and num_cols(answer) > 1
|
||||
):
|
||||
if transform == "remove_every_nth_col" and num_cols(answer) > 1:
|
||||
n = rng.randint(2, num_cols(answer))
|
||||
answer = self._remove_every_nth_col(answer, n)
|
||||
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue