Fix bug in normalize_answer method (#444)

This commit is contained in:
Adefioye 2025-06-02 01:58:54 -05:00 committed by GitHub
parent c0e98f93b4
commit 9053009dbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View file

@ -49,7 +49,10 @@ class PrimeFactorizationDataset(ProceduralDataset):
def _normalize_answer(self, answer: str) -> list[int]: def _normalize_answer(self, answer: str) -> list[int]:
"""Parse and sort factors from a string""" """Parse and sort factors from a string"""
return sorted([int(factor.strip()) for factor in answer.split("×")]) if not answer or answer.strip() == "":
return []
return sorted([int(factor.strip()) for factor in answer.split("×") if factor.strip() != ""])
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"] oracle_answer = entry["answer"]

View file

@ -119,6 +119,10 @@ def test_prime_factorization_score_answer():
answer = None answer = None
assert dataset.score_answer(answer, item) == 0.0 assert dataset.score_answer(answer, item) == 0.0
# Answer is empty string
answer = ""
assert dataset.score_answer(answer, item) == 0.01
def is_prime(n: int) -> bool: def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime""" """Helper function to check if a number is prime"""