fix manipulate matrix (#247)

This commit is contained in:
Zafir Stojanovski 2025-03-01 23:00:29 +01:00 committed by GitHub
parent 39f151ad14
commit f549909c3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 116 additions and 69 deletions

View file

@ -3,7 +3,9 @@
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Optional
from typing import Any, Optional
import numpy as np
from ..factory import ProceduralDataset, register_dataset
@ -28,21 +30,22 @@ def num_cols(matrix: list[list[int]]) -> int:
class ManipulateMatrixConfig:
"""Configuration for Manipulate Matrix dataset generation"""
min_rows: int = 1 # Minimum number of rows
min_cols: int = 1 # Minimum number of columns
min_rows: int = 2 # Minimum number of rows
min_cols: int = 2 # Minimum number of columns
max_rows: int = 10 # Maximum number of rows
max_cols: int = 10 # Maximum number of columns
max_transforms: int = 5 # Maximum number of transformations to apply
p_rotate: float = 0.2 # Probability of rotating the matrix
p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix
p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix
p_dmirror: float = 0.2 # Probability of mirroring along the diagonal
p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal
p_map: float = 0.2 # Probability of mapping a certain value to another
p_crop: float = 0.2 # Probability of cropping the matrix
p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row
p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column
p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero
min_transforms: int = 1 # Minimum number of transformations to apply
max_transforms: int = 10 # Maximum number of transformations to apply
w_rotate: float = 1 # Weight of rotating the matrix
w_hmirror: float = 1 # Weight of horizontally mirroring the matrix
w_vmirror: float = 1 # Weight of vertically mirroring the matrix
w_dmirror: float = 1 # Weight of mirroring along the diagonal
w_cmirror: float = 1 # Weight of mirroring along the counterdiagonal
w_map: float = 1 # Weight of mapping a certain value to another
w_crop: float = 1 # Weight of cropping the matrix
w_remove_every_nth_row: float = 1 # Weight of removing every nth row
w_remove_every_nth_col: float = 1 # Weight of removing every nth column
w_zero_divisible: float = 1 # Weight of setting elements divisible by some number to zero
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
@ -53,17 +56,27 @@ class ManipulateMatrixConfig:
assert 1 <= self.min_cols, "min_cols must be at least 1"
assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows"
assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols"
assert 0 <= self.max_transforms, "max_transforms must be non-negative"
assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1"
assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1"
assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1"
assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1"
assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1"
assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1"
assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1"
assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1"
assert 1 <= self.min_transforms, "min_transforms must be at least 1"
assert self.min_transforms <= self.max_transforms, "max_transforms must be at least min_transforms"
assert (
np.sum(
np.exp(
[
self.w_rotate,
self.w_hmirror,
self.w_vmirror,
self.w_dmirror,
self.w_cmirror,
self.w_map,
self.w_crop,
self.w_remove_every_nth_row,
self.w_remove_every_nth_col,
self.w_zero_divisible,
]
)
)
> 0
), "At least one weight must be non-zero"
class ManipulateMatrixDataset(ProceduralDataset):
@ -89,6 +102,21 @@ class ManipulateMatrixDataset(ProceduralDataset):
"remove_every_nth_row",
"remove_every_nth_col",
]
weights = np.array(
[
config.w_rotate,
config.w_hmirror,
config.w_vmirror,
config.w_dmirror,
config.w_cmirror,
config.w_map,
config.w_crop,
config.w_remove_every_nth_row,
config.w_remove_every_nth_col,
config.w_zero_divisible,
]
)
self._weights = np.exp(weights) / np.sum(np.exp(weights))
def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix"""
@ -102,6 +130,25 @@ class ManipulateMatrixDataset(ProceduralDataset):
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
answer = answer.strip()
if answer == oracle_answer:
return 1.0
# perhaps the model's answer has unnecessary spaces (e.g. after last row element)
answer = self._matrix_to_str([row.strip().split() for row in answer.strip().split("\n")]).strip()
if answer == oracle_answer:
return 1.0
if oracle_answer in answer:
return len(oracle_answer) / len(answer)
else:
return 0.01
return 0.0
def _identity(self, matrix: list[list[int]]) -> list[list[int]]:
"""Identity transformation"""
return matrix
@ -163,15 +210,16 @@ class ManipulateMatrixDataset(ProceduralDataset):
matrix = self._get_matrix(rng)
matrix_str = self._matrix_to_str(matrix)
num_transforms = rng.randint(0, self.config.max_transforms)
transforms = rng.sample(self._all_transforms, num_transforms)
num_transforms = rng.randint(self.config.min_transforms, self.config.max_transforms)
operations = []
answer = deepcopy(matrix)
for transform in transforms:
while len(operations) < num_transforms:
# Choose a transform randomly, weighted by the probability of each transform
transform = rng.choices(self._all_transforms, weights=self._weights, k=1)[0]
# Rotate
if transform == "rotate" and rng.random() < self.config.p_rotate:
if transform == "rotate":
rotation = rng.choice(list(self._rotations.keys()))
answer = self._rotations[rotation](answer)
operations.append(
@ -182,39 +230,39 @@ class ManipulateMatrixDataset(ProceduralDataset):
}
)
# Horizontal mirror
if transform == "hmirror" and rng.random() < self.config.p_hmirror:
if transform == "hmirror":
answer = self._hmirror(answer)
operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"})
# Vertical mirror
if transform == "vmirror" and rng.random() < self.config.p_vmirror:
if transform == "vmirror":
answer = self._vmirror(answer)
operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"})
# Diagonal mirror
if transform == "dmirror" and rng.random() < self.config.p_dmirror:
if transform == "dmirror":
answer = self._dmirror(answer)
operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"})
# Counterdiagonal mirror
if transform == "cmirror" and rng.random() < self.config.p_cmirror:
if transform == "cmirror":
answer = self._cmirror(answer)
operations.append(
{"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"}
)
# Map a value to another
if transform == "map" and rng.random() < self.config.p_map:
if transform == "map":
a, b = rng.sample(range(10), 2)
answer = self._map(answer, a, b)
operations.append(
{"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"}
)
# Set elements divisible by k to zero
if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible:
if transform == "zero_divisible":
k = rng.randint(1, 9)
answer = self._zero_divisible(answer, k)
operations.append(
{"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"}
)
# Crop the matrix
if transform == "crop" and rng.random() < self.config.p_crop:
if transform == "crop":
row_start = rng.randint(1, num_rows(answer))
row_end = rng.randint(row_start, num_rows(answer))
col_start = rng.randint(1, num_cols(answer))
@ -231,11 +279,7 @@ class ManipulateMatrixDataset(ProceduralDataset):
}
)
# Remove every nth row
if (
transform == "remove_every_nth_row"
and rng.random() < self.config.p_remove_every_nth_row
and num_rows(answer) > 1
):
if transform == "remove_every_nth_row" and num_rows(answer) > 1:
n = rng.randint(2, num_rows(answer))
answer = self._remove_every_nth_row(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
@ -243,11 +287,7 @@ class ManipulateMatrixDataset(ProceduralDataset):
{"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"}
)
# Remove every nth column
if (
transform == "remove_every_nth_col"
and rng.random() < self.config.p_remove_every_nth_col
and num_cols(answer) > 1
):
if transform == "remove_every_nth_col" and num_cols(answer) > 1:
n = rng.randint(2, num_cols(answer))
answer = self._remove_every_nth_col(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"