diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index b217ed76..77ea075f 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -5,7 +5,7 @@ https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dat from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Dict, Optional from ..factory import ProceduralDataset, register_dataset @@ -26,6 +26,7 @@ Example - First, we insert A after ABCD. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. - Lastly, we insert D after DEAB. + - Therefore, the final answer is DDABCDAEEDEABD (represented as a string, instead of a list of characters). Given the following string, provide the answer after inserting the characters according to the pattern: {string} """ @@ -79,12 +80,28 @@ class StringInsertionDataset(ProceduralDataset): i += 1 return "".join(output) + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Overwrite this method in derived classes if a single oracle answer is not available.""" + oracle_answer = entry["answer"] + if answer is not None: + if answer == oracle_answer: + return 1.0 + else: + try: + # check if answer is python list of characters + answer = "".join(eval(answer)) + if answer == oracle_answer: + return 0.5 + except Exception as e: + return 0.01 + return 0.0 + def __getitem__(self, idx: int) -> dict: """Generate a single String Insertion question""" rng = Random(self.seed + idx) string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) - string = [rng.choice(self.vocabulary) for _ in range(string_length)] + string = "".join(rng.choice(self.vocabulary) for _ in range(string_length)) answer = self._get_answer(string) diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index 12225954..9d815b15 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -92,3 +92,13 @@ def test_string_insertion_answer(): # No reuse of newly inserted characters assert dataset._get_answer("ABCDBCD") == "ABCDABCD" + + # Test score_answer with correct answer + answer = "AABCDAEEEEEEEBCDEBAAAAA" + entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"} + assert dataset.score_answer(answer, entry) == 1.0 + + # Test score_answer with correct answer as python list of characters (partial correct) + answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']" + entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"} + assert dataset.score_answer(answer, entry) == 0.5