diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index bc20177c..34951c0c 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -19,6 +19,23 @@ class TextTransformation(StrEnum): RANDOMCASE = "randomcase" +QUESTION_TEMPLATE = """Your task is to sort words in ascending or descending order using ASCII/Unicode ordering. + +Example: +- Input: Sort these words in ascending order (using ASCII/Unicode ordering) and return them as a comma-separated list: freely, idea, indemnify, last, END, solving +- Output: END, freely, idea, indemnify, last, solving +- Explanation: + - Uppercase letters come before lowercase letters, hence why "END" comes first. + - "freely" comes before "idea" because "f" comes before "i". + - "idea" comes before "indemnify" because even though they both start with "i", "d" comes before "n". + - "indemnify" comes before "last" because "i" comes before "l". + - "last" comes before "solving" because "l" comes before "s". + - Finally, the output is provided as a comma separated list of the sorted words. + +Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words} +""" + + @dataclass class WordSortingConfig: """Configuration for word sorting task generation""" @@ -94,7 +111,7 @@ class WordSortingDataset(ProceduralDataset): answer = asc_words if is_ascending else desc_words return { - "question": f"Sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list:\n{', '.join(transformed_words)}", + "question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)), "answer": ", ".join(answer), "metadata": { "original_words": original_words, @@ -106,26 +123,17 @@ class WordSortingDataset(ProceduralDataset): } def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - """Determine if the solution provided solves this task. + oracle_answer = entry["metadata"]["sorted_words"] + if answer is not None and len(answer) > 0: + parsed_answer = [word.strip() for word in re.split(r",\s*", answer)] + if parsed_answer == oracle_answer: + return 1.0 + elif sorted(parsed_answer) == oracle_answer: + return 0.2 + else: + return 0.01 - The function awards 1.0 for a correct answer. - - Args: - answer (Optional[str]): The user's answer. - entry (Dict[str, any]): The original dataset entry containing the correct answer. - - Returns: - float: The computed score between 0.0 and 1.0. - """ - - if answer == None: - return 0.0 - - s_answer = answer.strip().replace(" ", "") - if not s_answer == entry["answer"].strip().replace(" ", ""): - return 0.01 - else: - return 1.0 + return 0.0 register_dataset("word_sorting", WordSortingDataset, WordSortingConfig) diff --git a/tests/test_word_sorting.py b/tests/test_word_sorting.py index ea66b86f..8920e814 100644 --- a/tests/test_word_sorting.py +++ b/tests/test_word_sorting.py @@ -116,3 +116,35 @@ def test_word_sorting_dataset_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_word_sorting_scoring(): + """Test scoring function""" + config = WordSortingConfig(size=1, seed=42) + dataset = WordSortingDataset(config) + + item = { + "metadata": { + "sorted_words": ["apple", "banana", "cherry"], + } + } + + # Correct answer + answer = "apple, banana, cherry" + assert dataset.score_answer(answer, item) == 1.0 + + # Correct answer, with incorrect spaces + answer = "apple,banana, cherry" + assert dataset.score_answer(answer, item) == 1.0 + + # All words present, but not sorted + answer = "banana, cherry, apple" + assert dataset.score_answer(answer, item) == 0.2 + + # Garbage + answer = "gibberish" + assert dataset.score_answer(answer, item) == 0.01 + + # Empty answer + answer = None + assert dataset.score_answer(answer, item) == 0.0