diff --git a/reasoning_gym/algorithmic/spiral_matrix.py b/reasoning_gym/algorithmic/spiral_matrix.py index 2fc99666..fe65cc87 100644 --- a/reasoning_gym/algorithmic/spiral_matrix.py +++ b/reasoning_gym/algorithmic/spiral_matrix.py @@ -6,20 +6,25 @@ https://leetcode.com/problems/spiral-matrix/description/ from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Dict, Optional from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element. Example: - -Input: +- Input: For the matrix below, what is the list of elements in spiral order? 1 2 3 4 5 6 7 8 9 - -Output: 1 2 3 6 9 8 7 4 5 +- Output: 1 2 3 6 9 8 7 4 5 +- Explanation: + - We start from the top-left element (1) and move right until we reach the end of the row: 1 2 3 + - Then, we move down until we reach the last column: 1 2 3 6 9 + - Next, we move left until we reach the first column: 1 2 3 6 9 8 7 + - Then, we move up until we reach the second row (i.e. one below the previously traversed row): 1 2 3 6 9 8 7 4 + - Finally, we move right until we reach the second to last column: 1 2 3 6 9 8 7 4 5 + - The output format is a space-separated list of elements in spiral order (as opposed to a python list) For the matrix below, what is the list of elements in spiral order? {matrix} @@ -37,7 +42,7 @@ class SpiralMatrixConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.max_n, "max_n must be at least 1" + assert 2 <= self.max_n, "max_n must be at least 2" class SpiralMatrixDataset(ProceduralDataset): @@ -48,7 +53,7 @@ class SpiralMatrixDataset(ProceduralDataset): def _get_matrix(self, rng: Random) -> list[list[int]]: """Generate a random matrix""" - n = rng.randint(1, self.config.max_n) + n = rng.randint(2, self.config.max_n) numbers = [rng.randint(0, 9) for _ in range(n**2)] rng.shuffle(numbers) matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] @@ -111,5 +116,28 @@ class SpiralMatrixDataset(ProceduralDataset): "metadata": {"matrix": matrix, "solution": answer}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["answer"].strip() + + if answer is not None and len(answer) > 0: + answer = answer.strip() + + # Exact match + if answer == oracle_answer: + return 1.0 + + # Try to see if the model's answer is a python list + try: + answer = " ".join(str(item) for item in eval(answer)) + if answer == oracle_answer: + return 0.5 + else: + return 0.01 + except Exception as e: + return 0.01 + + return 0.0 + register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig) diff --git a/tests/test_spiral_matrix.py b/tests/test_spiral_matrix.py index fc707310..333ecadd 100644 --- a/tests/test_spiral_matrix.py +++ b/tests/test_spiral_matrix.py @@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation(): config = SpiralMatrixConfig(max_n=0) # Zero not allowed config.validate() + with pytest.raises(AssertionError): + config = SpiralMatrixConfig(max_n=1) # One not allowed + config.validate() + def test_spiral_matrix_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -69,18 +73,26 @@ def test_spiral_matrix_answer(): config = SpiralMatrixConfig(seed=42) dataset = SpiralMatrixDataset(config) - # One element - matrix = [[0]] - assert dataset._get_spiral(matrix) == [0] - - # One row - matrix = [[0, 1, 2]] - assert dataset._get_spiral(matrix) == [0, 1, 2] - - # One column - matrix = [[0], [1], [2]] - assert dataset._get_spiral(matrix) == [0, 1, 2] - # 2D grid matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5] + + # Answer is identical (up to trimming) + entry = {"answer": "1 2 3 6 9 8 7 4 5"} + answer = "\n\n1 2 3 6 9 8 7 4 5\n" + assert dataset.score_answer(answer, entry) == 1.0 + + # Score answer in list format (partially correct) + entry = {"answer": "1 2 3 6 9 8 7 4 5"} + answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]" + assert dataset.score_answer(answer, entry) == 0.5 + + # Answer is incorrect + entry = {"answer": "1 2 3 6 9 8 7 4 5"} + answer = "1 2 3" + assert dataset.score_answer(answer, entry) == 0.01 + + # Answer is none + entry = {"answer": "1 2 3 6 9 8 7 4 5"} + answer = None + assert dataset.score_answer(answer, entry) == 0.0