mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
fix: add score_answer() to number_sorting (#380)
* fix: add score_answer() to number_sorting * chore: run pre-commit * fix: use json.loads() * fix: run isort()
This commit is contained in:
parent
1c6f2d01ee
commit
d6aad5a329
2 changed files with 120 additions and 1 deletions
|
|
@ -117,3 +117,59 @@ def test_number_sorting_curriculum():
|
|||
assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100
|
||||
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4
|
||||
assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000
|
||||
|
||||
|
||||
def test_number_sorting_score_answer():
|
||||
"""Test the score_answer method for correctly evaluating model responses."""
|
||||
# Create a dataset instance
|
||||
config = NumberSortingConfig(seed=42)
|
||||
dataset = NumberSortingDataset(config)
|
||||
|
||||
# Create a mock entry similar to the example provided
|
||||
mock_entry = {
|
||||
"question": "Sort these numbers in ascending order: -16.5, -83.6, -95.7, -97.8, 61.5, 71.08, -92.85",
|
||||
"answer": "['-97.8', '-95.7', '-92.8', '-83.6', '-16.5', '61.5', '71.1']",
|
||||
"metadata": {
|
||||
"direction": "ascending",
|
||||
"original_numbers": ["-16.5", "-83.6", "-95.7", "-97.8", "61.5", "71.08", "-92.85"],
|
||||
"sorted_numbers": ["-97.8", "-95.7", "-92.8", "-83.6", "-16.5", "61.5", "71.1"],
|
||||
},
|
||||
}
|
||||
|
||||
# Test case 1: Exact match should score 1.0
|
||||
exact_match = "['-97.8', '-95.7', '-92.8', '-83.6', '-16.5', '61.5', '71.1']"
|
||||
assert dataset.score_answer(exact_match, mock_entry) == 1.0
|
||||
|
||||
# Test case 2: Answer with small numerical differences but correct order should score 1.0
|
||||
close_match = "['-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5', '71.08']"
|
||||
assert dataset.score_answer(close_match, mock_entry) == 1.0
|
||||
|
||||
# Test case 3: Incorrectly sorted answer should score 0.0
|
||||
wrong_order = "['-16.5', '-83.6', '-92.85', '-95.7', '-97.8', '61.5', '71.08']"
|
||||
assert dataset.score_answer(wrong_order, mock_entry) == 0.0
|
||||
|
||||
# Test case 4: Answer with wrong length should score 0.0
|
||||
wrong_length = "['-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5']"
|
||||
assert dataset.score_answer(wrong_length, mock_entry) == 0.0
|
||||
|
||||
# Test case 5: Non-list answer should score 0.0
|
||||
non_list = "'-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5', '71.08'"
|
||||
assert dataset.score_answer(non_list, mock_entry) == 0.0
|
||||
|
||||
# Test case 6: None answer should score 0.0
|
||||
assert dataset.score_answer(None, mock_entry) == 0.0
|
||||
|
||||
# Test case 7: Correctly sorted but with larger numerical differences (beyond tolerance)
|
||||
beyond_tolerance = "['-97.8', '-95.7', '-91.0', '-83.6', '-16.5', '61.5', '72.0']"
|
||||
assert dataset.score_answer(beyond_tolerance, mock_entry) == 0.0
|
||||
|
||||
# Test case 8: Descending order test
|
||||
descending_entry = {
|
||||
"answer": "['71.1', '61.5', '-16.5', '-83.6', '-92.8', '-95.7', '-97.8']",
|
||||
"metadata": {
|
||||
"direction": "descending",
|
||||
"sorted_numbers": ["71.1", "61.5", "-16.5", "-83.6", "-92.8", "-95.7", "-97.8"],
|
||||
},
|
||||
}
|
||||
descending_match = "['71.08', '61.5', '-16.5', '-83.6', '-92.85', '-95.7', '-97.8']"
|
||||
assert dataset.score_answer(descending_match, descending_entry) == 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue