updated algorithmics dataset (#269)

* updated algorithmic datasets
* added changes to symbolic and power
* updated power function test
This commit is contained in:
joesharratt1229 2025-03-05 23:32:53 +01:00 committed by GitHub
parent f426db90ec
commit d9638df79c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 57 additions and 29 deletions

View file

@ -1,6 +1,7 @@
"""Computhe the power of a number."""
from dataclasses import dataclass
from decimal import Decimal
from math import pow
from random import Random
from typing import Any, Optional
@ -9,7 +10,8 @@ from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
Compute {base}^{exponent}
Compute {base}^{exponent}. Return your final answer correct to 3 significant figures.
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
"""
@ -33,19 +35,26 @@ class PowerFunctionDataset(ProceduralDataset):
super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
"""Score the answer by checking if it matches the expected answer to 3 significant figures."""
oracle_answer = entry["answer"]
if answer is not None:
try:
answer = round(float(answer), 4)
oracle_answer = round(float(oracle_answer), 4)
difference = abs(float(answer) - float(oracle_answer))
if difference < 1e-4:
user_answer = Decimal(answer)
oracle_value = Decimal(oracle_answer)
if oracle_value == 0:
return 1.0 if user_answer == 0 else 0.01
user_sig_figs = f"{user_answer:.3g}"
oracle_sig_figs = f"{oracle_value:.3g}"
# Check if they match to 3 significant figures
if user_sig_figs == oracle_sig_figs:
return 1.0
elif difference < 1e-1:
return 0.5
except Exception:
pass
else:
return 0.01
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict: