diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 0cb89240..8d126536 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -59,7 +59,7 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): if answer == oracle_answer: reward = 1.0 elif oracle_answer in answer: - reward = 0.5 + reward = len(oracle_answer) / len(answer) else: reward = 0.01 diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index 457004ce..41f0779d 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -17,7 +17,7 @@ Once you have thought about the reasoning process, provide the answer in the fol def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: - regex = f"<{tag_name}>(.*?)" + regex = f"<{tag_name}>\\s?(.*?)\\s?" matches = list( re.finditer( regex, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9ba6bcc3..f321ad2a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,6 +2,7 @@ import pytest from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from reasoning_gym.dataset import ReseedingDataset +from reasoning_gym.utils import extract_answer def test_reseeding_dataset_iteration(): @@ -38,3 +39,19 @@ def test_reseeding_dataset_iteration(): test_item = next(iter(infinite_dataset)) assert infinite_dataset.score_answer("wrong", test_item) == 0.01 assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0 + + +def test_extract_answer(): + assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234" + + # ignore single whitespae + assert extract_answer("This is a text. \n1234 ", tag_name="answer") == "1234" + + config = BasicArithmeticDatasetConfig( + min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10 + ) + + base_dataset = BasicArithmeticDataset(config) + item = base_dataset[0] + assert base_dataset.score_answer(item["answer"] + " + x", item) > 0.1 + assert base_dataset.score_answer(item["answer"], item) == 1.0