mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
Remove strip from ProceduralDataset::core score_answer() (#250)
* remove strip from ProceduralDataset::core score_answer(), strip in extract answer (optional, default=True) * test: Move test_extract_answer() from test_dataset.py to test_utils.py * refactor: Improve decimal reward computation with more flexible comparison * fix: Implement rounding for format_number when round_if_needed is True * test: Add test case for compute_decimal_reward with sign and zeros
This commit is contained in:
parent
16a4ea1193
commit
ece6990709
6 changed files with 80 additions and 26 deletions
53
tests/test_utils.py
Normal file
53
tests/test_utils.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number
|
||||
|
||||
|
||||
def test_extract_answer():
|
||||
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"
|
||||
|
||||
# ignore whitespaces
|
||||
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234"
|
||||
|
||||
|
||||
def test_format_number():
|
||||
# Test integers
|
||||
assert format_number(42) == "42"
|
||||
assert format_number(42.0) == "42"
|
||||
|
||||
# Test decimals
|
||||
assert format_number(3.14) == "3.14"
|
||||
assert format_number(3.10) == "3.1"
|
||||
assert format_number(3.00) == "3"
|
||||
|
||||
# Test with max_decimals (rounding)
|
||||
assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416"
|
||||
|
||||
# Test with trailing zeros
|
||||
assert format_number(5.5000) == "5.5"
|
||||
|
||||
# Test error cases
|
||||
with pytest.raises(ValueError):
|
||||
format_number(3.14159, max_decimals=2)
|
||||
|
||||
|
||||
def test_compute_decimal_reward():
|
||||
# Test exact matches
|
||||
assert compute_decimal_reward("42", "42") == 1.0
|
||||
assert compute_decimal_reward("3.14", "3.14") == 1.0
|
||||
|
||||
# Test with commas
|
||||
assert compute_decimal_reward("1,000", "1000") == 1.0
|
||||
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0
|
||||
|
||||
# Test with sign, leading zeros, and trailing decimals
|
||||
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0
|
||||
|
||||
# Test partial matches
|
||||
assert compute_decimal_reward("The answer is 42", "42") < 1.0
|
||||
assert compute_decimal_reward("The answer is 42", "42") > 0.01
|
||||
|
||||
# Test invalid answers
|
||||
assert compute_decimal_reward(None, "42") == 0.0
|
||||
assert compute_decimal_reward("", "42") == 0.0
|
||||
assert compute_decimal_reward("not a number", "42") == 0.01
|
||||
Loading…
Add table
Add a link
Reference in a new issue