mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
ignore single whitespace at beginning and end of answer, use reward = len(oracle_answer) / len(answer)
This commit is contained in:
parent
979b6ba4ef
commit
0a660a3409
3 changed files with 19 additions and 2 deletions
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
|
||||
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||
from reasoning_gym.dataset import ReseedingDataset
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
|
||||
def test_reseeding_dataset_iteration():
|
||||
|
|
@ -38,3 +39,19 @@ def test_reseeding_dataset_iteration():
|
|||
test_item = next(iter(infinite_dataset))
|
||||
assert infinite_dataset.score_answer("wrong", test_item) == 0.01
|
||||
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0
|
||||
|
||||
|
||||
def test_extract_answer():
|
||||
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"
|
||||
|
||||
# ignore single whitespae
|
||||
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer") == "1234"
|
||||
|
||||
config = BasicArithmeticDatasetConfig(
|
||||
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
|
||||
)
|
||||
|
||||
base_dataset = BasicArithmeticDataset(config)
|
||||
item = base_dataset[0]
|
||||
assert base_dataset.score_answer(item["answer"] + " + x", item) > 0.1
|
||||
assert base_dataset.score_answer(item["answer"], item) == 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue