diff --git a/.gitignore b/.gitignore index 393a7522..16d2503b 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,11 @@ uv.lock .DS_Store .aider* +environments/community/word_hunt/Dictionary.txt + +# Word Hunt artefacts +environments/community/word_hunt/Dictionary.txt +environments/community/word_hunt/debug_test.py +environments/community/word_hunt/template_with_metadata.html +environments/community/word_hunt/word_hunt_rollouts*.html +.DS_Store diff --git a/environments/community/word_hunt/README.md b/environments/community/word_hunt/README.md new file mode 100644 index 00000000..c4f53789 --- /dev/null +++ b/environments/community/word_hunt/README.md @@ -0,0 +1,125 @@ +# Word Hunt Environment + +This Atropos environment is designed to train language models to play **Word Hunt**, a game where the goal is to trace through a 4x4 grid of letters to create as many words as possible within the time limit. + +Word Hunt combines multiple cognitive challenges: spatial reasoning (tracing paths through the grid), vocabulary knowledge (recognizing valid words), and strategic optimization (prioritizing longer, higher-scoring words within token output constraints). + + +## Game Rules and Sample Prompt + +The model receives a 4x4 grid of letters and must find valid English words by tracing through adjacent letters (including diagonally). The key rules are: + +- Words must be **at least 3 letters long** +- Letters must be **adjacent** (horizontally, vertically, or diagonally) +- Each letter can only be **used once per word** +- The board **does not wrap around** (edges are not connected) +- Only **valid English words** count toward the score +- **Duplicate words** don't count for extra points + +**Sample Prompt:** +``` +Find English words on this 4x4 letter grid to maximize your score. +Longer words are worth more points. You must adhere to the following rules: + +• Words must be AT LEAST 3 letters long and have to be formed by connecting + adjacent letters on the board (including diagonally). +• The board does not wrap around; letters on opposite edges are not considered adjacent. +• The whole word must have an adjacent path through it +• Each letter can only be used once per word. +• The word must be a valid word in the English language +• Making the same word in multiple ways does not count for extra points - + each unique word only counts once. + +Provide your answer as a comma-separated list, like this: WORD, ANOTHER, EXAMPLE + +Scoring: 3-letter: 100pts, 4-letter: 400pts, 5-letter: 800pts + +Board: +G O E L +M I I E +N G M C +B S D T + +Found words: +``` + +## Features + +- **High-Performance Solver:** Uses a Trie-based recursive backtracking algorithm to efficiently find all possible words on any given board. This allows us to measure model performance against the optimal solution and provide accurate scoring. + +- **Scoring:** Points are awarded based on word length following the official Word Hunt scoring system: + - 3-letter words: 100 points + - 4-letter words: 400 points + - 5-letter words: 800 points + - 6+ letter words: 1400 + (400 × (length - 6)) points + - These scoring rules are based on the version of Word Hunt in GamePidgeon, the iOS app + + The default maximum output tokens a response will produce is 100 tokens, which simulates the time pressure and encourages strategic resource allocation. + + +- **Reward Signal:** The environment provides normalized rewards (0-1 range) based on the model's score divided by the maximum possible score on each board, ensuring consistent training signals across different board difficulties. + + +## Setup and Dependencies + +The environment relies on a dictionary file to validate words. You'll need to download the dictionary file before running the environment: + +1. **Download the dictionary file:** + Manually download `Dictionary.txt` from: https://github.com/Aboozle1/wordhuntsolver/blob/main/Dictionary.txt + +2. **Place it in the correct location:** + Save the file as: `environments/community/word_hunt/Dictionary.txt` + +If you wish to use a different dictionary, you can change the `dictionary_path` in `word_hunt_config.py`. + +## How to Run + +You can run the environment and generate a small sample of rollouts using the `process` command. This is a quick way to test that the environment is working correctly. The results will be saved to a `.jsonl` file and a corresponding `.html` report for easy viewing. + +```bash +python environments/community/word_hunt/word_hunt_env.py process \ + --env.total_steps 2 \ + --env.data_path_to_save_groups word_hunt_sample_rollouts.jsonl \ + --env.use_wandb false +``` + +This command will: +- Run the model for **2 steps** (i.e., process 2 groups of boards). +- Save the results to `word_hunt_sample_rollouts.jsonl` and `word_hunt_sample_rollouts.html`. + +You can also override other parameters, like the model endpoint: +```bash +python environments/community/word_hunt/word_hunt_env.py process \ + --openai.base_url \ + --openai.api_key \ + --openai.model_name \ + --env.total_steps 2 \ + --env.use_wandb false +``` + +### Full Example with Model Configuration + +Here is a complete, real-world example that specifies the model endpoint and API key. This is useful when you want to target a specific model that is not set as your default. + +```bash +python3 environments/community/word_hunt/word_hunt_env.py process \ + --openai.base_url https://inference-api.nousresearch.com/v1 \ + --openai.api_key \ + --openai.model_name DeepHermes-3-Llama-3-8B-Preview \ + --env.data_path_to_save_groups word_hunt_rollouts.jsonl \ + --env.use_wandb false \ + --env.total_steps 2 +``` + +## Configuration + +The primary configuration for this environment is handled in `environments/community/word_hunt/word_hunt_config.py`. Key options include: + +- `prompt_style`: How the board is presented to the model. Can be `grid_visual`, `text_description`, or `both`. +- `include_instructions`: Whether to include the game rules in the prompt. +- `include_scoring_info`: Whether to show the scoring system in the prompt. +- `board_size`: The dimensions of the game board (default is 4). +- `dictionary_path`: The file path to the dictionary used for word validation. +- `letter_frequencies`: The probability distribution used for generating the letter grid. +- `scoring_system`: A dictionary mapping word lengths to point values. +- `max_tokens_per_game`: The maximum number of tokens the model is allowed to generate for its response. This simulates a "time limit" and can be adjusted to make the task easier or harder (default: 100). diff --git a/environments/community/word_hunt/__init__.py b/environments/community/word_hunt/__init__.py new file mode 100644 index 00000000..e9d2565d --- /dev/null +++ b/environments/community/word_hunt/__init__.py @@ -0,0 +1,9 @@ +""" +Word Hunt Environment Package +""" + +from .word_hunt_config import WordHuntEnvConfig +from .word_hunt_env import WordHuntEnv +from .word_hunt_solver import WordHuntSolver + +__all__ = ["WordHuntEnv", "WordHuntEnvConfig", "WordHuntSolver"] diff --git a/environments/community/word_hunt/example_config.yaml b/environments/community/word_hunt/example_config.yaml new file mode 100644 index 00000000..9eda67aa --- /dev/null +++ b/environments/community/word_hunt/example_config.yaml @@ -0,0 +1,75 @@ +# Word Hunt Environment Example Configuration +# Copy this file and modify as needed for your training runs + +# Environment Configuration +env_name: "word_hunt_environment" +env_config: + # Core Training Parameters + group_size: 4 + max_token_length: 2048 + steps_per_eval: 100 + boards_per_epoch: 50 + + # Game Parameters + board_size: 4 + min_word_length: 3 + max_word_length: 16 + max_tokens_per_game: 100 + + # Board Generation + vowel_weight: 0.4 + common_letter_bias: true + board_generation_method: "random_weighted" + + # Scoring + use_official_scoring: true + normalize_scores: true + max_possible_score_estimate: 10000 + + # Validation + validate_words: true + validate_board_paths: true + + # Prompt Configuration + prompt_style: "grid_visual" # Options: grid_visual, text_description, both + include_instructions: true + include_scoring_info: true + + # Evaluation + eval_board_count: 10 + eval_metrics: ["accuracy", "total_score", "word_count", "avg_word_length"] + + # Logging + debug_mode: false + use_wandb: true + wandb_name: "word_hunt_training" + + # Advanced Options + shuffle_boards: true + save_board_images: false + +# Server Configuration (example) +server_configs: + - name: "openai" + server_type: "openai" + api_key: "${OPENAI_API_KEY}" + base_url: "https://api.openai.com/v1" + model: "gpt-4" + max_retries: 3 + timeout: 60 + +# Training Parameters +training: + total_steps: 10000 + batch_size: 512 + learning_rate: 1e-5 + warmup_steps: 100 + gradient_accumulation_steps: 1 + +# Logging Configuration +logging: + log_level: "INFO" + save_checkpoints: true + checkpoint_interval: 1000 + save_rollouts: false + rollout_save_interval: 500 diff --git a/environments/community/word_hunt/requirements.txt b/environments/community/word_hunt/requirements.txt new file mode 100644 index 00000000..3aeb5623 --- /dev/null +++ b/environments/community/word_hunt/requirements.txt @@ -0,0 +1,14 @@ +# Word Hunt Environment Dependencies +# Core Atropos dependencies are handled by the main project + +# Configuration and validation +pydantic>=2.0.0 + +# Tokenization and model handling +transformers>=4.30.0 + +# Optional: For enhanced word validation (if using external dictionary) +# nltk>=3.8.0 + +# Optional: For performance monitoring +# psutil>=5.9.0 diff --git a/environments/community/word_hunt/trie.py b/environments/community/word_hunt/trie.py new file mode 100644 index 00000000..a189e74c --- /dev/null +++ b/environments/community/word_hunt/trie.py @@ -0,0 +1,54 @@ +""" +Trie Data Structure for Word Hunt +""" + + +class TrieNode: + """A node in the Trie structure.""" + + def __init__(self): + self.children = {} # A dictionary of child nodes: {char: TrieNode} + self.is_end_of_word = False + + +class Trie: + """ + Trie data structure for efficient word and prefix lookups, optimized for the Word Hunt game. + """ + + def __init__(self): + """Initializes the Trie with an empty root node.""" + self.root = TrieNode() + + def insert(self, word: str): + """ + Inserts a word into the Trie. Assumes word is already uppercase. + """ + node = self.root + for char in word: + if char not in node.children: + node.children[char] = TrieNode() + node = node.children[char] + node.is_end_of_word = True + + def is_word(self, word: str) -> bool: + """ + Searches for a complete word in the Trie. Assumes word is already uppercase. + """ + node = self.root + for char in word: + if char not in node.children: + return False + node = node.children[char] + return node.is_end_of_word + + def is_prefix(self, prefix: str) -> bool: + """ + Checks if a string is a prefix of any word in the Trie. Assumes prefix is already uppercase. + """ + node = self.root + for char in prefix: + if char not in node.children: + return False + node = node.children[char] + return True diff --git a/environments/community/word_hunt/word_hunt_config.py b/environments/community/word_hunt/word_hunt_config.py new file mode 100644 index 00000000..f548912d --- /dev/null +++ b/environments/community/word_hunt/word_hunt_config.py @@ -0,0 +1,259 @@ +""" +Custom configuration class for Word Hunt Environment +Defines all configurable parameters for board generation, scoring, and training +""" + +from typing import Dict, List + +from pydantic import Field + +from atroposlib.envs.base import BaseEnvConfig + + +class WordHuntEnvConfig(BaseEnvConfig): + """Configuration for Word Hunt Environment""" + + # Board Generation Parameters + board_size: int = Field(default=4, description="Size of the word hunt board (4x4)") + min_word_length: int = Field(default=3, description="Minimum word length to count") + max_word_length: int = Field(default=16, description="Maximum word length possible") + + # Letter Distribution Parameters + vowel_weight: float = Field( + default=0.4, description="Probability of generating vowels vs consonants" + ) + common_letter_bias: bool = Field( + default=True, description="Bias towards more common English letters" + ) + + # Token and Response Parameters + max_tokens_per_game: int = Field( + default=100, description="Maximum tokens model can use per game" + ) + use_official_scoring: bool = Field( + default=True, description="Use official Word Hunt scoring rules" + ) + normalize_scores: bool = Field( + default=True, description="Normalize scores between 0 and 1" + ) + + # Dictionary and solver settings + dictionary_path: str = Field( + default="environments/community/word_hunt/Dictionary.txt", + description="Path to the dictionary file for word validation", + ) + validate_words: bool = Field( + default=True, description="Validate words are in dictionary" + ) + validate_board_paths: bool = Field( + default=True, description="Validate words can be formed on board" + ) + + # Prompt settings + prompt_style: str = Field( + default="grid_visual", + description="How to present board: 'grid_visual', 'text_description', 'both'", + ) + include_instructions: bool = Field( + default=True, description="Include game instructions in prompt" + ) + include_scoring_info: bool = Field( + default=True, description="Include scoring information in prompt" + ) + + # Tokenizer settings + tokenizer_name: str = Field( + default="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + description="Name of the tokenizer to use for model generation", + ) + + # Training Parameters + boards_per_epoch: int = Field( + default=100, description="Number of boards to generate per training epoch" + ) + shuffle_boards: bool = Field( + default=True, description="Shuffle board order each epoch" + ) + + # Evaluation Parameters + eval_board_count: int = Field( + default=10, description="Number of boards for evaluation" + ) + eval_metrics: List[str] = Field( + default=["accuracy", "total_score", "word_count", "avg_word_length"], + description="Metrics to track during evaluation", + ) + + # Advanced Parameters + debug_mode: bool = Field(default=False, description="Enable debug logging") + save_board_images: bool = Field( + default=False, description="Save board visualizations for debugging" + ) + + # Override some BaseEnvConfig defaults for Word Hunt + group_size: int = Field( + default=16, description="Number of responses to generate for each board." + ) + max_token_length: int = Field( + default=1024 * 16, + description="Max tokens for model generation (matching working environments)", + ) + steps_per_eval: int = Field( + default=25, + description="Steps between evaluations (matching working environments)", + ) + inference_weight: float = Field( + default=1.0, description="Inference weight for training" + ) + min_batch_allocation: float = Field( + default=0.1, + description="Minimum batch allocation (matching working environments)", + ) + + def get_scoring_system(self) -> Dict[int, int]: + """Get the scoring system for word lengths""" + if self.use_official_scoring: + return { + 3: 100, # 3-letter words + 4: 400, # 4-letter words + 5: 800, # 5-letter words + # 6+ letter words: 1400 + (400 * (length - 6)) + } + else: + # Custom scoring: exponential growth + return {length: 2 ** (length - 2) for length in range(3, 17)} + + def get_letter_frequencies(self) -> Dict[str, float]: + """Get letter frequency distribution for board generation""" + if self.common_letter_bias: + return { + "E": 12.0, + "A": 8.2, + "R": 6.7, + "I": 6.3, + "O": 6.1, + "T": 5.9, + "N": 5.7, + "S": 5.3, + "L": 4.0, + "C": 3.8, + "U": 3.0, + "D": 2.8, + "P": 2.7, + "M": 2.4, + "H": 2.3, + "G": 2.0, + "B": 1.5, + "F": 1.4, + "Y": 1.4, + "W": 1.3, + "K": 0.8, + "V": 0.6, + "X": 0.2, + "Z": 0.1, + "J": 0.1, + "Q": 0.1, + } + else: + # Uniform distribution + return {chr(i): 1.0 for i in range(65, 91)} # A-Z + + def get_prompt_template(self) -> str: + """Get the prompt template for presenting boards to the model""" + if self.prompt_style == "grid_visual": + template = """ +You are playing Word Hunt! Find as many English words as possible on this 4x4 letter grid. + +Rules: +- Words must be at least 3 letters long +- You can move in any direction (including diagonally) +- You cannot reuse the same letter in a single word +- Only real English words count + +{scoring_info} + +Here's your board: +{board_grid} + +Find all the words you can! Return them as a space-separated list. +""" + elif self.prompt_style == "text_description": + template = """ +You are playing Word Hunt! Find English words from these letters arranged in a 4x4 grid. + +Rules: +- Words must be at least 3 letters long +- You can move in any direction (including diagonally) +- You cannot reuse the same letter in a single word +- Only real English words count + +{scoring_info} + +Letters (reading left to right, top to bottom): {board_letters} + +Find all the words you can! Return them as a space-separated list. +""" + else: # both + template = """ +You are playing Word Hunt! Find as many English words as possible on this 4x4 letter grid. + +Rules: +- Words must be at least 3 letters long +- You can move in any direction (including diagonally) +- You cannot reuse the same letter in a single word +- Only real English words count + +{scoring_info} + +Here's your board: +{board_grid} + +Letters (reading left to right, top to bottom): {board_letters} + +Find all the words you can! Return them as a space-separated list. +""" + + return template.strip() + + def get_scoring_info(self) -> str: + """Get scoring information for the prompt""" + if not self.include_scoring_info: + return "" + + scoring = self.get_scoring_system() + info = "Scoring: " + info += ", ".join( + [ + f"{length}-letter: {score}pts" + for length, score in sorted(scoring.items())[:5] + ] + ) # Show first 5 + if len(scoring) > 5: + info += f", 6+ letters: {scoring[6]}+ pts" + + return info + + +# For testing the config +if __name__ == "__main__": + config = WordHuntEnvConfig() + + print("Word Hunt Environment Configuration:") + print(f"Board size: {config.board_size}x{config.board_size}") + print(f"Min word length: {config.min_word_length}") + print(f"Max tokens per game: {config.max_tokens_per_game}") + print(f"Prompt style: {config.prompt_style}") + + print("\nScoring system:") + scoring = config.get_scoring_system() + for length, score in sorted(scoring.items())[:8]: + print(f" {length}-letter words: {score} points") + + print("\nLetter frequencies (top 10):") + frequencies = config.get_letter_frequencies() + sorted_freq = sorted(frequencies.items(), key=lambda x: x[1], reverse=True)[:10] + for letter, freq in sorted_freq: + print(f" {letter}: {freq}") + + print("\nPrompt template:") + print(config.get_prompt_template()) diff --git a/environments/community/word_hunt/word_hunt_env.py b/environments/community/word_hunt/word_hunt_env.py new file mode 100644 index 00000000..68d665e5 --- /dev/null +++ b/environments/community/word_hunt/word_hunt_env.py @@ -0,0 +1,517 @@ +""" +Word Hunt Environment for Atropos +Trains models to find English words on 4x4 letter grids +""" + +import random +import uuid +from typing import List, Optional, Tuple + +from atroposlib.envs.base import APIServerConfig, BaseEnv, ScoredDataGroup +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +# Define a custom data group to include our reward metadata +class WordHuntScoredDataGroup(ScoredDataGroup): + pass + + +# Handle imports for both direct execution and module import +try: + from .word_hunt_config import WordHuntEnvConfig + from .word_hunt_solver import WordHuntSolver +except ImportError: + from word_hunt_config import WordHuntEnvConfig + from word_hunt_solver import WordHuntSolver + + +class WordHuntEnv(BaseEnv): + """Word Hunt Environment for training models to find words on 4x4 grids""" + + name = "word_hunt_environment" + + @classmethod + def config_init(cls) -> Tuple[WordHuntEnvConfig, List[APIServerConfig]]: + """Initializes the default configuration for the environment.""" + env_config = WordHuntEnvConfig() + server_configs = [APIServerConfig()] + return env_config, server_configs + + async def setup(self) -> None: + """ + Initialize environment, load solver, set up state + + This method: + 1. Initializes the WordHuntSolver with dictionary + 2. Sets up board generation parameters + 3. Initializes training statistics + 4. Prepares prompt templates + """ + # 1. Initialize WordHuntSolver with dictionary + self.solver = WordHuntSolver(self.config.dictionary_path) + print( + f"✅ WordHuntSolver initialized with dictionary: {self.config.dictionary_path}" + ) + + # Initialize tokenizer for scoring + try: + from transformers import AutoTokenizer + + print( + f"🔍 Debug: Initializing tokenizer with name: {self.config.tokenizer_name}" + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Set chat template if not present (GPT-2 doesn't have one by default) + if self.tokenizer.chat_template is None: + self.tokenizer.chat_template = ( + "{% for message in messages %}" + "{{ message['role'] + ': ' + message['content'] + '\\n' }}" + "{% endfor %}" + ) + print("✅ Set default chat template for GPT-2 tokenizer") + + print("✅ Tokenizer initialized") + except Exception as e: + print(f"❌ Failed to initialize tokenizer: {e}") + self.tokenizer = None + + # 2. Set up board generation parameters + self.letter_frequencies = self.config.get_letter_frequencies() + self.scoring_system = self.config.get_scoring_system() + self.prompt_template = self.config.get_prompt_template() + + # 3. Initialize training statistics + self.total_games = 0 + self.total_score = 0 + self.total_valid_words = 0 + self.total_invalid_words = 0 + self.current_board_index = 0 + self.boards_this_epoch = [] + self.current_item = None # Store current item for scoring + + # 4. Set up random state for reproducible board generation + random.seed(42) # Fixed seed for reproducibility + + # 5. Generate initial batch of boards for this epoch + await self._generate_epoch_boards() + + print("✅ Word Hunt Environment setup complete:") + print(f" - Board size: {self.config.board_size}x{self.config.board_size}") + print(f" - Boards per epoch: {self.config.boards_per_epoch}") + print(f" - Max tokens per game: {self.config.max_tokens_per_game}") + print(f" - Prompt style: {self.config.prompt_style}") + if self.config.use_official_scoring: + print(" - Scoring system: Official (3-5 letter scores + formula for 6+)") + else: + print( + f" - Scoring system: {len(self.scoring_system)} word lengths supported" + ) + + async def _generate_epoch_boards(self) -> None: + """Generate boards for the current epoch""" + self.boards_this_epoch = [] + + for i in range(self.config.boards_per_epoch): + board = self.solver.generate_random_board(self.letter_frequencies) + board_id = f"board_{self.total_games + i}_{uuid.uuid4().hex[:8]}" + + self.boards_this_epoch.append( + { + "board": board, + "board_id": board_id, + "max_tokens": self.config.max_tokens_per_game, + } + ) + + if self.config.shuffle_boards: + random.shuffle(self.boards_this_epoch) + + print(f"📋 Generated {len(self.boards_this_epoch)} boards for epoch") + + async def get_next_item(self) -> Optional[Tuple]: + """Get the next board for the model to solve (following Atropos standard format). + + Returns: + Tuple of (prompt_messages, board_data) following Atropos format, or None if epoch is complete + """ + if self.current_board_index >= len(self.boards_this_epoch): + # Epoch complete - generate new boards for next epoch + await self._generate_epoch_boards() + self.current_board_index = 0 + + if self.current_board_index >= len(self.boards_this_epoch): + return None # No more boards available + + board_data = self.boards_this_epoch[self.current_board_index] + self.current_board_index += 1 + + # Format the board into a prompt + board = board_data["board"] + prompt_text = self._format_board_prompt(board) + + # Create prompt messages in Atropos standard format (frozenset tuples) + prompt_messages = [frozenset({"role": "user", "content": prompt_text}.items())] + + # Return tuple following Atropos standard: (prompt_messages, board_data) + return (tuple(prompt_messages), board_data) + + def _format_board_prompt(self, board: List[List[str]]) -> str: + """Format the board into a prompt for the model.""" + prompt_parts = [] + + # Add instructions if enabled + if self.config.include_instructions: + prompt_parts.append( + "Find English words on this 4x4 letter grid to maximize your score. " + "Longer words are worth more points. You must adhere to the following rules:" + ) + prompt_parts.append( + "Words must be AT LEAST 3 letters long and have to be formed by " + "connecting adjacent letters on the board (including diagonally)." + ) + prompt_parts.append( + "The board does not wrap around; letters on opposite edges are not " + "considered adjacent." + ) + prompt_parts.append("The whole word must have an adjacent path through it") + prompt_parts.append("Each letter can only be used once per word.") + prompt_parts.append("The word must be a valid word in the English language") + prompt_parts.append( + "Making the same word in multiple ways does not count for extra points - " + "each unique word only counts once." + ) + prompt_parts.append( + "Provide your answer as a comma-separated list, like this: " + "WORD, ANOTHER, EXAMPLE" + ) + prompt_parts.append("") + + # Add scoring info if enabled + if self.config.include_scoring_info: + scoring_info = self.config.get_scoring_info() + prompt_parts.append(f"Scoring: {scoring_info}") + prompt_parts.append("") + + # Add the board based on prompt style + if self.config.prompt_style == "grid_visual": + prompt_parts.append("Board:") + for row in board: + prompt_parts.append(" ".join(row)) + elif self.config.prompt_style == "text_description": + letters = [] + for row in board: + letters.extend(row) + prompt_parts.append(f"Letters: {' '.join(letters)}") + elif self.config.prompt_style == "both": + prompt_parts.append("Board:") + for row in board: + prompt_parts.append(" ".join(row)) + letters = [] + for row in board: + letters.extend(row) + prompt_parts.append(f"Letters: {' '.join(letters)}") + + prompt_parts.append("") + prompt_parts.append("Found words:") + + return "\n".join(prompt_parts) + + async def collect_trajectories( + self, item + ) -> Tuple[Optional[WordHuntScoredDataGroup], List]: + """Collect trajectories with robust error handling and validation. + + Args: + item: Tuple of (prompt_messages, board_data) following Atropos format + + Returns: + Tuple of (scored_data, backlog): + - scored_data: ScoredDataGroup with tokens, masks, and scores, or None if failed + - backlog: Empty list (no follow-up items) + """ + # Validate input structure + if not isinstance(item, tuple) or len(item) != 2: + print( + f"❌ Invalid item format: expected tuple of (messages, data), got {type(item)}" + ) + return None, [] + + prompt_messages, board_data = item + + if not isinstance(prompt_messages, tuple) or not prompt_messages: + print( + f"❌ Invalid prompt_messages: expected non-empty tuple, got {type(prompt_messages)}" + ) + return None, [] + + if not isinstance(board_data, dict) or "board" not in board_data: + print( + f"❌ Invalid board_data: expected dict with 'board' key, got {type(board_data)}" + ) + return None, [] + + # Extract and validate messages + try: + messages = [] + for role_dict in prompt_messages: + if not isinstance(role_dict, frozenset): + print( + f"❌ Invalid message format: expected frozenset, got {type(role_dict)}" + ) + return None, [] + messages.append(dict(role_dict)) + + if not messages: + print("❌ No valid messages found") + return None, [] + + except Exception as e: + print(f"❌ Failed to extract messages: {e}") + return None, [] + + # Store current item for scoring + self.current_item = board_data + + # Debug: Print the messages structure + print("🔍 Debug: Messages structure before chat template:") + for i, msg in enumerate(messages): + print(f" Message {i}: {msg}") + print() + + # Apply chat template with error handling + try: + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + if not prompt or len(prompt.strip()) == 0: + print("❌ Generated prompt is empty") + return None, [] + + # Debug: Print the actual prompt being sent to the model + print("🔍 Debug: Actual prompt being sent to model:") + print("=" * 50) + print(prompt) + print("=" * 50) + + except Exception as e: + print(f"❌ Chat template application failed: {e}") + return None, [] + + # Get completions with timeout and validation + try: + print("🔍 Debug: About to call server.completion()") + print(f"🔍 Debug: Prompt length: {len(prompt)} chars") + print(f"🔍 Debug: Group size: {self.config.group_size}") + + # Rely on the server's built-in timeout and retry logic + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=self.config.max_tokens_per_game, + temperature=0.8, + ) + + print("🔍 Debug: API call completed successfully") + + if not completions or not completions.choices: + print("❌ No completions received from API") + return None, [] + + print(f"🔍 Debug: Got {len(completions.choices)} completions") + + # Debug: Print each completion response + print("🔍 Debug: Model responses:") + print("-" * 50) + for i, choice in enumerate(completions.choices): + print(f" Response {i+1}:") + print(f" {choice.text.strip()}") + print("-" * 50) + + except Exception as e: + print(f"❌ Model call failed: {e}") + import traceback + + traceback.print_exc() + return None, [] + + # Build trajectories efficiently + try: + # Pre-build base messages once + base_messages = [dict(role_dict) for role_dict in prompt_messages] + to_score = [] + + for completion_choice in completions.choices: + # Validate completion + if not completion_choice or not completion_choice.text: + print("⚠️ Skipping invalid completion choice") + continue + + # Create trajectory efficiently + trajectory_messages = base_messages + [ + {"role": "assistant", "content": completion_choice.text.strip()} + ] + + to_score.append((tuple(trajectory_messages), board_data)) + + if not to_score: + print("❌ No valid trajectories created") + return None, [] + + except Exception as e: + print(f"❌ Failed to build trajectories: {e}") + return None, [] + + # Score trajectories with error handling + try: + scored_data = await self.score(to_score) + + if scored_data is None: + print("❌ Scoring returned None") + return None, [] + + # Validate scored data structure + required_keys = ["tokens", "masks", "scores"] + if not all(key in scored_data for key in required_keys): + print(f"❌ Scored data missing required keys: {required_keys}") + return None, [] + + if not scored_data["tokens"] or not scored_data["scores"]: + print("❌ Scored data is empty") + return None, [] + + print(f"✅ Successfully scored {len(scored_data['scores'])} trajectories") + return scored_data, [] + + except Exception as e: + print(f"❌ Scoring failed: {e}") + import traceback + + traceback.print_exc() + return None, [] + + async def score( + self, rollout_group_data: List + ) -> Optional[WordHuntScoredDataGroup]: + """Score the collected trajectories (following Atropos standard pattern). + + Args: + rollout_group_data: List of tuples (trajectory_messages, board_data) from collect_trajectories + + Returns: + ScoredDataGroup with tokens, masks, and scores + """ + if not rollout_group_data or not self.current_item: + return None + + board = self.current_item["board"] + + # Initialize our custom data group with empty lists for each key. + scores = WordHuntScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + + for trajectory_tuple in rollout_group_data: + trajectory_messages, board_data = trajectory_tuple + + if not trajectory_messages or not isinstance(trajectory_messages, tuple): + continue + + # Convert frozenset tuples back to dict format for processing + trajectory_dicts = [] + for msg_frozenset in trajectory_messages: + trajectory_dicts.append(dict(msg_frozenset)) + + # Extract assistant response + assistant_messages = [ + msg + for msg in trajectory_dicts + if isinstance(msg, dict) and msg.get("role") == "assistant" + ] + + if not assistant_messages: + continue + + response = assistant_messages[-1]["content"] + + # Score the response using our solver + normalized_score, metadata = self.solver.score_word_hunt_response( + response, board, self.scoring_system + ) + + # Update training statistics + self.total_games += 1 + self.total_score += metadata["total_score"] + self.total_valid_words += metadata["num_valid_words"] + self.total_invalid_words += metadata["num_invalid_words"] + + # Tokenize the response (following Atropos standard) + tokenized = tokenize_for_trainer(self.tokenizer, trajectory_dicts) + tokens = tokenized["tokens"] + masks = tokenized["masks"] + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(normalized_score) + + return scores if scores["tokens"] else None + + async def evaluate(self) -> None: + """Run evaluation and log metrics. + + This method: + 1. Calculate and log training metrics + 2. Print summary statistics + 3. Reset statistics for next epoch + """ + if self.total_games == 0: + print("⚠️ No games played yet - skipping evaluation") + return + + # Calculate metrics + avg_score = self.total_score / self.total_games + avg_valid_words = self.total_valid_words / self.total_games + avg_invalid_words = self.total_invalid_words / self.total_games + total_words = self.total_valid_words + self.total_invalid_words + accuracy = self.total_valid_words / total_words if total_words > 0 else 0.0 + + # Log to wandb + metrics = { + "eval/total_games": self.total_games, + "eval/avg_score": avg_score, + "eval/avg_valid_words": avg_valid_words, + "eval/avg_invalid_words": avg_invalid_words, + "eval/word_accuracy": accuracy, + "eval/total_score": self.total_score, + "eval/total_valid_words": self.total_valid_words, + "eval/total_invalid_words": self.total_invalid_words, + } + + await self.wandb_log(metrics) + + # Print summary + print("\n📊 Word Hunt Evaluation Summary:") + print(f" Games played: {self.total_games}") + print(f" Average score: {avg_score:.1f}") + print(f" Average valid words: {avg_valid_words:.1f}") + print(f" Average invalid words: {avg_invalid_words:.1f}") + print(f" Word accuracy: {accuracy:.1%}") + print(f" Total score: {self.total_score}") + print(f" Total valid words: {self.total_valid_words}") + print(f" Total invalid words: {self.total_invalid_words}") + + # Reset statistics for next epoch + self.total_games = 0 + self.total_score = 0 + self.total_valid_words = 0 + self.total_invalid_words = 0 + self.current_board_index = 0 + + +if __name__ == "__main__": + WordHuntEnv.cli() diff --git a/environments/community/word_hunt/word_hunt_solver.py b/environments/community/word_hunt/word_hunt_solver.py new file mode 100644 index 00000000..8267f5cb --- /dev/null +++ b/environments/community/word_hunt/word_hunt_solver.py @@ -0,0 +1,153 @@ +""" +Word Hunt Solver for Atropos +""" + +import random +import re +from typing import Dict, List, Optional, Set, Tuple + +try: + from .trie import Trie +except ImportError: + from trie import Trie + + +class WordHuntSolver: + """ + Solves a 4x4 Word Hunt game by finding all valid words on a given board. + + This solver uses a Trie data structure for efficient dictionary lookups and a + recursive backtracking algorithm (Depth-First Search) to find words. + """ + + def __init__(self, dictionary_path: Optional[str] = None): + """ + Initializes the solver, loading the dictionary into a Trie. + + Args: + dictionary_path: The path to the dictionary file. + """ + self.trie = self._load_dictionary(dictionary_path) + + def _load_dictionary(self, dictionary_path: Optional[str]) -> Trie: + """Loads words from a file into the Trie, filtering by length.""" + trie = Trie() + if not dictionary_path: + print("⚠️ No dictionary path provided.") + return trie + try: + with open(dictionary_path, "r") as f: + for word in f: + clean_word = word.strip().upper() + if len(clean_word) >= 3: + trie.insert(clean_word) + print(f"✅ Dictionary loaded from {dictionary_path}") + except FileNotFoundError: + print(f"❌ Dictionary file not found at {dictionary_path}.") + return trie + + def generate_random_board( + self, letter_frequencies: Dict[str, float], board_size: int = 4 + ) -> List[List[str]]: + """ + Generates a random 4x4 board based on letter frequencies. + + Args: + letter_frequencies: A dictionary mapping letters to their frequencies. + board_size: The dimension of the square board (default is 4). + + Returns: + A 4x4 list of lists representing the board. + """ + letters = list(letter_frequencies.keys()) + weights = list(letter_frequencies.values()) + return [ + random.choices(letters, weights=weights, k=board_size) + for _ in range(board_size) + ] + + def solve_board(self, board: List[List[str]]) -> Set[str]: + """Finds all valid words on the board using a Trie-based DFS.""" + found_words = set() + board_size = len(board) + for r in range(board_size): + for c in range(board_size): + self._solve_dfs(board, self.trie.root, r, c, "", set(), found_words) + return found_words + + def _solve_dfs(self, board, node, r, c, path_str, visited, found_words): + board_size = len(board) + if not (0 <= r < board_size and 0 <= c < board_size) or (r, c) in visited: + return + + char = board[r][c] + if char not in node.children: + return + + # Move to the next node in the trie + node = node.children[char] + + # Update path and visited set + path_str += char + visited.add((r, c)) + + # Check if the current path forms a valid word + if node.is_end_of_word and len(path_str) >= 3: + found_words.add(path_str) + + # Recurse on all 8 neighbors + for dr in [-1, 0, 1]: + for dc in [-1, 0, 1]: + if dr == 0 and dc == 0: + continue + self._solve_dfs( + board, node, r + dr, c + dc, path_str, visited, found_words + ) + + # Backtrack: remove the current cell from the visited set for other paths + visited.remove((r, c)) + + def score_word_hunt_response( + self, response: str, board: List[List[str]], scoring_system: Dict[int, int] + ) -> Tuple[float, Dict]: + """ + Scores a model's response by finding all valid words on the board and checking + the response against them. This is a more robust method than checking each + word individually. + """ + all_possible_words = self.solve_board(board) + # Use regex to find all alphabetic words, making parsing more robust. + submitted_words = {word.upper() for word in re.findall(r"[a-zA-Z]+", response)} + + valid_words = submitted_words.intersection(all_possible_words) + invalid_words = submitted_words.difference(all_possible_words) + + total_score = 0 + for word in valid_words: + word_len = len(word) + if word_len in scoring_system: + total_score += scoring_system[word_len] + elif word_len >= 6: # Official scoring for 6+ letter words + total_score += 1400 + (400 * (word_len - 6)) + + max_possible_score = 0 + for word in all_possible_words: + word_len = len(word) + if word_len in scoring_system: + max_possible_score += scoring_system[word_len] + elif word_len >= 6: # Official scoring for 6+ letter words + max_possible_score += 1400 + (400 * (word_len - 6)) + + normalized_score = ( + (total_score / max_possible_score) if max_possible_score > 0 else 0.0 + ) + + metadata = { + "total_score": total_score, + "valid_words": sorted(list(valid_words)), + "invalid_words": sorted(list(invalid_words)), + "num_valid_words": len(valid_words), + "num_invalid_words": len(invalid_words), + } + + return normalized_score, metadata