mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
ignore single whitespace at beginning and end of answer, use reward = len(oracle_answer) / len(answer)
This commit is contained in:
parent
979b6ba4ef
commit
0a660a3409
3 changed files with 19 additions and 2 deletions
|
|
@ -59,7 +59,7 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||||
if answer == oracle_answer:
|
if answer == oracle_answer:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
elif oracle_answer in answer:
|
elif oracle_answer in answer:
|
||||||
reward = 0.5
|
reward = len(oracle_answer) / len(answer)
|
||||||
else:
|
else:
|
||||||
reward = 0.01
|
reward = 0.01
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ Once you have thought about the reasoning process, provide the answer in the fol
|
||||||
|
|
||||||
|
|
||||||
def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
|
def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
|
||||||
regex = f"<{tag_name}>(.*?)</{tag_name}>"
|
regex = f"<{tag_name}>\\s?(.*?)\\s?</{tag_name}>"
|
||||||
matches = list(
|
matches = list(
|
||||||
re.finditer(
|
re.finditer(
|
||||||
regex,
|
regex,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||||
from reasoning_gym.dataset import ReseedingDataset
|
from reasoning_gym.dataset import ReseedingDataset
|
||||||
|
from reasoning_gym.utils import extract_answer
|
||||||
|
|
||||||
|
|
||||||
def test_reseeding_dataset_iteration():
|
def test_reseeding_dataset_iteration():
|
||||||
|
|
@ -38,3 +39,19 @@ def test_reseeding_dataset_iteration():
|
||||||
test_item = next(iter(infinite_dataset))
|
test_item = next(iter(infinite_dataset))
|
||||||
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
|
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
|
||||||
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0
|
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_answer():
|
||||||
|
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"
|
||||||
|
|
||||||
|
# ignore single whitespae
|
||||||
|
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer") == "1234"
|
||||||
|
|
||||||
|
config = BasicArithmeticDatasetConfig(
|
||||||
|
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
|
||||||
|
)
|
||||||
|
|
||||||
|
base_dataset = BasicArithmeticDataset(config)
|
||||||
|
item = base_dataset[0]
|
||||||
|
assert base_dataset.score_answer(item["answer"] + " + x", item) > 0.1
|
||||||
|
assert base_dataset.score_answer(item["answer"], item) == 1.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue