From bba128ffd086bba25492b30e519744d425e9dbc5 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Tue, 25 Feb 2025 23:36:11 +0100 Subject: [PATCH] fix score_answer of pool_matrix (if -> elif), remove print --- reasoning_gym/algorithmic/pool_matrix.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index 4d14c9cf..41b9117d 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -97,19 +97,21 @@ class PoolMatrixDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Score the answer based on the metadata""" + if not answer: + return 0.0 + reward = 0.0 try: - if answer is not None: - oracle_answer = np.array(entry["answer"]) - answer = np.array(answer) - if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): - reward = 1.0 - if oracle_answer.shape == answer.shape: - reward = 0.1 - else: - reward = 0.01 - except: - print("Error in scoring answer for Pool Matrix") + oracle_answer = np.array(entry["answer"]) + answer = np.array(answer) + if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): + reward = 1.0 + elif oracle_answer.shape == answer.shape: + reward = 0.1 + else: + reward = 0.01 + except Exception: + pass return reward def __getitem__(self, idx: int) -> dict: