fix score_answer of pool_matrix (if -> elif), remove print

This commit is contained in:
Andreas Koepf 2025-02-25 23:36:11 +01:00
parent f9e8f8b064
commit bba128ffd0

View file

@ -97,19 +97,21 @@ class PoolMatrixDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer based on the metadata"""
if not answer:
return 0.0
reward = 0.0
try:
if answer is not None:
oracle_answer = np.array(entry["answer"])
answer = np.array(answer)
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
reward = 1.0
if oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except:
print("Error in scoring answer for Pool Matrix")
oracle_answer = np.array(entry["answer"])
answer = np.array(answer)
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
reward = 1.0
elif oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except Exception:
pass
return reward
def __getitem__(self, idx: int) -> dict: