import json import logging import os import random import re import uuid from typing import Dict, List, Optional, Tuple, Union import wandb from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # Import NLTK words corpus for large-scale word list try: import nltk from nltk.corpus import words from nltk.tokenize import sent_tokenize # Download required NLTK data try: words.words() except LookupError: nltk.download("words") try: sent_tokenize("Test sentence.") except LookupError: nltk.download("punkt") except ImportError: print("Warning: NLTK not available. Please install with: pip install nltk") words = None sent_tokenize = None # Import datasets for OpenWebText try: from datasets import load_dataset except ImportError: print( "Warning: datasets library not available. Please install with: pip install datasets" ) load_dataset = None system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) class LetterCountingConfig(BaseEnvConfig): """Configuration class for Letter Counting Environment with custom parameters.""" # Word dataset configuration min_word_length: int = Field(3, description="Minimum word length to include") max_word_length: int = Field(30, description="Maximum word length to include") train_test_split: float = Field(0.95, description="Ratio for train/test split") # Letter selection configuration use_all_letters: bool = Field( True, description="Whether to use all 26 letters or custom set" ) custom_letters: str = Field( "aeiou", description="Custom letter set if use_all_letters=False" ) present_letter_bias: float = Field( 0.5, description="Probability of choosing letters that are present in the text (0.0-1.0)", ) # Generation configuration generation_temperature: float = Field( 1.0, description="Temperature for training generation" ) eval_temperature: float = Field( 0.2, description="Temperature for evaluation generation" ) max_generation_tokens: int = Field( 1024 * 15, description="Maximum tokens for model generation" ) # Evaluation configuration eval_sample_size: int = Field( 1000, description="Number of test words to evaluate on" ) # Reproducibility configuration random_seed: Optional[int] = Field( 42, description="Seed for reproducibility, None for random" ) # Random string generation configuration random_string_percentage: float = Field( 0.03, description="Percentage of dataset to be random strings (0.0-1.0)" ) random_string_min_length: int = Field( 3, description="Minimum length for random strings" ) random_string_max_length: int = Field( 25, description="Maximum length for random strings" ) # Word capitalization configuration uppercase_word_percentage: float = Field( 0.01, description="Percentage of real words to make uppercase (0.0-1.0)" ) capitalized_word_percentage: float = Field( 0.01, description="Percentage of real words to capitalize first letter (0.0-1.0)", ) # Text/passage configuration use_text_passages: bool = Field( False, description="Include text passages from OpenWebText in addition to words" ) text_passage_percentage: float = Field( 0.5, description="Percentage of dataset to be text passages when use_text_passages=True (0.0-1.0)", ) min_text_length: int = Field( 50, description="Minimum character length for text passages" ) max_text_length: int = Field( 500, description="Maximum character length for text passages" ) include_punctuation_in_count: bool = Field( True, description="Include punctuation in letter counting" ) include_spaces_in_count: bool = Field( False, description="Include spaces in letter counting" ) # Multi-letter counting configuration max_letters_to_count: int = Field( 1, description="Maximum number of different letters to count simultaneously (1 for single letter)", ) multi_letter_probability: float = Field( 0.2, description="Probability of asking for multiple letters (0.0-1.0)" ) # Difficulty and training thresholds max_group_average_for_training: float = Field( 1.0, description="Maximum group average to use for training (skip groups that are too easy)", ) # Logging and data dumping configuration debug_logging: bool = Field( True, description="Enable debug-level logging for more verbose output" ) suppress_base_env_logs: bool = Field( True, description="Suppress verbose base environment logs" ) dump_rollouts: bool = Field( False, description="Whether to dump successful rollouts to JSONL files" ) dump_batch_size: int = Field( 100, description="Number of groups to accumulate before saving to disk (1 = save immediately)", ) class LetterCountingEnv(BaseEnv): """ Letter Counting Environment for training models to count letters in words and sentences. This environment presents the model with questions like "How many 'a's are in the word 'banana'?" or "Count the occurrences of the letters 'e', 'o', and 't' in the following text: 'The quick brown fox jumps over the lazy dog'" and expects responses in the format 3 for single letters or {"e": 4, "o": 4, "t": 2} for multiple letters. The model should use tags for reasoning before providing the final answer. Features: - **Word Mode**: Uses NLTK's words corpus (236k+ English words) - **Mixed Mode**: Combines words and text passages from OpenWebText-10k dataset - **Text Passage Mode**: Uses OpenWebText-10k dataset with character-based text extraction - Optional random string generation (80% alphabetical) mixed with real words - Configurable word/string/passage length ranges and letter sets - Optional word capitalization (uppercase, title case) - **Multi-letter counting**: Configurable simultaneous counting of multiple letters with JSON responses - **Letter selection bias**: Configurable bias toward letters present in the text (reduces zero-count questions) - Training thresholds based on group average scores - Configurable punctuation and space handling for letter counting - Comprehensive logging and data dumping capabilities - Detailed metrics tracking (letter distribution, text lengths, error rates, group average scores) - Support for saving successful and failed rollouts for analysis Data Dumping: - Set dump_rollouts=True to save rollouts from groups with appropriate difficulty - Only saves groups where group average score ≤ max_group_average_for_training - Files saved to data_dumps/ directory with unique UUIDs - Rollouts include full conversations, scores, metadata - Configurable batch size (dump_batch_size) for efficient disk I/O Mixed Mode Configuration: - Set use_text_passages=True to enable mixed mode with both words and text passages - Configure text_passage_percentage to control the ratio (e.g., 0.3 = 30% passages, 70% words) - Configure min/max text passage character lengths (more reliable than word counts) - Set max_group_average_for_training to skip groups that are too easy Logging: - Set debug_logging=True for verbose per-item scoring details - Comprehensive WandB metrics including letter distribution entropy and group average scores - Progress tracking for data dumps and evaluation """ # noqa name = "letter_counting" env_config_cls = LetterCountingConfig def __init__( self, config: LetterCountingConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): """ Initialize the Letter Counting environment. Args: config: Configuration for the base environment server_configs: List of server configurations for OpenAI API slurm: Whether to use Slurm for distributed training testing: Whether in testing mode """ # noqa: E501 super().__init__(config, server_configs, slurm, testing) # Initialize data dumping infrastructure first (needed for logging) self.run_uuid = str(uuid.uuid4()) self.rollouts_to_save_buffer: List[ Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]] ] = [] self.processed_item_count = 0 self.datadumps_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "data_dumps" ) self.save_file_batch_num = 0 # Additional metrics tracking self.letter_distribution_stats: Dict[str, int] = {} self.word_length_stats: Dict[int, int] = {} self.answer_format_errors = 0 self.think_format_errors = 0 # Initialize the logger self.logger = logging.getLogger(self.__class__.__name__) if not self.logger.handlers: # Add a basic stream handler if no handlers are configured _handler = logging.StreamHandler() _formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) _handler.setFormatter(_formatter) self.logger.addHandler(_handler) # Set logging level based on config log_level = logging.DEBUG if self.config.debug_logging else logging.INFO self.logger.setLevel(log_level) # Ensure the logger itself is enabled self.logger.disabled = False # Suppress base environment logs if requested if self.config.suppress_base_env_logs: # Set the base environment logger to WARNING level to suppress INFO logs base_logger = logging.getLogger("atroposlib.envs.base") base_logger.setLevel(logging.WARNING) # Log initialization completion self.logger.info( f"LetterCountingEnv initialized with run UUID: {self.run_uuid}" ) self.logger.info( f"Debug logging: {'enabled' if self.config.debug_logging else 'disabled'}" ) self.logger.info( f"Data dumping: rollouts={'enabled' if self.config.dump_rollouts else 'disabled'}" ) self.percent_correct_buffer = list() self.eval_metrics = list() self.rollouts_for_wandb: List[List[Tuple[str, float, str, str, str]]] = [] @classmethod def config_init(self) -> Tuple[LetterCountingConfig, List[APIServerConfig]]: env_config = LetterCountingConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, use_wandb=True, max_num_workers=128, rollout_server_url="http://localhost:8000", total_steps=250, batch_size=1024, steps_per_eval=20, max_token_length=1024 * 15, inference_weight=1.0, wandb_name="letter_counting_deep_thinking", data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, # Letter counting specific configs min_word_length=3, max_word_length=30, train_test_split=0.95, eval_sample_size=1000, generation_temperature=1.0, eval_temperature=0.5, random_seed=42, use_all_letters=True, custom_letters="aeiou", present_letter_bias=0.5, max_generation_tokens=1024 * 15, # Random string generation random_string_percentage=0.0, random_string_min_length=3, random_string_max_length=15, # Word capitalization uppercase_word_percentage=0.01, capitalized_word_percentage=0.005, # Text passage configuration use_text_passages=True, text_passage_percentage=0.3, min_text_length=3, max_text_length=2000, include_punctuation_in_count=True, include_spaces_in_count=True, max_group_average_for_training=0.7, # Multi-letter counting max_letters_to_count=4, multi_letter_probability=0.2, debug_logging=True, dump_rollouts=True, dump_batch_size=100, ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ) ] return env_config, server_configs async def setup(self): """ Set up the environment by loading and preparing the word/text dataset. """ if self.config.use_text_passages: await self._setup_mixed_dataset() else: await self._setup_word_dataset() # Initialize iteration counter self.iter = 0 async def _setup_mixed_dataset(self): """ Set up the environment using both words and text passages from OpenWebText dataset. """ if load_dataset is None: raise ImportError( "datasets library is required for text passage mode. Please install with: pip install datasets" ) if words is None: raise ImportError( "NLTK is required for this environment. Please install with: pip install nltk" ) # Set random seed for reproducibility if configured if self.config.random_seed is not None: random.seed(self.config.random_seed) # Validate configuration await self._validate_config() self.logger.info("Setting up mixed dataset with words and text passages...") # First, set up words (same as word-only mode) all_words = words.words() filtered_words = [ word.lower() for word in all_words if word.isalpha() and self.config.min_word_length <= len(word) <= self.config.max_word_length ] # Apply capitalization to real words if configured filtered_words = self._apply_word_capitalization(filtered_words) # Generate random strings if configured if self.config.random_string_percentage > 0.0: total_filtered_words = len(filtered_words) if self.config.random_string_percentage >= 1.0: num_random_strings = ( total_filtered_words if total_filtered_words > 0 else 1000 ) filtered_words = [] else: num_random_strings = int( (self.config.random_string_percentage * total_filtered_words) / (1.0 - self.config.random_string_percentage) ) random_strings = self._generate_random_strings(num_random_strings) all_word_strings = filtered_words + random_strings self.logger.info( f"Generated {num_random_strings} random strings ({self.config.random_string_percentage:.1%} of word dataset)" # noqa ) else: all_word_strings = filtered_words random_strings = [] self.logger.info( f"Prepared {len(all_word_strings)} word/string items ({len(filtered_words)} real words + {len(random_strings)} random strings)" # noqa ) # Now load and process text passages self.logger.info("Loading OpenWebText-10k dataset for text passages...") dataset = load_dataset("stas/openwebtext-10k", split="train") self.logger.info(f"Loaded {len(dataset)} text samples from OpenWebText") # Extract and filter text passages all_passages = [] processed_texts = 0 for item in dataset: text = item["text"] # Skip texts that are too long initially if ( len(text) > self.config.max_text_length * 3 ): # Allow some overhead for chunking continue # Extract passages from this text passages = self._extract_text_passages(text) all_passages.extend(passages) processed_texts += 1 # Log progress periodically if processed_texts % 1000 == 0: self.logger.info( f"Processed {processed_texts} texts, extracted {len(all_passages)} passages so far..." ) self.logger.info( f"Extracted {len(all_passages)} text passages from {processed_texts} texts" ) # Now mix words and passages according to text_passage_percentage if self.config.text_passage_percentage >= 1.0: # Special case: 100% text passages, no words all_mixed_items = all_passages final_word_count = 0 final_passage_count = len(all_passages) elif self.config.text_passage_percentage <= 0.0: # Special case: 0% text passages, only words all_mixed_items = all_word_strings final_word_count = len(all_word_strings) final_passage_count = 0 else: # Calculate how many passages to include # If we want X% passages, then passages / (words + passages) = X # Solving: passages_to_use = X * words / (1 - X) total_words = len(all_word_strings) num_passages_to_use = int( (self.config.text_passage_percentage * total_words) / (1.0 - self.config.text_passage_percentage) ) # Upsample passages if we don't have enough to reach the target percentage if num_passages_to_use > len(all_passages): self.logger.info( f"Upsampling passages: need {num_passages_to_use} but only have {len(all_passages)}. " f"Will repeat passages to reach target percentage." ) # Calculate how many times we need to repeat the passage list repeat_factor = (num_passages_to_use // len(all_passages)) + 1 upsampled_passages = all_passages * repeat_factor # Shuffle the upsampled passages to avoid patterns random.shuffle(upsampled_passages) # Take exactly the number we need passages_to_use = upsampled_passages[:num_passages_to_use] self.logger.info( f"Upsampled {len(all_passages)} unique passages {repeat_factor}x times, " f"then sampled {len(passages_to_use)} passages for the dataset." ) else: # We have enough passages, just sample what we need passages_to_use = random.sample(all_passages, num_passages_to_use) # Combine words and passages all_mixed_items = all_word_strings + passages_to_use final_word_count = len(all_word_strings) final_passage_count = len(passages_to_use) # Shuffle the mixed dataset random.shuffle(all_mixed_items) # Create train/test split split_point = int(self.config.train_test_split * len(all_mixed_items)) self.train_words = all_mixed_items[ :split_point ] # Reusing train_words for mixed items self.test_words = all_mixed_items[split_point:] # Calculate actual percentages actual_passage_percentage = ( final_passage_count / len(all_mixed_items) if len(all_mixed_items) > 0 else 0.0 ) # Calculate unique passages for logging unique_passages_used = ( len(set(passages_to_use)) if "passages_to_use" in locals() else 0 ) # Log dataset statistics self.logger.info("Mixed dataset created:") if unique_passages_used < final_passage_count: # Upsampling occurred self.logger.info( f" Total items: {len(all_mixed_items)} ({final_word_count} words/strings + {final_passage_count} passages)" # noqa ) self.logger.info( f" Unique passages used: {unique_passages_used} (repeated to create {final_passage_count} total passage instances)" # noqa ) else: # No upsampling self.logger.info( f" Total items: {len(all_mixed_items)} ({final_word_count} words/strings + {final_passage_count} passages)" # noqa ) self.logger.info( f" Actual passage percentage: {actual_passage_percentage:.1%} (target: {self.config.text_passage_percentage:.1%})" # noqa ) self.logger.info(f" Training items: {len(self.train_words)}") self.logger.info(f" Test items: {len(self.test_words)}") # Show examples of both types word_examples = [item for item in self.train_words[:10] if len(item) <= 50][:3] passage_examples = [ item[:100] + "..." for item in self.train_words[:50] if len(item) > 50 ][:3] if word_examples: self.logger.info(f" Example words: {word_examples}") if passage_examples: self.logger.info(f" Example passages: {passage_examples}") # If we have upsampled passages, show some statistics if unique_passages_used < final_passage_count and unique_passages_used > 0: avg_repetitions = final_passage_count / unique_passages_used self.logger.info(f" Average passage repetitions: {avg_repetitions:.1f}x") # Log configuration details self.logger.info( f"Word length range: {self.config.min_word_length}-{self.config.max_word_length}" ) self.logger.info( f"Passage length range: {self.config.min_text_length}-{self.config.max_text_length} characters" ) self.logger.info( f"Include punctuation: {self.config.include_punctuation_in_count}" ) self.logger.info(f"Include spaces: {self.config.include_spaces_in_count}") self.logger.info( f"Training threshold: {self.config.max_group_average_for_training}" ) self.logger.info("Mixed dataset setup complete") async def _setup_text_passage_dataset(self): """ Set up the environment using text passages from OpenWebText dataset. """ if load_dataset is None: raise ImportError( "datasets library is required for text passage mode. Please install with: pip install datasets" ) # Set random seed for reproducibility if configured if self.config.random_seed is not None: random.seed(self.config.random_seed) # Validate configuration await self._validate_config() self.logger.info("Loading OpenWebText-10k dataset...") # Load the dataset dataset = load_dataset("stas/openwebtext-10k", split="train") self.logger.info(f"Loaded {len(dataset)} text samples from OpenWebText") # Extract and filter text passages all_passages = [] processed_texts = 0 for item in dataset: text = item["text"] # Skip texts that are too long initially if ( len(text) > self.config.max_text_length * 3 ): # Allow some overhead for chunking continue # Extract passages from this text passages = self._extract_text_passages(text) all_passages.extend(passages) processed_texts += 1 # Log progress periodically if processed_texts % 1000 == 0: self.logger.info( f"Processed {processed_texts} texts, extracted {len(all_passages)} passages so far..." ) self.logger.info( f"Extracted {len(all_passages)} text passages from {processed_texts} texts" ) # Shuffle passages for randomness random.shuffle(all_passages) # Create train/test split split_point = int(self.config.train_test_split * len(all_passages)) self.train_words = all_passages[ :split_point ] # Reusing train_words for passages self.test_words = all_passages[split_point:] # Log dataset statistics self.logger.info(f"Training passages: {len(self.train_words)}") self.logger.info(f"Test passages: {len(self.test_words)}") self.logger.info( f"Example passages: {[p[:100] + '...' for p in self.train_words[:3]]}" ) # Log configuration details self.logger.info( f"Passage length range: {self.config.min_text_length}-{self.config.max_text_length} characters" ) self.logger.info( f"Include punctuation: {self.config.include_punctuation_in_count}" ) self.logger.info(f"Include spaces: {self.config.include_spaces_in_count}") self.logger.info( f"Training threshold: {self.config.max_group_average_for_training}" ) self.logger.info("Text passage dataset setup complete") async def _setup_word_dataset(self): """ Set up the environment using single words (original functionality). """ # Load the NLTK words corpus (contains 236,736 English words) if words is None: raise ImportError( "NLTK is required for this environment. Please install with: pip install nltk" ) # Set random seed for reproducibility if configured if self.config.random_seed is not None: random.seed(self.config.random_seed) # Validate configuration await self._validate_config() # Get all English words from NLTK all_words = words.words() # Filter words to ensure they contain only alphabetic characters # and are within the configured length range for reasonable difficulty filtered_words = [ word.lower() for word in all_words if word.isalpha() and self.config.min_word_length <= len(word) <= self.config.max_word_length ] # Apply capitalization to real words if configured filtered_words = self._apply_word_capitalization(filtered_words) # Generate random strings if configured if self.config.random_string_percentage > 0.0: # Calculate how many random strings to generate total_filtered_words = len(filtered_words) # If we want X% random strings, then random_strings / (words + random_strings) = X # Solving: random_strings = X * words / (1 - X) if self.config.random_string_percentage >= 1.0: # Special case: 100% random strings, no real words num_random_strings = ( total_filtered_words if total_filtered_words > 0 else 1000 ) filtered_words = [] else: num_random_strings = int( (self.config.random_string_percentage * total_filtered_words) / (1.0 - self.config.random_string_percentage) ) random_strings = self._generate_random_strings(num_random_strings) # Combine real words and random strings all_strings = filtered_words + random_strings self.logger.info( f"Generated {num_random_strings} random strings ({self.config.random_string_percentage:.1%} of dataset)" ) else: # No random strings, use only real words all_strings = filtered_words random_strings = [] # Shuffle all strings for randomness random.shuffle(all_strings) # Create train/test split using configured ratio split_point = int(self.config.train_test_split * len(all_strings)) self.train_words = all_strings[:split_point] self.test_words = all_strings[split_point:] # Log dataset statistics if self.config.random_string_percentage > 0.0: self.logger.info( f"Total dataset: {len(all_strings)} strings ({len(filtered_words)} real words + {len(random_strings)} random strings)" # noqa ) else: self.logger.info(f"Loaded {len(all_strings)} words total") self.logger.info(f"Training strings: {len(self.train_words)}") self.logger.info(f"Test strings: {len(self.test_words)}") self.logger.info(f"Example strings: {self.train_words[:10]}") # Log configuration details self.logger.info( f"Word length range: {self.config.min_word_length}-{self.config.max_word_length}" ) if self.config.random_string_percentage > 0.0: self.logger.info( f"Random string length range: {self.config.random_string_min_length}-{self.config.random_string_max_length}" # noqa ) self.logger.info(f"Train/test split: {self.config.train_test_split:.2%}") self.logger.info( f"Letter set: {'all 26 letters' if self.config.use_all_letters else f'custom ({self.config.custom_letters})'}" # noqa ) self.logger.info( f"Random strings: {self.config.random_string_percentage:.1%} of dataset" ) # Log capitalization settings if ( self.config.uppercase_word_percentage > 0.0 or self.config.capitalized_word_percentage > 0.0 ): self.logger.info( f"Word capitalization: {self.config.uppercase_word_percentage:.1%} uppercase, {self.config.capitalized_word_percentage:.1%} title case" # noqa ) self.logger.info(f"Random seed: {self.config.random_seed}") # Log letter selection bias configuration self.logger.info( f"Letter selection bias: {self.config.present_letter_bias:.1%} toward letters present in text" ) # Log data dumping configuration if self.config.dump_rollouts: self.logger.info( "Data dumping enabled - saving groups with appropriate difficulty for training" ) self.logger.info( f"Group difficulty threshold: group average ≤ {self.config.max_group_average_for_training}" ) self.logger.info(f"Data dumps directory: {self.datadumps_dir}") self.logger.info( f"Batch size: {self.config.dump_batch_size} groups per file" ) self.logger.info("Letter counting environment setup complete") async def _save_rollouts_to_jsonl(self): """Saves the buffered rollouts to a JSONL file in the datadumps directory.""" if not self.rollouts_to_save_buffer: self.logger.warning("_save_rollouts_to_jsonl called but buffer is empty!") return buffer_size = len(self.rollouts_to_save_buffer) self.logger.info(f"Starting save of {buffer_size} groups to JSONL file...") try: if not os.path.exists(self.datadumps_dir): os.makedirs(self.datadumps_dir) self.logger.info(f"Created directory: {self.datadumps_dir}") except OSError as e: self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}") return file_path = os.path.join( self.datadumps_dir, f"letter_counting_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl", ) try: with open(file_path, "w") as f: for rollout_dict in self.rollouts_to_save_buffer: json.dump(rollout_dict, f) f.write("\n") self.logger.info( f"Successfully saved {buffer_size} groups to {file_path} " f"(batch #{self.save_file_batch_num})" ) self.rollouts_to_save_buffer.clear() self.save_file_batch_num += 1 self.logger.info( f"Buffer cleared. Next batch will be #{self.save_file_batch_num}" ) except IOError as e: self.logger.error(f"Error writing rollouts to {file_path}: {e}") except Exception as e: self.logger.error( f"An unexpected error occurred while saving rollouts to {file_path}: {e}" ) def _get_letter_set(self): """ Get the set of letters to choose from based on configuration. Returns: String containing letters to choose from (lowercase only for consistency) """ if not self.config.use_all_letters: return self.config.custom_letters.lower() else: return "abcdefghijklmnopqrstuvwxyz" def _select_target_letters(self, text: str, num_letters: int) -> List[str]: """ Select target letters with bias toward letters present in the text. Args: text: The text to analyze for present letters num_letters: Number of letters to select Returns: List of selected target letters """ available_letters = list(self._get_letter_set()) # Prepare text for counting to match the actual counting logic text_for_counting = self._prepare_text_for_counting(text) # Find letters that are present in the text (case-insensitive) present_letters = [] absent_letters = [] for letter in available_letters: if letter.lower() in text_for_counting.lower(): present_letters.append(letter) else: absent_letters.append(letter) # If we need more letters than are present, we'll need some absent ones too if num_letters > len(present_letters): # Select all present letters and fill the rest randomly from absent letters selected_letters = present_letters.copy() remaining_needed = num_letters - len(present_letters) if remaining_needed > 0 and absent_letters: selected_letters.extend( random.sample( absent_letters, min(remaining_needed, len(absent_letters)) ) ) return selected_letters[:num_letters] # We have enough present letters, so apply the bias selected_letters = [] for _ in range(num_letters): # Decide whether to pick from present or absent letters based on bias if present_letters and ( not absent_letters or random.random() < self.config.present_letter_bias ): # Choose from present letters chosen_letter = random.choice(present_letters) present_letters.remove(chosen_letter) elif absent_letters: # Choose from absent letters chosen_letter = random.choice(absent_letters) absent_letters.remove(chosen_letter) else: # Fallback: choose from any remaining available letters remaining_letters = [ l for l in available_letters if l not in selected_letters # noqa ] if remaining_letters: chosen_letter = random.choice(remaining_letters) else: break # No more letters available selected_letters.append(chosen_letter) return selected_letters def _generate_random_string(self, length: int) -> str: """ Generate a random string of specified length with at least 80% alphabetical characters. Args: length: Length of the string to generate Returns: Random string with mix of uppercase, lowercase, and some non-alphabetical chars """ # Ensure at least 80% alphabetical characters min_alpha_chars = max(1, int(length * 0.8)) # Character sets letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" non_alpha = "0123456789!@#$%^&*()-_=+[]{}|;:,.<>?" result = [] # First, add the required alphabetical characters for _ in range(min_alpha_chars): result.append(random.choice(letters)) # Fill the rest with either alphabetical or non-alphabetical characters remaining_length = length - min_alpha_chars for _ in range(remaining_length): # 90% chance of alphabetical even for remaining chars (to exceed 80% minimum) if random.random() < 0.9: result.append(random.choice(letters)) else: result.append(random.choice(non_alpha)) # Shuffle to avoid alphabetical chars being clustered at the beginning random.shuffle(result) return "".join(result) def _generate_random_strings(self, count: int) -> List[str]: """ Generate a list of random strings with lengths within the configured random string range. Args: count: Number of random strings to generate Returns: List of random strings """ random_strings = [] for _ in range(count): # Generate random length within configured random string range length = random.randint( self.config.random_string_min_length, self.config.random_string_max_length, ) random_string = self._generate_random_string(length) random_strings.append(random_string) return random_strings def _apply_word_capitalization(self, words: List[str]) -> List[str]: """ Apply capitalization transformations to real words based on configuration. Args: words: List of lowercase words Returns: List of words with applied capitalization """ if ( self.config.uppercase_word_percentage == 0.0 and self.config.capitalized_word_percentage == 0.0 ): return words result = [] for word in words: rand_val = random.random() if rand_val < self.config.uppercase_word_percentage: # Make uppercase result.append(word.upper()) elif ( rand_val < self.config.uppercase_word_percentage + self.config.capitalized_word_percentage ): # Capitalize first letter result.append(word.capitalize()) else: # Keep lowercase result.append(word) return result def _extract_text_passages(self, text: str) -> List[str]: """ Extract text passages from raw text based on character length. Args: text: Raw text to extract passages from Returns: List of filtered text passages """ # Clean the text - remove excessive whitespace and normalize text = re.sub(r"\s+", " ", text.strip()) # If text is shorter than min length, skip it if len(text) < self.config.min_text_length: return [] # If text is within range, use it as-is if len(text) <= self.config.max_text_length: return [text] # For longer texts, create overlapping chunks passages = [] chunk_size = self.config.max_text_length overlap = min(50, chunk_size // 4) # 25% overlap, max 50 chars start = 0 while start < len(text): end = start + chunk_size # If this would be the last chunk and it's too small, extend the previous chunk if end >= len(text): if len(text) - start >= self.config.min_text_length: passages.append(text[start:]) break # Try to break at a natural boundary (space, punctuation) chunk = text[start:end] # Look for a good break point in the last 20% of the chunk break_start = int(len(chunk) * 0.8) break_candidates = [] # Find sentence endings first for i in range(len(chunk) - 1, break_start - 1, -1): if chunk[i] in ".!?": break_candidates.append(i + 1) break # If no sentence ending, look for other punctuation if not break_candidates: for i in range(len(chunk) - 1, break_start - 1, -1): if chunk[i] in ",;:": break_candidates.append(i + 1) break # If no punctuation, look for spaces if not break_candidates: for i in range(len(chunk) - 1, break_start - 1, -1): if chunk[i] == " ": break_candidates.append(i) break # Use the break point if found, otherwise use the full chunk if break_candidates: actual_end = start + break_candidates[0] passage = text[start:actual_end].strip() else: passage = chunk.strip() actual_end = end # Only add if it meets minimum length if len(passage) >= self.config.min_text_length: passages.append(passage) # Move start position with overlap start = actual_end - overlap # Avoid infinite loops if start >= end - overlap: start = end return passages def _prepare_text_for_counting(self, text: str) -> str: """ Prepare text for letter counting based on configuration. Args: text: Original text Returns: Processed text for counting """ if not self.config.include_punctuation_in_count: # Remove punctuation but keep spaces and alphanumeric text = "".join(c for c in text if c.isalnum() or c.isspace()) if not self.config.include_spaces_in_count: # Remove spaces text = text.replace(" ", "") return text async def _validate_config(self): """Validate configuration parameters.""" if not (0.0 <= self.config.random_string_percentage <= 1.0): raise ValueError( f"random_string_percentage must be between 0.0 and 1.0, got {self.config.random_string_percentage}" ) # noqa if not (0.0 <= self.config.uppercase_word_percentage <= 1.0): raise ValueError( f"uppercase_word_percentage must be between 0.0 and 1.0, got {self.config.uppercase_word_percentage}" ) # noqa if not (0.0 <= self.config.capitalized_word_percentage <= 1.0): raise ValueError( f"capitalized_word_percentage must be between 0.0 and 1.0, got {self.config.capitalized_word_percentage}" # noqa ) # noqa if ( self.config.uppercase_word_percentage + self.config.capitalized_word_percentage > 1.0 ): raise ValueError( f"Sum of uppercase_word_percentage ({self.config.uppercase_word_percentage}) and capitalized_word_percentage ({self.config.capitalized_word_percentage}) cannot exceed 1.0" # noqa ) # noqa if self.config.random_string_min_length < 1: raise ValueError( f"random_string_min_length must be at least 1, got {self.config.random_string_min_length}" ) # noqa if self.config.random_string_max_length < self.config.random_string_min_length: raise ValueError( f"random_string_max_length ({self.config.random_string_max_length}) must be >= random_string_min_length ({self.config.random_string_min_length})" # noqa ) # noqa if self.config.use_text_passages: if not (0.0 <= self.config.text_passage_percentage <= 1.0): raise ValueError( f"text_passage_percentage must be between 0.0 and 1.0, got {self.config.text_passage_percentage}" # noqa ) if self.config.min_text_length < 1: raise ValueError( f"min_text_length must be at least 1, got {self.config.min_text_length}" # noqa ) if self.config.max_text_length < self.config.min_text_length: raise ValueError( f"max_text_length ({self.config.max_text_length}) must be >= min_text_length ({self.config.min_text_length})" # noqa ) if self.config.max_text_length < 10: raise ValueError( f"max_text_length must be at least 10 characters, got {self.config.max_text_length}" # noqa ) if self.config.max_letters_to_count < 1: raise ValueError( f"max_letters_to_count must be at least 1, got {self.config.max_letters_to_count}" ) if not (0.0 <= self.config.multi_letter_probability <= 1.0): raise ValueError( f"multi_letter_probability must be between 0.0 and 1.0, got {self.config.multi_letter_probability}" ) if self.config.max_letters_to_count > 26: raise ValueError( f"max_letters_to_count cannot exceed 26 (total letters), got {self.config.max_letters_to_count}" ) if not (0.0 <= self.config.present_letter_bias <= 1.0): raise ValueError( f"present_letter_bias must be between 0.0 and 1.0, got {self.config.present_letter_bias}" ) if self.config.dump_batch_size < 1: raise ValueError( f"dump_batch_size must be at least 1, got {self.config.dump_batch_size}" ) def save_checkpoint(self, step, data=None): """Save checkpoint including current iteration number, statistics, and data dumping state.""" if data is None: data = {} data["iter"] = self.iter data["processed_item_count"] = self.processed_item_count data["save_file_batch_num"] = self.save_file_batch_num data["letter_distribution_stats"] = self.letter_distribution_stats data["word_length_stats"] = self.word_length_stats data["answer_format_errors"] = self.answer_format_errors data["think_format_errors"] = self.think_format_errors super().save_checkpoint(step, data) def load_checkpoint(self): """Load checkpoint including iteration number, statistics, and data dumping state.""" # Call the base class method first to load the data super().load_checkpoint() # The base class loads data into attributes, so we can access them directly # if they were saved in save_checkpoint if hasattr(self, "iter"): # Restore statistics if available if ( hasattr(self, "letter_distribution_stats") and self.letter_distribution_stats ): total_letters = sum(self.letter_distribution_stats.values()) self.logger.info( f"Restored letter distribution stats with {total_letters} total letters" ) if hasattr(self, "word_length_stats") and self.word_length_stats: total_words = sum(self.word_length_stats.values()) self.logger.info( f"Restored word length stats with {total_words} total words" ) if hasattr(self, "answer_format_errors"): self.logger.info( f"Restored error counts: {self.answer_format_errors} answer format errors, {self.think_format_errors} think format errors" # noqa ) # noqa async def close(self): """Clean up and save any remaining rollouts before exiting.""" self.logger.info( "Closing LetterCountingEnv. Attempting to save any remaining rollouts..." ) if self.config.dump_rollouts and self.rollouts_to_save_buffer: self.logger.info( f"FINAL SAVE: Found {len(self.rollouts_to_save_buffer)} groups in buffer. " f"Saving final batch to disk (batch #{self.save_file_batch_num})..." ) await self._save_rollouts_to_jsonl() self.logger.info("Final save completed successfully.") elif self.config.dump_rollouts: self.logger.info("No rollouts in buffer to save upon closing.") else: self.logger.info("Data dumping is disabled - no rollouts to save.") # Call the superclass's close method if it exists if hasattr(super(), "close"): await super().close() self.logger.info("LetterCountingEnv closed.") async def get_next_item(self): """ Get the next training item from the dataset. Returns: A tuple containing prompt and expected answer """ # Get the next text from training set (could be a word, sentence, or random string) text = self.train_words[self.iter % len(self.train_words)] # Decide whether to use multiple letters use_multiple = ( self.config.max_letters_to_count > 1 and random.random() < self.config.multi_letter_probability ) if use_multiple: # Choose 2 to max_letters_to_count different letters num_letters = random.randint(2, self.config.max_letters_to_count) target_letters = self._select_target_letters(text, num_letters) else: # Single letter counting target_letters = self._select_target_letters(text, 1) # Prepare text for counting (handle punctuation/spaces based on config) text_for_counting = self._prepare_text_for_counting(text) # Count occurrences for each target letter (case-insensitive) correct_counts = {} for letter in target_letters: correct_counts[letter] = text_for_counting.lower().count(letter.lower()) # Determine if this is a text passage or word/string based on length and content is_text_passage = ( len(text) > 50 or " " in text or any(c in text for c in ".,!?;:") ) # Log item selection details for every item text_type = "passage" if is_text_passage else "word/string" text_preview = text[:50] + "..." if len(text) > 50 else text letters_str = ", ".join(target_letters) counts_str = ", ".join( f"{letter}:{correct_counts[letter]}" for letter in target_letters ) # Add information about present vs absent letters for bias analysis present_count = sum( 1 for letter in target_letters if correct_counts[letter] > 0 ) absent_count = len(target_letters) - present_count bias_info = f"Present: {present_count}, Absent: {absent_count}" self.logger.info( f"Selected {text_type}: '{text_preview}' | Letters: [{letters_str}] | Counts: [{counts_str}] | {bias_info} (iteration {self.iter})" # noqa ) self.iter += 1 # Create the question based on whether this item is a text passage or word/string and single/multiple letters if len(target_letters) == 1: # Single letter question target_letter = target_letters[0] if is_text_passage: question_text = ( f'How many {target_letter}s are in the following text: "{text}"?' ) else: question_text = f"How many {target_letter}s are in the string {text}?" # Add instruction for single letter answer format question_with_instruction = f"{question_text}\n\nProvide your answer in the format: {{number}}" # noqa else: # Multiple letters question letters_str = ( ", ".join(f"'{letter}'" for letter in target_letters[:-1]) + f", and '{target_letters[-1]}'" ) if is_text_passage: question_text = f'Count the occurrences of the letters {letters_str} in the following text: "{text}"' else: question_text = f"Count the occurrences of the letters {letters_str} in the string {text}" # Add instruction for multiple letter JSON answer format example_json = ( "{" + ", ".join(f'"{letter}": 0' for letter in target_letters) + "}" ) question_with_instruction = f"{question_text}\n\nProvide your answer as JSON in the format: {example_json}" # noqa # Create prompt tuple using frozensets as required prompt = [] # Add system prompt prompt.append(frozenset({"role": "system", "content": system_prompt}.items())) # Add user message with the question prompt.append( frozenset({"role": "user", "content": question_with_instruction}.items()) ) # Return the prompt, correct counts, text, and target letters return (tuple(prompt), correct_counts, text, target_letters) async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: """ Generate and collect model responses for scoring. Args: item: Input item containing prompt and expected answer Returns: Tuple of lists containing scored data groups and backlog """ # Extract messages from the item messages = [] for role_dict in item[0]: messages.append(dict(role_dict)) # Apply chat template to convert messages to a single string prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Get completions from the model completions = await self.server.completion( prompt=prompt, n=self.config.group_size, max_tokens=self.config.max_generation_tokens, temperature=self.config.generation_temperature, ) to_score = list() for i, completion_choice in enumerate(completions.choices): # Create a copy of the prompt messages trajectory_messages = [] for role_dict in item[0]: trajectory_messages.append(dict(role_dict)) # Add the model's response trajectory_messages.append( {"role": "assistant", "content": completion_choice.text} ) # Add to scoring queue with expected answer and metadata to_score.append( ( tuple(trajectory_messages), item[1], # correct_counts (dict) item[2], # text (word or sentence) item[3], # target_letters (list) completion_choice.finish_reason, # stop reason ) ) # Call score to get the scored data scored_data = await self.score(to_score) # Data dumping logic - save groups that meet the accuracy threshold BEFORE training filtering # This ensures we save high-quality data even if it's too easy for training if self.config.dump_rollouts: # First, we need to score the data to get the scores, but we'll do our own scoring for data dumping temp_scores = [] for score_item in to_score: # Extract the model's response and score it model_response = score_item[0][-1]["content"] stop_reason = score_item[4] expected_counts = score_item[1] target_letters = score_item[3] # Handle legacy format if isinstance(target_letters, str): target_letters = [target_letters] expected_counts = {target_letters[0]: expected_counts} elif isinstance(expected_counts, int): expected_counts = {target_letters[0]: expected_counts} # Score this individual response if stop_reason == "length": temp_scores.append(0.0) else: expected_format = "single" if len(target_letters) == 1 else "multi" model_answer = self._extract_answer(model_response, expected_format) if model_answer is None: temp_scores.append(0.0) else: if expected_format == "single": expected_single_count = expected_counts[target_letters[0]] temp_scores.append( 1.0 if model_answer == expected_single_count else 0.0 ) else: if set(model_answer.keys()) == set(target_letters) and all( model_answer.get(letter, -1) == expected_counts[letter] for letter in target_letters ): temp_scores.append(1.0) else: temp_scores.append(0.0) # Check if group has appropriate difficulty for training (group average within training range) group_average_score = ( sum(temp_scores) / len(temp_scores) if temp_scores else 0.0 ) # Skip groups where all scores are identical (no learning signal, including all 0.0) if temp_scores and all(temp_scores[0] == score for score in temp_scores): self.logger.debug( f"Skipping group save - all scores identical ({temp_scores[0]:.3f}) - no learning signal" ) elif group_average_score <= self.config.max_group_average_for_training: self.logger.debug( f"Saving group with scores: {[f'{s:.3f}' for s in temp_scores]} (group_avg: {group_average_score:.3f} <= {self.config.max_group_average_for_training})" # noqa ) rollouts_with_scores_to_save = [] for i, score_for_rollout in enumerate(temp_scores): conversation_messages = to_score[i][0] correct_counts = to_score[i][1] text = to_score[i][2] target_letters = to_score[i][3] stop_reason = to_score[i][4] rollouts_with_scores_to_save.append( { "conversation": conversation_messages, "score": score_for_rollout, "expected_counts": correct_counts, "text": text, "target_letters": target_letters, "stop_reason": stop_reason, "group_average_score": group_average_score, } ) if rollouts_with_scores_to_save: # Extract item info for logging correct_counts = item[1] text = item[2] target_letters = item[3] text_preview = ( text[:30].replace(" ", "_") if len(text) > 30 else text.replace(" ", "_") ) letters_str = "_".join(target_letters) counts_str = "_".join( str(correct_counts.get(letter, 0)) for letter in target_letters ) item_id = f"{text_preview}_{letters_str}_{counts_str}" item_data_to_save = { "item_id": item_id, "rollouts": rollouts_with_scores_to_save, } self.rollouts_to_save_buffer.append(item_data_to_save) self.processed_item_count += 1 # Log every single sample added to buffer self.logger.info( f"BUFFER ADD: Added item '{item_id}' to buffer. " f"Buffer now contains {len(self.rollouts_to_save_buffer)} groups " f"(target batch size: {self.config.dump_batch_size})" ) # Log progress every 10 items if self.processed_item_count % 10 == 0: self.logger.info( f"Data dump progress: {self.processed_item_count} items processed " f"(Buffer size: {len(self.rollouts_to_save_buffer)})" ) # Save in batches when buffer reaches the configured size if ( self.config.dump_rollouts and len(self.rollouts_to_save_buffer) >= self.config.dump_batch_size ): self.logger.info( f"Buffer reached batch size ({len(self.rollouts_to_save_buffer)}/{self.config.dump_batch_size}). " # noqa f"Saving batch to disk..." ) await self._save_rollouts_to_jsonl() # Safety mechanism: save every 50 items processed to prevent data loss elif ( self.config.dump_rollouts and self.processed_item_count % 50 == 0 and len(self.rollouts_to_save_buffer) > 0 ): self.logger.info( f"Safety save: {self.processed_item_count} items processed. " f"Saving {len(self.rollouts_to_save_buffer)} groups to prevent data loss..." ) await self._save_rollouts_to_jsonl() else: self.logger.debug( f"Skipping group save - group too easy for training (group_avg: {group_average_score:.3f} > {self.config.max_group_average_for_training})" # noqa ) to_backlog = [] return scored_data, to_backlog def _extract_answer(self, text, expected_format="single"): """ Extract the answer from model response (single number or JSON). Only allows one valid answer format - multiple answer formats result in a score of 0. Args: text: Text containing the model's response expected_format: "single" for number, "multi" for JSON Returns: Extracted answer (int for single, dict for multi) or None if invalid """ # Check for multiple tags - score as 0 if found think_tags = re.findall(r"", text, re.IGNORECASE) if len(think_tags) > 1: return None # Check if the think tag is properly opened - we need exactly one opening tag if len(think_tags) != 1: return None # Check for closing tags think_close_tags = re.findall(r"", text, re.IGNORECASE) if len(think_close_tags) != 1: return None # Must have exactly one closing tag # Split the text into thinking and answer sections parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1) # If there's no tag or multiple sections, return None if len(parts) != 2: return None thinking_section, answer_section = parts # Validate thinking section # Make sure thinking section actually contains the opening tag if "" not in thinking_section.lower(): return None # Malformed thinking section # Check if there are any tags in the answer section (after the first ) if "" in answer_section.lower(): return None # Look for answer tags in the answer section if expected_format == "single": # Single number format answer_pattern = r"\s*(\d+)\s*" answer_matches = re.findall(answer_pattern, answer_section, re.IGNORECASE) # If no answers found or multiple answers found, return None if len(answer_matches) != 1: return None # Return the single found answer as an integer try: return int(answer_matches[0]) except ValueError: return None else: # Multi-letter JSON format answer_pattern = r"\s*(\{[^}]+\})\s*" answer_matches = re.findall(answer_pattern, answer_section, re.IGNORECASE) # If no answers found or multiple answers found, return None if len(answer_matches) != 1: return None # Try to parse the JSON try: import json answer_dict = json.loads(answer_matches[0]) # Validate that all values are integers if not isinstance(answer_dict, dict): return None for key, value in answer_dict.items(): if not isinstance(key, str) or not isinstance(value, int): return None return answer_dict except (json.JSONDecodeError, ValueError): return None async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: """ Score the generated model responses against expected letter counts. Args: rollout_group_data: List of generated responses with expected answers Returns: ScoredDataGroup with tokenized inputs and scores, or None if no valid scores """ scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() if not rollout_group_data: return None # Get the expected answer from first item expected_counts = rollout_group_data[0][ 1 ] # correct counts (dict for multi, int for single - legacy) text = rollout_group_data[0][2] # text (word or sentence) target_letters = rollout_group_data[0][3] # target letters (list) # Handle legacy format (single letter as string, single count as int) if isinstance(target_letters, str): target_letters = [target_letters] expected_counts = {target_letters[0]: expected_counts} elif isinstance(expected_counts, int): # Legacy format with single letter expected_counts = {target_letters[0]: expected_counts} # Track statistics for all target letters for target_letter in target_letters: if target_letter not in self.letter_distribution_stats: self.letter_distribution_stats[target_letter] = 0 self.letter_distribution_stats[target_letter] += 1 text_len = len(text) if text_len not in self.word_length_stats: self.word_length_stats[text_len] = 0 self.word_length_stats[text_len] += 1 # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) format_errors_in_group = 0 think_errors_in_group = 0 for item in rollout_group_data: # Extract the model's response model_response = item[0][-1]["content"] stop_reason = item[4] # Get the stop reason # If the response was cut off due to length, give it a score of 0 if stop_reason == "length": reward = 0 if self.config.debug_logging: letters_str = ", ".join(target_letters) self.logger.debug( f"Text '{text[:50]}...' letters '{letters_str}': Length cutoff, score=0" ) else: # Determine expected format and extract the answer expected_format = "single" if len(target_letters) == 1 else "multi" model_answer = self._extract_answer(model_response, expected_format) # Track metrics based on result if model_answer is None: reward = 0 # Invalid format gets 0 reward format_errors_in_group += 1 # Check if it's a think format error if ( "" not in model_response.lower() or "" not in model_response.lower() ): think_errors_in_group += 1 if self.config.debug_logging: letters_str = ", ".join(target_letters) self.logger.debug( f"Text '{text[:50]}...' letters '{letters_str}': Format error, score=0" ) else: # Check if answer matches expected counts if expected_format == "single": # Single letter: compare integer expected_single_count = expected_counts[target_letters[0]] if model_answer == expected_single_count: reward = 1 if self.config.debug_logging: self.logger.debug( f"Text '{text[:50]}...' letter '{target_letters[0]}': Correct answer {model_answer}, score=1" # noqa ) else: reward = 0 if self.config.debug_logging: self.logger.debug( f"Text '{text[:50]}...' letter '{target_letters[0]}': Wrong answer {model_answer} (expected {expected_single_count}), score=0" # noqa ) else: # Multiple letters: compare dictionaries # Check if all expected letters are present and counts match if set(model_answer.keys()) == set(target_letters) and all( model_answer.get(letter, -1) == expected_counts[letter] for letter in target_letters ): reward = 1 if self.config.debug_logging: self.logger.debug( f"Text '{text[:50]}...' letters {target_letters}: Correct answer {model_answer}, score=1" # noqa ) else: reward = 0 if self.config.debug_logging: self.logger.debug( f"Text '{text[:50]}...' letters {target_letters}: Wrong answer {model_answer} (expected {expected_counts}), score=0" # noqa ) # Tokenize the conversation for learning out_dict = tokenize_for_trainer(self.tokenizer, item[0]) tokens = out_dict["tokens"] masks = out_dict["masks"] # Remove examples with insufficient context if len([1 for i in masks if i != -100]) < 10: continue scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(1.0 if reward else 0.0) # Break once we have enough examples if len(scores["tokens"]) >= self.config.group_size: break if not scores["tokens"]: letters_str = ", ".join(target_letters) self.logger.warning( f"No valid items were scored for text '{text[:50]}...' letters '{letters_str}' - all items had insufficient context" # noqa ) return None # Update global error counters self.answer_format_errors += format_errors_in_group self.think_format_errors += think_errors_in_group # Record success rate metrics for wandb logging for score in scores["scores"]: self.percent_correct_buffer.append(score) # Calculate and log average score for the current group current_scores = scores.get("scores", []) if current_scores: average_score = sum(current_scores) / len(current_scores) # Create log message with appropriate text preview text_preview = text[:50] + "..." if len(text) > 50 else text letters_str = ", ".join(target_letters) expected_str = ( str(expected_counts) if len(target_letters) > 1 else str(expected_counts[target_letters[0]]) ) log_message_main = ( f"Text: '{text_preview}' | Letters: '{letters_str}' | Expected: {expected_str} | " f"Group average score: {average_score:.4f}" ) if all(s == 1.0 for s in current_scores): self.logger.info(f"{log_message_main} (All correct in this group!)") elif all(s == 0.0 for s in current_scores): self.logger.info( f"{log_message_main} (All failed - format/answer errors!)" ) else: self.logger.info(log_message_main) # Check training threshold - if group is too easy, skip it for training if average_score > self.config.max_group_average_for_training: self.logger.debug( f"Skipping group for training - too easy (avg: {average_score:.3f} > threshold: {self.config.max_group_average_for_training})" # noqa ) return None # Check if all scores are the same (no learning signal) if all(scores["scores"][0] == score for score in scores["scores"]): self.logger.debug( f"All scores in group are identical ({scores['scores'][0]:.4f}) - no learning signal, skipping group" ) return None return scores async def rollout_and_score_eval(self, test_text): """ Generate and score model responses for a single test text. Args: test_text: Test text from dataset (could be word, sentence, or random string) Returns: Score (1 for correct, 0 for incorrect) """ # Decide whether to use multiple letters (same logic as get_next_item) use_multiple = ( self.config.max_letters_to_count > 1 and random.random() < self.config.multi_letter_probability ) if use_multiple: # Choose 2 to max_letters_to_count different letters num_letters = random.randint(2, self.config.max_letters_to_count) target_letters = self._select_target_letters(test_text, num_letters) else: # Single letter counting target_letters = self._select_target_letters(test_text, 1) # Prepare text for counting (handle punctuation/spaces based on config) text_for_counting = self._prepare_text_for_counting(test_text) # Count occurrences for each target letter (case-insensitive) expected_counts = {} for letter in target_letters: expected_counts[letter] = text_for_counting.lower().count(letter.lower()) # Determine if this is a text passage or word/string based on length and content is_text_passage = ( len(test_text) > 50 or " " in test_text or any(c in test_text for c in ".,!?;:") ) # Create the question based on whether this item is a text passage or word/string and single/multiple letters if len(target_letters) == 1: # Single letter question target_letter = target_letters[0] if is_text_passage: question_text = f'How many {target_letter}s are in the following text: "{test_text}"?' else: question_text = ( f"How many {target_letter}s are in the string {test_text}?" ) # Add instruction for single letter answer format question_with_instruction = f"{question_text}\n\nProvide your answer in the format: {{number}}" # noqa else: # Multiple letters question letters_str = ( ", ".join(f"'{letter}'" for letter in target_letters[:-1]) + f", and '{target_letters[-1]}'" # noqa ) if is_text_passage: question_text = f'Count the occurrences of the letters {letters_str} in the following text: "{test_text}"' # noqa else: question_text = f"Count the occurrences of the letters {letters_str} in the string {test_text}" # noqa # Add instruction for multiple letter JSON answer format example_json = ( "{" + ", ".join(f'"{letter}": 0' for letter in target_letters) + "}" ) question_with_instruction = f"{question_text}\n\nProvide your answer as JSON in the format: {example_json}" # noqa # Create messages for model messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question_with_instruction}, ] # Apply chat template to convert messages to a single string prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Get model completion completion = await self.server.completion( prompt=prompt, n=1, max_tokens=self.config.max_generation_tokens, temperature=self.config.eval_temperature, split="eval", ) # Extract the model's response from the completion model_response = completion.choices[0].text # Determine expected format and extract the answer expected_format = "single" if len(target_letters) == 1 else "multi" model_answer = self._extract_answer(model_response, expected_format) # Score 1 if the answers match, 0 otherwise if model_answer is None: score = 0 elif expected_format == "single": # Single letter: compare integer expected_single_count = expected_counts[target_letters[0]] score = 1 if model_answer == expected_single_count else 0 else: # Multiple letters: compare dictionaries score = ( 1 if ( set(model_answer.keys()) == set(target_letters) and all( model_answer.get(letter, -1) == expected_counts[letter] for letter in target_letters ) ) else 0 ) return score async def evaluate(self, *args, **kwargs): """ Evaluate the model on test data. """ self.logger.info("Starting evaluation...") if not self.test_words: self.logger.warning("No test texts available for evaluation. Skipping.") self.eval_metrics.append(("eval/percent_correct", 0.0)) return eval_tasks = [] # Sample a subset of test texts for evaluation to keep it manageable eval_texts = random.sample( self.test_words, min(len(self.test_words), self.config.eval_sample_size) ) text_type = ( "mixed items (words and passages)" if self.config.use_text_passages else "strings" ) self.logger.info( f"Starting evaluation on {len(eval_texts)} test {text_type}..." ) for test_text in eval_texts: eval_tasks.append(self.rollout_and_score_eval(test_text)) # Run evaluation scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating") if not scores: percent_correct = 0.0 else: percent_correct = sum(scores) / len(scores) self.eval_metrics.append(("eval/percent_correct", percent_correct)) self.logger.info(f"Evaluation finished. Percent correct: {percent_correct:.4f}") async def add_rollouts_for_wandb( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], item: Item = None, ): if item is None or scored_data is None or not scored_data.get("tokens"): return expected_counts = item[1] # correct counts (dict) text = item[2] # text (word or sentence) target_letters = item[3] # target letters (list) # Handle legacy format if isinstance(target_letters, str): target_letters = [target_letters] expected_counts = {target_letters[0]: expected_counts} elif isinstance(expected_counts, int): expected_counts = {target_letters[0]: expected_counts} # save rollout to trajectory num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: num_keep = self.config.group_size # Make sure there's data to log num_keep = min(num_keep, len(scored_data["tokens"])) if num_keep == 0: return # Calculate group average score group_scores = scored_data.get("scores", []) group_average_score = ( sum(group_scores) / len(group_scores) if group_scores else 0.0 ) current_rollouts = [] for i in range(num_keep): # Ensure tokens and scores have the same length if i < len(scored_data["tokens"]) and i < len(scored_data["scores"]): # Decode the full trajectory including prompt and model response full_text = self.tokenizer.decode( scored_data["tokens"][i], skip_special_tokens=True ) score_val = scored_data["scores"][i] expected_str = ( str(expected_counts) if len(target_letters) > 1 else str(expected_counts[target_letters[0]]) ) letters_str = ", ".join(target_letters) current_rollouts.append( ( full_text, score_val, expected_str, text[:100], letters_str, group_average_score, ) ) else: self.logger.warning( f"Mismatch in lengths of tokens/scores for wandb logging at index {i}." ) self.rollouts_for_wandb.append(current_rollouts) if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) async def create_rollout_table(self, wandb_metrics): if len(self.rollouts_for_wandb) > 0: table = wandb.Table( columns=[ "full_text", "score", "expected_counts", "text", "target_letters", "group_average_score", ] ) for group in self.rollouts_for_wandb: for item in group: # Handle both old format (5 items) and new format (6 items) if len(item) >= 6: table.add_data( item[0], item[1], item[2], item[3], item[4], item[5] ) else: table.add_data(item[0], item[1], item[2], item[3], item[4], 0.0) wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Try to calculate percent_correct, pass if there's a division by zero try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: # Skip if buffer is empty pass self.percent_correct_buffer = list() # Add eval metrics for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() # Add comprehensive letter counting specific metrics if self.letter_distribution_stats: # Log letter distribution statistics total_letters_asked = sum(self.letter_distribution_stats.values()) wandb_metrics["stats/total_letters_asked"] = total_letters_asked # Log most and least common letters asked most_common_letter = max( self.letter_distribution_stats, key=self.letter_distribution_stats.get ) least_common_letter = min( self.letter_distribution_stats, key=self.letter_distribution_stats.get ) wandb_metrics["stats/most_common_letter_count"] = ( self.letter_distribution_stats[most_common_letter] ) wandb_metrics["stats/least_common_letter_count"] = ( self.letter_distribution_stats[least_common_letter] ) # Log distribution entropy (measure of how evenly distributed the letters are) import math entropy = -sum( (count / total_letters_asked) * math.log2(count / total_letters_asked) for count in self.letter_distribution_stats.values() if count > 0 ) wandb_metrics["stats/letter_distribution_entropy"] = entropy if self.word_length_stats: # Log word length statistics total_words_asked = sum(self.word_length_stats.values()) wandb_metrics["stats/total_words_asked"] = total_words_asked # Calculate average word length avg_word_length = ( sum(length * count for length, count in self.word_length_stats.items()) / total_words_asked ) wandb_metrics["stats/avg_word_length"] = avg_word_length # Log min and max word lengths seen wandb_metrics["stats/min_word_length"] = min(self.word_length_stats.keys()) wandb_metrics["stats/max_word_length"] = max(self.word_length_stats.keys()) # Log error rates if self.processed_item_count > 0: wandb_metrics["errors/answer_format_error_rate"] = ( self.answer_format_errors / (self.processed_item_count * self.config.group_size) ) wandb_metrics["errors/think_format_error_rate"] = ( self.think_format_errors / (self.processed_item_count * self.config.group_size) ) wandb_metrics["errors/total_format_errors"] = ( self.answer_format_errors + self.think_format_errors ) # Log data dumping progress if self.config.dump_rollouts: wandb_metrics["data_dumps/processed_item_count"] = self.processed_item_count wandb_metrics["data_dumps/rollouts_buffer_size"] = len( self.rollouts_to_save_buffer ) wandb_metrics["data_dumps/save_file_batch_num"] = self.save_file_batch_num wandb_metrics["data_dumps/batch_size"] = self.config.dump_batch_size # Add rollout table wandb_metrics = await self.create_rollout_table(wandb_metrics) await super().wandb_log(wandb_metrics) if __name__ == "__main__": LetterCountingEnv.cli()