mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
updated algorithmics dataset (#269)
* updated algorithmic datasets * added changes to symbolic and power * updated power function test
This commit is contained in:
parent
f426db90ec
commit
d9638df79c
5 changed files with 57 additions and 29 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""GSM Symblic dataset generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Callable, Optional
|
||||
|
|
@ -149,8 +150,28 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
generator_idx = self.task_indices[idx]
|
||||
generator = self.generators[generator_idx]
|
||||
example = generator(rng, self.config.difficulty)
|
||||
example["question"] += " Give only the result as your final answer."
|
||||
example["question"] += " Give the result as your final answer. Do not include units."
|
||||
return example
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
if answer is None:
|
||||
return reward
|
||||
try:
|
||||
# Extract number using regex with search
|
||||
match = re.search(r"\b-?\d+(?:\.\d+)?\b", answer)
|
||||
if not match:
|
||||
return reward
|
||||
|
||||
answer_value = float(match.group(0))
|
||||
expected_answer = float(entry["answer"])
|
||||
if answer_value == expected_answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
return reward
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue