diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py
index 0cb89240..8d126536 100644
--- a/reasoning_gym/dataset.py
+++ b/reasoning_gym/dataset.py
@@ -59,7 +59,7 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
- reward = 0.5
+ reward = len(oracle_answer) / len(answer)
else:
reward = 0.01
diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py
index 457004ce..41f0779d 100644
--- a/reasoning_gym/utils.py
+++ b/reasoning_gym/utils.py
@@ -17,7 +17,7 @@ Once you have thought about the reasoning process, provide the answer in the fol
def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
- regex = f"<{tag_name}>(.*?){tag_name}>"
+ regex = f"<{tag_name}>\\s?(.*?)\\s?{tag_name}>"
matches = list(
re.finditer(
regex,
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 9ba6bcc3..f321ad2a 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -2,6 +2,7 @@ import pytest
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from reasoning_gym.dataset import ReseedingDataset
+from reasoning_gym.utils import extract_answer
def test_reseeding_dataset_iteration():
@@ -38,3 +39,19 @@ def test_reseeding_dataset_iteration():
test_item = next(iter(infinite_dataset))
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0
+
+
+def test_extract_answer():
+ assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234"
+
+ # ignore single whitespae
+ assert extract_answer("This is a text. \n1234 ", tag_name="answer") == "1234"
+
+ config = BasicArithmeticDatasetConfig(
+ min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
+ )
+
+ base_dataset = BasicArithmeticDataset(config)
+ item = base_dataset[0]
+ assert base_dataset.score_answer(item["answer"] + " + x", item) > 0.1
+ assert base_dataset.score_answer(item["answer"], item) == 1.0