use Decimal class for numeric comparison e.g. +0123.100 == 123.1

This commit is contained in:
Andreas Koepf 2025-02-21 15:33:42 +01:00
parent ff5b210106
commit 476e37e70b
4 changed files with 15 additions and 18 deletions

View file

@ -103,8 +103,7 @@ class ProductsDataset(ProceduralDataset):
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer)
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
class ProductsCurriculum(BaseCurriculum):