mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
update binary matrix prompt
This commit is contained in:
parent
b7b8e90d04
commit
87ae218328
1 changed files with 35 additions and 8 deletions
|
|
@ -7,23 +7,28 @@ https://leetcode.com/problems/01-matrix/description/
|
|||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab distance of the nearest 0 for each cell.
|
||||
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab (Manhattan) distance of the nearest 0 for each cell.
|
||||
|
||||
Example:
|
||||
|
||||
Input: Find the distance to the nearest 0 for each cell in the matrix below:
|
||||
- Input: Find the distance to the nearest 0 for each cell in the matrix below:
|
||||
0 0 0
|
||||
0 1 0
|
||||
1 1 1
|
||||
|
||||
Output:
|
||||
- Output:
|
||||
0 0 0
|
||||
0 1 0
|
||||
1 2 1
|
||||
- Explanation
|
||||
- Each cell with a 0 has a distance of 0 to itself.
|
||||
- The cell at (1, 1) has a distance of 1 to the nearest 0 (any of the three 0's at (1, 0), (0, 1), (1, 2)).
|
||||
- The cell at (2, 0) has a distance of 1 to the nearest 0 (the 0 at (1, 0)).
|
||||
- The cell at (2, 1) has a distance of 2 to the nearest 0 (any of the two 0's at (1, 0), (1, 2))
|
||||
- The cell at (2, 2) has a distance of 1 to the nearest 0 (the 0 at (1, 2)).
|
||||
- Hence, the final answer is the matrix is the output shown above, where each cell contains the distance to the nearest 0, in the same format as the input matrix.
|
||||
|
||||
Find the distance to the nearest 0 for each cell in the matrix below:
|
||||
{matrix}
|
||||
|
|
@ -34,6 +39,7 @@ Find the distance to the nearest 0 for each cell in the matrix below:
|
|||
class BinaryMatrixConfig:
|
||||
"""Configuration for Binary Matrix dataset generation"""
|
||||
|
||||
min_n: int = 3 # Minimum size of the matrix
|
||||
max_n: int = 10 # Maximum size of the matrix
|
||||
p_zero: float = 0.25 # Probability of a cell being 0
|
||||
|
||||
|
|
@ -42,7 +48,8 @@ class BinaryMatrixConfig:
|
|||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||
assert 1 <= self.min_n, "min_n must be at least 1"
|
||||
assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n"
|
||||
assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1"
|
||||
|
||||
|
||||
|
|
@ -54,7 +61,7 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
|
||||
def _get_binary_matrix(self, rng: Random) -> list[list[int]]:
|
||||
"""Generate a random binary matrix"""
|
||||
n = rng.randint(1, self.config.max_n)
|
||||
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||
# Ensure at least one 0 in the matrix, so that a solution exists
|
||||
numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)]
|
||||
rng.shuffle(numbers)
|
||||
|
|
@ -105,6 +112,26 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
"""Get a string representation of the matrix"""
|
||||
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
if answer == oracle_answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
try:
|
||||
# check if answer is python list of lists
|
||||
answer = self._matrix_to_str(eval(answer))
|
||||
if answer == oracle_answer:
|
||||
reward = 0.5
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception as e:
|
||||
reward = 0.01
|
||||
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Binary Matrix question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue