use Decimal class for numeric comparison e.g. +0123.100 == 123.1

This commit is contained in:
Andreas Koepf 2025-02-21 15:33:42 +01:00
parent ff5b210106
commit 476e37e70b
4 changed files with 15 additions and 18 deletions

View file

@ -111,8 +111,7 @@ class ChainSumDataset(ProceduralDataset):
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer)
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
class ChainSumCurriculum(BaseCurriculum):