From 229086131a8a9a00a6c1138af4ef3378439a3a8a Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Wed, 26 Feb 2025 12:54:40 +0100 Subject: [PATCH] fix CCC scoring --- reasoning_gym/cognition/color_cube_rotation.py | 9 ++------- tests/test_color_cube_rotation.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index fe7feeea..59e3e3a1 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -191,17 +191,12 @@ class ColorCubeRotationDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 0.0 - metadata = entry["metadata"] if answer is not None: try: - answer_formatted = answer.lower() - solved = answer_formatted == metadata["answer"] + answer_formatted = answer.strip().lower() + solved = answer_formatted == entry["answer"].strip().lower() if solved: reward = 1.0 - elif metadata["answer"] in answer_formatted: - reward = 0.25 - elif len(answer.strip()) > 0: - reward = 0.05 else: reward = 0.01 except: diff --git a/tests/test_color_cube_rotation.py b/tests/test_color_cube_rotation.py index a554afdd..87ecd8c6 100644 --- a/tests/test_color_cube_rotation.py +++ b/tests/test_color_cube_rotation.py @@ -49,6 +49,16 @@ def test_deterministic_generation(): assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["answer"] == dataset2[i]["answer"] + for item in dataset1: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset1.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset1.score_answer(answer=None, entry=item) == 0.0 + def test_cube_rotations(): # Test individual rotation operations