diff --git a/reasoning_gym/algorithmic/binary_matrix.py b/reasoning_gym/algorithmic/binary_matrix.py index c3a09086..509b8042 100644 --- a/reasoning_gym/algorithmic/binary_matrix.py +++ b/reasoning_gym/algorithmic/binary_matrix.py @@ -115,22 +115,18 @@ class BinaryMatrixDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"] - reward = 0.0 if answer is not None: if answer == oracle_answer: - reward = 1.0 + return 1.0 else: try: # check if answer is python list of lists answer = self._matrix_to_str(eval(answer)) if answer == oracle_answer: - reward = 0.5 - else: - reward = 0.01 + return 0.5 except Exception as e: - reward = 0.01 - - return reward + return 0.01 + return 0.0 def __getitem__(self, idx: int) -> dict: """Generate a single Binary Matrix question""" diff --git a/tests/test_binary_matrix.py b/tests/test_binary_matrix.py index d62db4a3..1f215cf6 100644 --- a/tests/test_binary_matrix.py +++ b/tests/test_binary_matrix.py @@ -106,3 +106,18 @@ def test_binary_matrix_answer(): # Empty matrix matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + + # String representation of answer + answer = "0 0 0\n0 1 0\n1 2 1" + entry = {"answer": "0 0 0\n0 1 0\n1 2 1"} + assert dataset.score_answer(answer, entry) == 1.0 + + # Answer is a python list (partially correct answer) + answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]" + entry = {"answer": "0 0 0\n0 1 0\n1 2 1"} + assert dataset.score_answer(answer, entry) == 0.5 + + # Answer is null + answer = None + entry = {"answer": "0 0 0\n0 1 0\n1 2 1"} + assert dataset.score_answer(answer, entry) == 0.0 \ No newline at end of file