use Decimal class for numeric comparison e.g. +0123.100 == 123.1

This commit is contained in:
Andreas Koepf 2025-02-21 15:33:42 +01:00
parent ff5b210106
commit 476e37e70b
4 changed files with 15 additions and 18 deletions

View file

@ -2,8 +2,6 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from reasoning_gym import utils
from ..factory import ProceduralDataset, register_dataset
@ -234,10 +232,6 @@ class BasicArithmeticDataset(ProceduralDataset):
template = rng.choice(templates)
return template.format(expression)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
return utils.compute_reward(answer, oracle_answer, allow_commas=False)
# Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)