diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index e9a02d66..24afda1a 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -19,7 +19,6 @@ class ChainSumConfig: def validate(self) -> None: """Validate configuration parameters""" - assert self.difficulty > 0, "difficulty must be positive" assert self.size > 0, "size must be positive" """Validate configuration parameters""" assert self.min_terms > 0, "min_terms must be positive" diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py index 264eef46..f504d92d 100644 --- a/reasoning_gym/cognition/rubiks_cube.py +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -15,6 +15,7 @@ class RubiksCubeConfig: scramble_steps: int = 3 # Number of random steps from initial state cube_size: int = 3 # Default to a standard 3x3x3 cube + remove_ansi: bool = True seed: Optional[int] = None size: int = 500 @@ -78,7 +79,12 @@ class RubiksCubeDataset(ProceduralDataset): cube = Cube(self.config.cube_size) scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps) cube.rotate(scramble_moves) - cube_render = self.remove_ansi(str(cube)) + + # render cube + if self.config.remove_ansi: + cube_render = self.remove_ansi(str(cube)) + else: + cube_render = str(cube) if self.config.cube_size == 3: solver = BasicSolver(cube) @@ -102,10 +108,8 @@ class RubiksCubeDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: """Determine if the solution provided solves the cube""" - answer = answer.strip() - - reward = 0.0 - if answer is not None and len(answer) > 0: + reward = 0.0 # default reward + if answer is not None: # Reconstruct the test cube eval_cube = Cube(entry["metadata"]["cube_size"]) eval_cube.rotate(entry["metadata"]["scramble_moves"]) @@ -117,8 +121,10 @@ class RubiksCubeDataset(ProceduralDataset): if solved: reward = 1.0 + elif len(answer.strip()) > 0: # encourage non-empty answers + reward = 0.05 # Incorrect, but rotate could parse the answer else: - reward = 0.1 # Incorrect, but rotate could parse the answer + reward = 0.01 except: reward = 0.01 # At least you tried diff --git a/tests/test_rubiks_cube.py b/tests/test_rubiks_cube.py index 96e7d41d..781d2e69 100644 --- a/tests/test_rubiks_cube.py +++ b/tests/test_rubiks_cube.py @@ -49,5 +49,11 @@ def test_rubikscube_items(): assert "example_correct_answer" in item["metadata"] assert dataset.score_answer(answer=item["metadata"]["example_correct_answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="R", entry=item) == 0.01 + assert dataset.score_answer(answer="a wrong solution", entry=item) == 0.01 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + if item["metadata"]["example_correct_answer"] != "R": + assert dataset.score_answer(answer="R", entry=item) == 0.05 + + if len(item["metadata"]["example_correct_answer"]) > 0: + assert dataset.score_answer(answer="", entry=item) == 0.01