diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index 4d14c9cf..41b9117d 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -97,19 +97,21 @@ class PoolMatrixDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Score the answer based on the metadata""" + if not answer: + return 0.0 + reward = 0.0 try: - if answer is not None: - oracle_answer = np.array(entry["answer"]) - answer = np.array(answer) - if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): - reward = 1.0 - if oracle_answer.shape == answer.shape: - reward = 0.1 - else: - reward = 0.01 - except: - print("Error in scoring answer for Pool Matrix") + oracle_answer = np.array(entry["answer"]) + answer = np.array(answer) + if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): + reward = 1.0 + elif oracle_answer.shape == answer.shape: + reward = 0.1 + else: + reward = 0.01 + except Exception: + pass return reward def __getitem__(self, idx: int) -> dict: