diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py index 17f45315..072bd0de 100644 --- a/reasoning_gym/cognition/rubiks_cube.py +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -137,20 +137,21 @@ class RubiksCubeDataset(ProceduralDataset): return ansi_escape.sub("", line) def expand_moves(self, move_str): - moves = move_str.split() - expanded = [] - for move in moves: - # Split the move into the base part and any trailing digits - match = re.fullmatch(r"^([^\d]*)(\d*)$", move) - if match: - base, num_part = match.groups() - if num_part: - # Append two copies of the base if there was a number. I don't think F3 is a valid signmaster notation etc - expanded.append(base) - expanded.append(base) - else: - expanded.append(base) - return " ".join(expanded).strip() + try: + moves = move_str.split() + expanded = [] + for move in moves: + # Split the move into the base part and any trailing digits + match = re.fullmatch(r"^([^\d]*)(\d*)$", move) + if match: + base, num_part = match.groups() + if num_part: + expanded.extend([base] * int(num_part)) + else: + expanded.append(base) + return " ".join(expanded).strip() + except Exception as e: + return move_str # Register the dataset diff --git a/tests/test_rubiks_cube.py b/tests/test_rubiks_cube.py index 781d2e69..f615d700 100644 --- a/tests/test_rubiks_cube.py +++ b/tests/test_rubiks_cube.py @@ -55,5 +55,7 @@ def test_rubikscube_items(): if item["metadata"]["example_correct_answer"] != "R": assert dataset.score_answer(answer="R", entry=item) == 0.05 + assert dataset.score_answer(answer="R2 R3 R4 R5 R'2 R'3", entry=item) == 0.05 + if len(item["metadata"]["example_correct_answer"]) > 0: assert dataset.score_answer(answer="", entry=item) == 0.01