updated algorithmics dataset (#269)

* updated algorithmic datasets
* added changes to symbolic and power
* updated power function test
This commit is contained in:
joesharratt1229 2025-03-05 23:32:53 +01:00 committed by GitHub
parent f426db90ec
commit d9638df79c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 57 additions and 29 deletions

View file

@ -13,6 +13,7 @@ The stride is equal to the kernel size, meaning there is no overlap between the
Your output should be a matrix in the same format as the input matrix. Your output should be a matrix in the same format as the input matrix.
The output matrix is smaller than the input matrix when the kernel size is greater than 1, and its elements may be floating-point numbers. The output matrix is smaller than the input matrix when the kernel size is greater than 1, and its elements may be floating-point numbers.
Give elements in the output matrix correct to 2 decimal places.
Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}: Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}:
{matrix} {matrix}
@ -87,7 +88,7 @@ class PoolMatrixDataset(ProceduralDataset):
try: try:
oracle_answer = np.loadtxt(entry["answer"].splitlines(), dtype=np.float32) oracle_answer = np.loadtxt(entry["answer"].splitlines(), dtype=np.float32)
answer = np.loadtxt(answer.splitlines(), dtype=np.float32) answer = np.loadtxt(answer.splitlines(), dtype=np.float32)
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer, rtol=1e-2):
reward = 1.0 reward = 1.0
elif oracle_answer.shape == answer.shape: elif oracle_answer.shape == answer.shape:
reward = 0.1 reward = 0.1

View file

@ -1,5 +1,6 @@
"""GSM Symblic dataset generator""" """GSM Symblic dataset generator"""
import re
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
@ -149,8 +150,28 @@ class GSMSymbolicDataset(ProceduralDataset):
generator_idx = self.task_indices[idx] generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx] generator = self.generators[generator_idx]
example = generator(rng, self.config.difficulty) example = generator(rng, self.config.difficulty)
example["question"] += " Give only the result as your final answer." example["question"] += " Give the result as your final answer. Do not include units."
return example return example
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is None:
return reward
try:
# Extract number using regex with search
match = re.search(r"\b-?\d+(?:\.\d+)?\b", answer)
if not match:
return reward
answer_value = float(match.group(0))
expected_answer = float(entry["answer"])
if answer_value == expected_answer:
reward = 1.0
else:
reward = 0.01
except Exception:
return reward
return reward
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig) register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -1,6 +1,7 @@
"""Computhe the power of a number.""" """Computhe the power of a number."""
from dataclasses import dataclass from dataclasses import dataclass
from decimal import Decimal
from math import pow from math import pow
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
@ -9,7 +10,8 @@ from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number. QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
Compute {base}^{exponent} Compute {base}^{exponent}. Return your final answer correct to 3 significant figures.
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
""" """
@ -33,19 +35,26 @@ class PowerFunctionDataset(ProceduralDataset):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available.""" """Score the answer by checking if it matches the expected answer to 3 significant figures."""
oracle_answer = entry["answer"] oracle_answer = entry["answer"]
if answer is not None: if answer is not None:
try: try:
answer = round(float(answer), 4) user_answer = Decimal(answer)
oracle_answer = round(float(oracle_answer), 4) oracle_value = Decimal(oracle_answer)
difference = abs(float(answer) - float(oracle_answer))
if difference < 1e-4: if oracle_value == 0:
return 1.0 if user_answer == 0 else 0.01
user_sig_figs = f"{user_answer:.3g}"
oracle_sig_figs = f"{oracle_value:.3g}"
# Check if they match to 3 significant figures
if user_sig_figs == oracle_sig_figs:
return 1.0 return 1.0
elif difference < 1e-1: else:
return 0.5 return 0.01
except Exception: except Exception as e:
pass return 0.01
return 0.0 return 0.0
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:

View file

@ -90,3 +90,14 @@ def test_gsm_symbolic_generators():
print(f"ok: q={len(question_set)}, a={len(answer_set)}") print(f"ok: q={len(question_set)}, a={len(answer_set)}")
i += 1 i += 1
def test_gsm_symbolic_score_answer():
"""Test score answer function"""
config = GSMSymbolicDatasetConfig(size=100, seed=42)
dataset = GSMSymbolicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
score = dataset.score_answer(item["answer"], item)
assert score == 1.0

View file

@ -59,20 +59,6 @@ def test_power_function_score_function():
config = PowerFunctionConfig(seed=42) config = PowerFunctionConfig(seed=42)
dataset = PowerFunctionDataset(config) dataset = PowerFunctionDataset(config)
item = dataset[0] for item in dataset:
answer = item["answer"]
# Answer is within 1e-6 of solution
answer = str(item["metadata"]["solution"] - 1e-7)
assert dataset.score_answer(answer, item) == 1.0 assert dataset.score_answer(answer, item) == 1.0
# Answer is within 1e-1 of solution
answer = str(item["metadata"]["solution"] - 1e-2)
assert dataset.score_answer(answer, item) == 0.5
# Answer is far from solution
answer = str(item["metadata"]["solution"] - 1)
assert dataset.score_answer(answer, item) == 0.0
# Answer is None
answer = None
assert dataset.score_answer(answer, item) == 0.0