From bea9e6d96ae8a5303eb0e71e74caad81bf0acc86 Mon Sep 17 00:00:00 2001 From: Adefioye <47661641+Adefioye@users.noreply.github.com> Date: Mon, 10 Feb 2025 08:15:23 -0600 Subject: [PATCH] Add score_answer method to word_ladder (#93) * Add score_answer method to word_ladder * add unit test for WordLadderDataset::score_answer() --------- Co-authored-by: Andreas Koepf --- reasoning_gym/algorithmic/word_ladder.py | 66 ++++++++++++++++++------ tests/test_word_ladder.py | 44 +++++++++++++++- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index a0b000c2..64c65326 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -5,8 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Dict, List, Optional, Set, Tuple -from reasoning_gym.data import read_data_file - +from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset @@ -64,6 +63,7 @@ class WordLadderDataset(ProceduralDataset): self.config = config self.word_sets = {} self.word_graphs = {} + self._vocabulary = None # A large list of dictionary words to validate words against # Load words from CSV self.word_sets = self._load_words_from_csv( @@ -84,28 +84,24 @@ class WordLadderDataset(ProceduralDataset): assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive" import csv - from io import StringIO word_sets = {} try: # Get CSV content as string - csv_content = read_data_file("words.csv") + with get_data_file_path("words.csv").open("r", encoding="utf-8") as csv_file: + reader = csv.DictReader(csv_file) - # Use StringIO to create a file-like object from the string - csv_file = StringIO(csv_content) - reader = csv.DictReader(csv_file) + for row in reader: + # Process each word length column using config range + for length in range(min_length, max_length + 1): + col_name = f"{length}_letter" + word = row.get(col_name, "") - for row in reader: - # Process each word length column using config range - for length in range(min_length, max_length + 1): - col_name = f"{length}_letter" - word = row.get(col_name, "") + if not word: # Skip empty entries + continue - if not word: # Skip empty entries - continue - - word_sets.setdefault(length, set()).add(word.upper()) + word_sets.setdefault(length, set()).add(word.upper()) except Exception as e: raise RuntimeError(f"Error processing words.csv content: {e}") from e @@ -220,5 +216,43 @@ class WordLadderDataset(ProceduralDataset): "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + if answer is None: + return 0 + + answer_words = tuple(s.strip() for s in answer.upper().split(",")) + + metadata = entry["metadata"] + start_word = metadata["start_word"] + end_word = metadata["end_word"] + word_length = len(end_word) + known_words = self.word_sets[word_length] + + # Check conditions: + # 1. start and end word match question + # 2. all words have the correct length + # 3. every changed word is a single letter change from the previous word + # 4. all words are in our vocabulary + + if len(answer_words) < 2: + return 0 + + if answer_words[0] != start_word or answer_words[-1] != end_word: + return 0.01 + + if not all(len(w) == word_length for w in answer_words): + return 0.01 + + for i in range(1, len(answer_words)): + if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1: + return 0.01 + + reward = 1.0 + for word in answer_words: + if not word in known_words: + reward *= 0.5 + + return reward + register_dataset("word_ladder", WordLadderDataset, WordLadderConfig) diff --git a/tests/test_word_ladder.py b/tests/test_word_ladder.py index d42108ea..1aba4cf3 100644 --- a/tests/test_word_ladder.py +++ b/tests/test_word_ladder.py @@ -355,5 +355,45 @@ def test_word_ladder_edge_cases(): assert max_length > 3, "No challenging word pairs generated" -if __name__ == "__main__": - pytest.main([__file__]) +def test_word_ladder_score_answer(): + """Test the score_answer method""" + config = WordLadderConfig(min_word_length=4, max_word_length=4) + dataset = WordLadderDataset(config) + + # Create a test entry + entry = { + "question": "Transform the word ladder 'COLD' to 'WARM' by changing one letter at a time.", + "answer": "COLD,CORD,CARD,WARD,WARM", + "metadata": {"start_word": "COLD", "end_word": "WARM", "word_length": 4, "chain_length": 5}, + } + + # Test perfect answer + assert dataset.score_answer("COLD,CORD,CARD,WARD,WARM", entry) == 1.0 + + # Test None answer + assert dataset.score_answer(None, entry) == 0.0 + + # Test empty answer + assert dataset.score_answer("", entry) == 0.0 + + # Test single word answer + assert dataset.score_answer("COLD", entry) == 0.0 + + # Test wrong start word + assert dataset.score_answer("BOLD,CORD,CARD,WARD,WARM", entry) == 0.01 + + # Test wrong end word + assert dataset.score_answer("COLD,CORD,CARD,WARD,WARP", entry) == 0.01 + + # Test wrong word length + assert dataset.score_answer("COLD,CORDS,CARDS,WARD,WARM", entry) == 0.01 + + # Test invalid transitions (more than one letter change) + assert dataset.score_answer("COLD,WARD,WARM", entry) == 0.01 + + # Test case insensitivity + assert dataset.score_answer("cold,cord,card,ward,warm", entry) == 1.0 + + # Test with unknown words (should return partial credit) + assert dataset.score_answer("COLD,COXD,CARD,WARD,WARM", entry) < 1.0 + assert dataset.score_answer("COLD,COXD,CARD,WARD,WARM", entry) > 0.0