diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index d3dae5aa..97a24745 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -1,8 +1,9 @@ """Number sorting task generator""" +import json from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -100,6 +101,68 @@ Please follow the instruction below: }, } + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + """Score the user's answer against the expected answer. + + Args: + answer (Optional[str]): The user's answer string. + entry (dict[str, Any]): The original dataset entry with the correct answer. + + Returns: + float: 1.0 for a correct answer, 0.0 for incorrect. + """ + if answer is None: + return 0.0 + + try: + # Try to parse the user's answer as a JSON list first + try: + user_answer = json.loads(answer) + except json.JSONDecodeError: + # If JSON parsing fails, fall back to eval (with caution) + user_answer = eval(answer) + + if not isinstance(user_answer, list): + return 0.0 + + # Get the expected answer + try: + expected_answer = json.loads(entry["answer"]) + except json.JSONDecodeError: + # Fall back to eval if necessary + expected_answer = eval(entry["answer"]) + + # Check if the lists have the same length + if len(user_answer) != len(expected_answer): + return 0.0 + + # Convert both answers to floats for comparison + user_floats = [float(num) for num in user_answer] + expected_floats = [float(num) for num in expected_answer] + + # First, verify the user's answer is properly sorted + direction = entry["metadata"]["direction"] + is_correctly_sorted = False + + if direction == "ascending": + is_correctly_sorted = user_floats == sorted(user_floats) + else: # descending + is_correctly_sorted = user_floats == sorted(user_floats, reverse=True) + + if not is_correctly_sorted: + return 0.0 + + # Check if the values are close enough (allowing for small rounding differences) + tolerance = 0.1 # Increased tolerance to handle decimal differences + for i in range(len(user_floats)): + if abs(user_floats[i] - expected_floats[i]) > tolerance: + return 0.0 + + return 1.0 + except Exception as e: + # Any parsing error means the answer is incorrect + return 0.0 + class NumberSortingCurriculum(BaseCurriculum): def __init__(self): diff --git a/tests/test_number_sorting.py b/tests/test_number_sorting.py index 531a3103..bd88345f 100644 --- a/tests/test_number_sorting.py +++ b/tests/test_number_sorting.py @@ -117,3 +117,59 @@ def test_number_sorting_curriculum(): assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100 assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4 assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000 + + +def test_number_sorting_score_answer(): + """Test the score_answer method for correctly evaluating model responses.""" + # Create a dataset instance + config = NumberSortingConfig(seed=42) + dataset = NumberSortingDataset(config) + + # Create a mock entry similar to the example provided + mock_entry = { + "question": "Sort these numbers in ascending order: -16.5, -83.6, -95.7, -97.8, 61.5, 71.08, -92.85", + "answer": "['-97.8', '-95.7', '-92.8', '-83.6', '-16.5', '61.5', '71.1']", + "metadata": { + "direction": "ascending", + "original_numbers": ["-16.5", "-83.6", "-95.7", "-97.8", "61.5", "71.08", "-92.85"], + "sorted_numbers": ["-97.8", "-95.7", "-92.8", "-83.6", "-16.5", "61.5", "71.1"], + }, + } + + # Test case 1: Exact match should score 1.0 + exact_match = "['-97.8', '-95.7', '-92.8', '-83.6', '-16.5', '61.5', '71.1']" + assert dataset.score_answer(exact_match, mock_entry) == 1.0 + + # Test case 2: Answer with small numerical differences but correct order should score 1.0 + close_match = "['-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5', '71.08']" + assert dataset.score_answer(close_match, mock_entry) == 1.0 + + # Test case 3: Incorrectly sorted answer should score 0.0 + wrong_order = "['-16.5', '-83.6', '-92.85', '-95.7', '-97.8', '61.5', '71.08']" + assert dataset.score_answer(wrong_order, mock_entry) == 0.0 + + # Test case 4: Answer with wrong length should score 0.0 + wrong_length = "['-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5']" + assert dataset.score_answer(wrong_length, mock_entry) == 0.0 + + # Test case 5: Non-list answer should score 0.0 + non_list = "'-97.8', '-95.7', '-92.85', '-83.6', '-16.5', '61.5', '71.08'" + assert dataset.score_answer(non_list, mock_entry) == 0.0 + + # Test case 6: None answer should score 0.0 + assert dataset.score_answer(None, mock_entry) == 0.0 + + # Test case 7: Correctly sorted but with larger numerical differences (beyond tolerance) + beyond_tolerance = "['-97.8', '-95.7', '-91.0', '-83.6', '-16.5', '61.5', '72.0']" + assert dataset.score_answer(beyond_tolerance, mock_entry) == 0.0 + + # Test case 8: Descending order test + descending_entry = { + "answer": "['71.1', '61.5', '-16.5', '-83.6', '-92.8', '-95.7', '-97.8']", + "metadata": { + "direction": "descending", + "sorted_numbers": ["71.1", "61.5", "-16.5", "-83.6", "-92.8", "-95.7", "-97.8"], + }, + } + descending_match = "['71.08', '61.5', '-16.5', '-83.6', '-92.85', '-95.7', '-97.8']" + assert dataset.score_answer(descending_match, descending_entry) == 1.0