fix chain_sum unit test

This commit is contained in:
Andreas Koepf 2025-01-30 10:57:55 +01:00
parent a6bf23e655
commit 5b35ea51a7
3 changed files with 19 additions and 8 deletions

View file

@ -15,6 +15,7 @@ class RubiksCubeConfig:
scramble_steps: int = 3 # Number of random steps from initial state
cube_size: int = 3 # Default to a standard 3x3x3 cube
remove_ansi: bool = True
seed: Optional[int] = None
size: int = 500
@ -78,7 +79,12 @@ class RubiksCubeDataset(ProceduralDataset):
cube = Cube(self.config.cube_size)
scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps)
cube.rotate(scramble_moves)
cube_render = self.remove_ansi(str(cube))
# render cube
if self.config.remove_ansi:
cube_render = self.remove_ansi(str(cube))
else:
cube_render = str(cube)
if self.config.cube_size == 3:
solver = BasicSolver(cube)
@ -102,10 +108,8 @@ class RubiksCubeDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the cube"""
answer = answer.strip()
reward = 0.0
if answer is not None and len(answer) > 0:
reward = 0.0 # default reward
if answer is not None:
# Reconstruct the test cube
eval_cube = Cube(entry["metadata"]["cube_size"])
eval_cube.rotate(entry["metadata"]["scramble_moves"])
@ -117,8 +121,10 @@ class RubiksCubeDataset(ProceduralDataset):
if solved:
reward = 1.0
elif len(answer.strip()) > 0: # encourage non-empty answers
reward = 0.05 # Incorrect, but rotate could parse the answer
else:
reward = 0.1 # Incorrect, but rotate could parse the answer
reward = 0.01
except:
reward = 0.01 # At least you tried