reasoning-gym/tests/test_utils.py
Andreas Köpf 5d7fbac0ad
Minor question template & score_answer improvements (#261)
* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
2025-03-04 21:55:09 +01:00

53 lines
1.7 KiB
Python

import pytest
from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number
def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"
# ignore whitespaces
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234"
def test_format_number():
# Test integers
assert format_number(42) == "42"
assert format_number(42.0) == "42"
# Test decimals
assert format_number(3.14) == "3.14"
assert format_number(3.10) == "3.1"
assert format_number(3.00) == "3"
# Test with max_decimals (rounding)
assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416"
# Test with trailing zeros
assert format_number(5.5000) == "5.5"
# Test error cases
with pytest.raises(ValueError):
format_number(3.14159, max_decimals=2)
def test_compute_decimal_reward():
# Test exact matches
assert compute_decimal_reward("42", "42") == 1.0
assert compute_decimal_reward("3.14", "3.14") == 1.0
# Test with commas
assert compute_decimal_reward("1,000", "1000") == 1.0
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0
# Test with sign, leading zeros, and trailing decimals
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0
# Test partial matches
assert compute_decimal_reward("The answer is 42", "42") < 1.0
assert compute_decimal_reward("The answer is 42", "42") > 0.01
# Test invalid answers
assert compute_decimal_reward(None, "42") == 0.0
assert compute_decimal_reward("", "42") == 0.0
assert compute_decimal_reward("not a number", "42") == 0.0