mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix: Rounding issues in score_answer and add unit tests (#462)
This commit is contained in:
parent
51c2afc1fc
commit
9e79fc84b6
2 changed files with 59 additions and 10 deletions
|
|
@ -1,4 +1,4 @@
|
|||
"""Computhe the power of a number."""
|
||||
"""Compute the power of a number."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
|
|
@ -37,22 +37,26 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
def __init__(self, config: PowerFunctionConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_sig_figs(self, x: Decimal, sig: int) -> Decimal:
|
||||
"""Format a Decimal to exactly 'sig' significant figures, keeping trailing zeros."""
|
||||
if x.is_zero():
|
||||
return "0." + "0" * (sig - 1)
|
||||
|
||||
exp = x.adjusted()
|
||||
shift = sig - exp - 1
|
||||
rounded = x.quantize(Decimal("1e{}".format(-shift)))
|
||||
return Decimal(rounded)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer by checking if it matches the expected answer to 3 significant figures."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
try:
|
||||
user_answer = Decimal(answer)
|
||||
oracle_value = Decimal(oracle_answer)
|
||||
|
||||
if oracle_value == 0:
|
||||
return 1.0 if user_answer == 0 else 0.01
|
||||
|
||||
user_sig_figs = f"{user_answer:.3g}"
|
||||
oracle_sig_figs = f"{oracle_value:.3g}"
|
||||
user_answer = self._format_sig_figs(Decimal(answer), 3)
|
||||
oracle_answer = self._format_sig_figs(Decimal(oracle_answer), 3)
|
||||
|
||||
# Check if they match to 3 significant figures
|
||||
if user_sig_figs == oracle_sig_figs:
|
||||
if user_answer == oracle_answer:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.01
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue