mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix score_answer of pool_matrix (if -> elif), remove print
This commit is contained in:
parent
f9e8f8b064
commit
bba128ffd0
1 changed files with 13 additions and 11 deletions
|
|
@ -97,19 +97,21 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer based on the metadata"""
|
||||
|
||||
if not answer:
|
||||
return 0.0
|
||||
|
||||
reward = 0.0
|
||||
try:
|
||||
if answer is not None:
|
||||
oracle_answer = np.array(entry["answer"])
|
||||
answer = np.array(answer)
|
||||
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
|
||||
reward = 1.0
|
||||
if oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
print("Error in scoring answer for Pool Matrix")
|
||||
oracle_answer = np.array(entry["answer"])
|
||||
answer = np.array(answer)
|
||||
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
|
||||
reward = 1.0
|
||||
elif oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
pass
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue