update binary matrix prompt

This commit is contained in:
Zafir Stojanovski 2025-02-15 15:34:38 +01:00
parent b7b8e90d04
commit 87ae218328

View file

@ -7,23 +7,28 @@ https://leetcode.com/problems/01-matrix/description/
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Optional
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab distance of the nearest 0 for each cell.
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab (Manhattan) distance of the nearest 0 for each cell.
Example:
Input: Find the distance to the nearest 0 for each cell in the matrix below:
- Input: Find the distance to the nearest 0 for each cell in the matrix below:
0 0 0
0 1 0
1 1 1
Output:
- Output:
0 0 0
0 1 0
1 2 1
- Explanation
- Each cell with a 0 has a distance of 0 to itself.
- The cell at (1, 1) has a distance of 1 to the nearest 0 (any of the three 0's at (1, 0), (0, 1), (1, 2)).
- The cell at (2, 0) has a distance of 1 to the nearest 0 (the 0 at (1, 0)).
- The cell at (2, 1) has a distance of 2 to the nearest 0 (any of the two 0's at (1, 0), (1, 2))
- The cell at (2, 2) has a distance of 1 to the nearest 0 (the 0 at (1, 2)).
- Hence, the final answer is the matrix is the output shown above, where each cell contains the distance to the nearest 0, in the same format as the input matrix.
Find the distance to the nearest 0 for each cell in the matrix below:
{matrix}
@ -34,6 +39,7 @@ Find the distance to the nearest 0 for each cell in the matrix below:
class BinaryMatrixConfig:
"""Configuration for Binary Matrix dataset generation"""
min_n: int = 3 # Minimum size of the matrix
max_n: int = 10 # Maximum size of the matrix
p_zero: float = 0.25 # Probability of a cell being 0
@ -42,7 +48,8 @@ class BinaryMatrixConfig:
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 1 <= self.min_n, "min_n must be at least 1"
assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n"
assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1"
@ -54,7 +61,7 @@ class BinaryMatrixDataset(ProceduralDataset):
def _get_binary_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random binary matrix"""
n = rng.randint(1, self.config.max_n)
n = rng.randint(self.config.min_n, self.config.max_n)
# Ensure at least one 0 in the matrix, so that a solution exists
numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)]
rng.shuffle(numbers)
@ -105,6 +112,26 @@ class BinaryMatrixDataset(ProceduralDataset):
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None:
if answer == oracle_answer:
reward = 1.0
else:
try:
# check if answer is python list of lists
answer = self._matrix_to_str(eval(answer))
if answer == oracle_answer:
reward = 0.5
else:
reward = 0.01
except Exception as e:
reward = 0.01
return reward
def __getitem__(self, idx: int) -> dict:
"""Generate a single Binary Matrix question"""
rng = Random(self.seed + idx)