mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
2177 lines
88 KiB
Python
2177 lines
88 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import uuid
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import wandb
|
|
from pydantic import Field
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
Item,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
# Import NLTK words corpus for large-scale word list
|
|
try:
|
|
import nltk
|
|
from nltk.corpus import words
|
|
from nltk.tokenize import sent_tokenize
|
|
|
|
# Download required NLTK data
|
|
try:
|
|
words.words()
|
|
except LookupError:
|
|
nltk.download("words")
|
|
try:
|
|
sent_tokenize("Test sentence.")
|
|
except LookupError:
|
|
nltk.download("punkt")
|
|
except ImportError:
|
|
print("Warning: NLTK not available. Please install with: pip install nltk")
|
|
words = None
|
|
sent_tokenize = None
|
|
|
|
# Import datasets for OpenWebText
|
|
try:
|
|
from datasets import load_dataset
|
|
except ImportError:
|
|
print(
|
|
"Warning: datasets library not available. Please install with: pip install datasets"
|
|
)
|
|
load_dataset = None
|
|
|
|
system_prompt = (
|
|
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
|
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
|
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
|
"</think> tags, and then provide your solution or response to the problem."
|
|
)
|
|
|
|
|
|
class LetterCountingConfig(BaseEnvConfig):
|
|
"""Configuration class for Letter Counting Environment with custom parameters."""
|
|
|
|
# Word dataset configuration
|
|
min_word_length: int = Field(3, description="Minimum word length to include")
|
|
max_word_length: int = Field(30, description="Maximum word length to include")
|
|
train_test_split: float = Field(0.95, description="Ratio for train/test split")
|
|
|
|
# Letter selection configuration
|
|
use_all_letters: bool = Field(
|
|
True, description="Whether to use all 26 letters or custom set"
|
|
)
|
|
custom_letters: str = Field(
|
|
"aeiou", description="Custom letter set if use_all_letters=False"
|
|
)
|
|
present_letter_bias: float = Field(
|
|
0.5,
|
|
description="Probability of choosing letters that are present in the text (0.0-1.0)",
|
|
)
|
|
|
|
# Generation configuration
|
|
generation_temperature: float = Field(
|
|
1.0, description="Temperature for training generation"
|
|
)
|
|
eval_temperature: float = Field(
|
|
0.2, description="Temperature for evaluation generation"
|
|
)
|
|
max_generation_tokens: int = Field(
|
|
1024 * 15, description="Maximum tokens for model generation"
|
|
)
|
|
|
|
# Evaluation configuration
|
|
eval_sample_size: int = Field(
|
|
1000, description="Number of test words to evaluate on"
|
|
)
|
|
|
|
# Reproducibility configuration
|
|
random_seed: Optional[int] = Field(
|
|
42, description="Seed for reproducibility, None for random"
|
|
)
|
|
|
|
# Random string generation configuration
|
|
random_string_percentage: float = Field(
|
|
0.03, description="Percentage of dataset to be random strings (0.0-1.0)"
|
|
)
|
|
random_string_min_length: int = Field(
|
|
3, description="Minimum length for random strings"
|
|
)
|
|
random_string_max_length: int = Field(
|
|
25, description="Maximum length for random strings"
|
|
)
|
|
|
|
# Word capitalization configuration
|
|
uppercase_word_percentage: float = Field(
|
|
0.01, description="Percentage of real words to make uppercase (0.0-1.0)"
|
|
)
|
|
capitalized_word_percentage: float = Field(
|
|
0.01,
|
|
description="Percentage of real words to capitalize first letter (0.0-1.0)",
|
|
)
|
|
|
|
# Text/passage configuration
|
|
use_text_passages: bool = Field(
|
|
False, description="Include text passages from OpenWebText in addition to words"
|
|
)
|
|
text_passage_percentage: float = Field(
|
|
0.5,
|
|
description="Percentage of dataset to be text passages when use_text_passages=True (0.0-1.0)",
|
|
)
|
|
min_text_length: int = Field(
|
|
50, description="Minimum character length for text passages"
|
|
)
|
|
max_text_length: int = Field(
|
|
500, description="Maximum character length for text passages"
|
|
)
|
|
include_punctuation_in_count: bool = Field(
|
|
True, description="Include punctuation in letter counting"
|
|
)
|
|
include_spaces_in_count: bool = Field(
|
|
False, description="Include spaces in letter counting"
|
|
)
|
|
|
|
# Multi-letter counting configuration
|
|
max_letters_to_count: int = Field(
|
|
1,
|
|
description="Maximum number of different letters to count simultaneously (1 for single letter)",
|
|
)
|
|
multi_letter_probability: float = Field(
|
|
0.2, description="Probability of asking for multiple letters (0.0-1.0)"
|
|
)
|
|
|
|
# Difficulty and training thresholds
|
|
max_group_average_for_training: float = Field(
|
|
1.0,
|
|
description="Maximum group average to use for training (skip groups that are too easy)",
|
|
)
|
|
|
|
# Logging and data dumping configuration
|
|
debug_logging: bool = Field(
|
|
True, description="Enable debug-level logging for more verbose output"
|
|
)
|
|
suppress_base_env_logs: bool = Field(
|
|
True, description="Suppress verbose base environment logs"
|
|
)
|
|
dump_rollouts: bool = Field(
|
|
False, description="Whether to dump successful rollouts to JSONL files"
|
|
)
|
|
dump_batch_size: int = Field(
|
|
100,
|
|
description="Number of groups to accumulate before saving to disk (1 = save immediately)",
|
|
)
|
|
|
|
|
|
class LetterCountingEnv(BaseEnv):
|
|
"""
|
|
Letter Counting Environment for training models to count letters in words and sentences.
|
|
|
|
This environment presents the model with questions like "How many 'a's are in the word 'banana'?"
|
|
or "Count the occurrences of the letters 'e', 'o', and 't' in the following text: 'The quick brown fox jumps over the lazy dog'"
|
|
and expects responses in the format <answer>3</answer> for single letters or <answer>{"e": 4, "o": 4, "t": 2}</answer>
|
|
for multiple letters. The model should use <think></think> tags for reasoning before providing the final answer.
|
|
|
|
Features:
|
|
- **Word Mode**: Uses NLTK's words corpus (236k+ English words)
|
|
- **Mixed Mode**: Combines words and text passages from OpenWebText-10k dataset
|
|
- **Text Passage Mode**: Uses OpenWebText-10k dataset with character-based text extraction
|
|
- Optional random string generation (80% alphabetical) mixed with real words
|
|
- Configurable word/string/passage length ranges and letter sets
|
|
- Optional word capitalization (uppercase, title case)
|
|
- **Multi-letter counting**: Configurable simultaneous counting of multiple letters with JSON responses
|
|
- **Letter selection bias**: Configurable bias toward letters present in the text (reduces zero-count questions)
|
|
- Training thresholds based on group average scores
|
|
- Configurable punctuation and space handling for letter counting
|
|
- Comprehensive logging and data dumping capabilities
|
|
- Detailed metrics tracking (letter distribution, text lengths, error rates, group average scores)
|
|
- Support for saving successful and failed rollouts for analysis
|
|
|
|
Data Dumping:
|
|
- Set dump_rollouts=True to save rollouts from groups with appropriate difficulty
|
|
- Only saves groups where group average score ≤ max_group_average_for_training
|
|
- Files saved to data_dumps/ directory with unique UUIDs
|
|
- Rollouts include full conversations, scores, metadata
|
|
- Configurable batch size (dump_batch_size) for efficient disk I/O
|
|
|
|
Mixed Mode Configuration:
|
|
- Set use_text_passages=True to enable mixed mode with both words and text passages
|
|
- Configure text_passage_percentage to control the ratio (e.g., 0.3 = 30% passages, 70% words)
|
|
- Configure min/max text passage character lengths (more reliable than word counts)
|
|
- Set max_group_average_for_training to skip groups that are too easy
|
|
|
|
Logging:
|
|
- Set debug_logging=True for verbose per-item scoring details
|
|
- Comprehensive WandB metrics including letter distribution entropy and group average scores
|
|
- Progress tracking for data dumps and evaluation
|
|
""" # noqa
|
|
|
|
name = "letter_counting"
|
|
env_config_cls = LetterCountingConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: LetterCountingConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
"""
|
|
Initialize the Letter Counting environment.
|
|
|
|
Args:
|
|
config: Configuration for the base environment
|
|
server_configs: List of server configurations for OpenAI API
|
|
slurm: Whether to use Slurm for distributed training
|
|
testing: Whether in testing mode
|
|
""" # noqa: E501
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# Initialize data dumping infrastructure first (needed for logging)
|
|
self.run_uuid = str(uuid.uuid4())
|
|
self.rollouts_to_save_buffer: List[
|
|
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
|
|
] = []
|
|
self.processed_item_count = 0
|
|
self.datadumps_dir = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "data_dumps"
|
|
)
|
|
self.save_file_batch_num = 0
|
|
|
|
# Additional metrics tracking
|
|
self.letter_distribution_stats: Dict[str, int] = {}
|
|
self.word_length_stats: Dict[int, int] = {}
|
|
self.answer_format_errors = 0
|
|
self.think_format_errors = 0
|
|
|
|
# Initialize the logger
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
if not self.logger.handlers:
|
|
# Add a basic stream handler if no handlers are configured
|
|
_handler = logging.StreamHandler()
|
|
_formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
_handler.setFormatter(_formatter)
|
|
self.logger.addHandler(_handler)
|
|
# Set logging level based on config
|
|
log_level = logging.DEBUG if self.config.debug_logging else logging.INFO
|
|
self.logger.setLevel(log_level)
|
|
# Ensure the logger itself is enabled
|
|
self.logger.disabled = False
|
|
|
|
# Suppress base environment logs if requested
|
|
if self.config.suppress_base_env_logs:
|
|
# Set the base environment logger to WARNING level to suppress INFO logs
|
|
base_logger = logging.getLogger("atroposlib.envs.base")
|
|
base_logger.setLevel(logging.WARNING)
|
|
|
|
# Log initialization completion
|
|
self.logger.info(
|
|
f"LetterCountingEnv initialized with run UUID: {self.run_uuid}"
|
|
)
|
|
self.logger.info(
|
|
f"Debug logging: {'enabled' if self.config.debug_logging else 'disabled'}"
|
|
)
|
|
self.logger.info(
|
|
f"Data dumping: rollouts={'enabled' if self.config.dump_rollouts else 'disabled'}"
|
|
)
|
|
|
|
self.percent_correct_buffer = list()
|
|
self.eval_metrics = list()
|
|
self.rollouts_for_wandb: List[List[Tuple[str, float, str, str, str]]] = []
|
|
|
|
@classmethod
|
|
def config_init(self) -> Tuple[LetterCountingConfig, List[APIServerConfig]]:
|
|
env_config = LetterCountingConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=32,
|
|
use_wandb=True,
|
|
max_num_workers=128,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=250,
|
|
batch_size=1024,
|
|
steps_per_eval=20,
|
|
max_token_length=1024 * 15,
|
|
inference_weight=1.0,
|
|
wandb_name="letter_counting_deep_thinking",
|
|
data_path_to_save_groups=None,
|
|
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
|
eval_limit_ratio=0.1,
|
|
# Letter counting specific configs
|
|
min_word_length=3,
|
|
max_word_length=30,
|
|
train_test_split=0.95,
|
|
eval_sample_size=1000,
|
|
generation_temperature=1.0,
|
|
eval_temperature=0.5,
|
|
random_seed=42,
|
|
use_all_letters=True,
|
|
custom_letters="aeiou",
|
|
present_letter_bias=0.5,
|
|
max_generation_tokens=1024 * 15,
|
|
# Random string generation
|
|
random_string_percentage=0.0,
|
|
random_string_min_length=3,
|
|
random_string_max_length=15,
|
|
# Word capitalization
|
|
uppercase_word_percentage=0.01,
|
|
capitalized_word_percentage=0.005,
|
|
# Text passage configuration
|
|
use_text_passages=True,
|
|
text_passage_percentage=0.3,
|
|
min_text_length=3,
|
|
max_text_length=2000,
|
|
include_punctuation_in_count=True,
|
|
include_spaces_in_count=True,
|
|
max_group_average_for_training=0.7,
|
|
# Multi-letter counting
|
|
max_letters_to_count=4,
|
|
multi_letter_probability=0.2,
|
|
debug_logging=True,
|
|
dump_rollouts=True,
|
|
dump_batch_size=100,
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_max_requests_at_once=32,
|
|
num_requests_for_eval=256,
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
"""
|
|
Set up the environment by loading and preparing the word/text dataset.
|
|
"""
|
|
if self.config.use_text_passages:
|
|
await self._setup_mixed_dataset()
|
|
else:
|
|
await self._setup_word_dataset()
|
|
|
|
# Initialize iteration counter
|
|
self.iter = 0
|
|
|
|
async def _setup_mixed_dataset(self):
|
|
"""
|
|
Set up the environment using both words and text passages from OpenWebText dataset.
|
|
"""
|
|
if load_dataset is None:
|
|
raise ImportError(
|
|
"datasets library is required for text passage mode. Please install with: pip install datasets"
|
|
)
|
|
if words is None:
|
|
raise ImportError(
|
|
"NLTK is required for this environment. Please install with: pip install nltk"
|
|
)
|
|
|
|
# Set random seed for reproducibility if configured
|
|
if self.config.random_seed is not None:
|
|
random.seed(self.config.random_seed)
|
|
|
|
# Validate configuration
|
|
await self._validate_config()
|
|
|
|
self.logger.info("Setting up mixed dataset with words and text passages...")
|
|
|
|
# First, set up words (same as word-only mode)
|
|
all_words = words.words()
|
|
filtered_words = [
|
|
word.lower()
|
|
for word in all_words
|
|
if word.isalpha()
|
|
and self.config.min_word_length <= len(word) <= self.config.max_word_length
|
|
]
|
|
|
|
# Apply capitalization to real words if configured
|
|
filtered_words = self._apply_word_capitalization(filtered_words)
|
|
|
|
# Generate random strings if configured
|
|
if self.config.random_string_percentage > 0.0:
|
|
total_filtered_words = len(filtered_words)
|
|
if self.config.random_string_percentage >= 1.0:
|
|
num_random_strings = (
|
|
total_filtered_words if total_filtered_words > 0 else 1000
|
|
)
|
|
filtered_words = []
|
|
else:
|
|
num_random_strings = int(
|
|
(self.config.random_string_percentage * total_filtered_words)
|
|
/ (1.0 - self.config.random_string_percentage)
|
|
)
|
|
|
|
random_strings = self._generate_random_strings(num_random_strings)
|
|
all_word_strings = filtered_words + random_strings
|
|
self.logger.info(
|
|
f"Generated {num_random_strings} random strings ({self.config.random_string_percentage:.1%} of word dataset)" # noqa
|
|
)
|
|
else:
|
|
all_word_strings = filtered_words
|
|
random_strings = []
|
|
|
|
self.logger.info(
|
|
f"Prepared {len(all_word_strings)} word/string items ({len(filtered_words)} real words + {len(random_strings)} random strings)" # noqa
|
|
)
|
|
|
|
# Now load and process text passages
|
|
self.logger.info("Loading OpenWebText-10k dataset for text passages...")
|
|
dataset = load_dataset("stas/openwebtext-10k", split="train")
|
|
self.logger.info(f"Loaded {len(dataset)} text samples from OpenWebText")
|
|
|
|
# Extract and filter text passages
|
|
all_passages = []
|
|
processed_texts = 0
|
|
|
|
for item in dataset:
|
|
text = item["text"]
|
|
|
|
# Skip texts that are too long initially
|
|
if (
|
|
len(text) > self.config.max_text_length * 3
|
|
): # Allow some overhead for chunking
|
|
continue
|
|
|
|
# Extract passages from this text
|
|
passages = self._extract_text_passages(text)
|
|
all_passages.extend(passages)
|
|
processed_texts += 1
|
|
|
|
# Log progress periodically
|
|
if processed_texts % 1000 == 0:
|
|
self.logger.info(
|
|
f"Processed {processed_texts} texts, extracted {len(all_passages)} passages so far..."
|
|
)
|
|
|
|
self.logger.info(
|
|
f"Extracted {len(all_passages)} text passages from {processed_texts} texts"
|
|
)
|
|
|
|
# Now mix words and passages according to text_passage_percentage
|
|
if self.config.text_passage_percentage >= 1.0:
|
|
# Special case: 100% text passages, no words
|
|
all_mixed_items = all_passages
|
|
final_word_count = 0
|
|
final_passage_count = len(all_passages)
|
|
elif self.config.text_passage_percentage <= 0.0:
|
|
# Special case: 0% text passages, only words
|
|
all_mixed_items = all_word_strings
|
|
final_word_count = len(all_word_strings)
|
|
final_passage_count = 0
|
|
else:
|
|
# Calculate how many passages to include
|
|
# If we want X% passages, then passages / (words + passages) = X
|
|
# Solving: passages_to_use = X * words / (1 - X)
|
|
total_words = len(all_word_strings)
|
|
num_passages_to_use = int(
|
|
(self.config.text_passage_percentage * total_words)
|
|
/ (1.0 - self.config.text_passage_percentage)
|
|
)
|
|
|
|
# Upsample passages if we don't have enough to reach the target percentage
|
|
if num_passages_to_use > len(all_passages):
|
|
self.logger.info(
|
|
f"Upsampling passages: need {num_passages_to_use} but only have {len(all_passages)}. "
|
|
f"Will repeat passages to reach target percentage."
|
|
)
|
|
# Calculate how many times we need to repeat the passage list
|
|
repeat_factor = (num_passages_to_use // len(all_passages)) + 1
|
|
upsampled_passages = all_passages * repeat_factor
|
|
|
|
# Shuffle the upsampled passages to avoid patterns
|
|
random.shuffle(upsampled_passages)
|
|
|
|
# Take exactly the number we need
|
|
passages_to_use = upsampled_passages[:num_passages_to_use]
|
|
|
|
self.logger.info(
|
|
f"Upsampled {len(all_passages)} unique passages {repeat_factor}x times, "
|
|
f"then sampled {len(passages_to_use)} passages for the dataset."
|
|
)
|
|
else:
|
|
# We have enough passages, just sample what we need
|
|
passages_to_use = random.sample(all_passages, num_passages_to_use)
|
|
|
|
# Combine words and passages
|
|
all_mixed_items = all_word_strings + passages_to_use
|
|
final_word_count = len(all_word_strings)
|
|
final_passage_count = len(passages_to_use)
|
|
|
|
# Shuffle the mixed dataset
|
|
random.shuffle(all_mixed_items)
|
|
|
|
# Create train/test split
|
|
split_point = int(self.config.train_test_split * len(all_mixed_items))
|
|
self.train_words = all_mixed_items[
|
|
:split_point
|
|
] # Reusing train_words for mixed items
|
|
self.test_words = all_mixed_items[split_point:]
|
|
|
|
# Calculate actual percentages
|
|
actual_passage_percentage = (
|
|
final_passage_count / len(all_mixed_items)
|
|
if len(all_mixed_items) > 0
|
|
else 0.0
|
|
)
|
|
|
|
# Calculate unique passages for logging
|
|
unique_passages_used = (
|
|
len(set(passages_to_use)) if "passages_to_use" in locals() else 0
|
|
)
|
|
|
|
# Log dataset statistics
|
|
self.logger.info("Mixed dataset created:")
|
|
if unique_passages_used < final_passage_count:
|
|
# Upsampling occurred
|
|
self.logger.info(
|
|
f" Total items: {len(all_mixed_items)} ({final_word_count} words/strings + {final_passage_count} passages)" # noqa
|
|
)
|
|
self.logger.info(
|
|
f" Unique passages used: {unique_passages_used} (repeated to create {final_passage_count} total passage instances)" # noqa
|
|
)
|
|
else:
|
|
# No upsampling
|
|
self.logger.info(
|
|
f" Total items: {len(all_mixed_items)} ({final_word_count} words/strings + {final_passage_count} passages)" # noqa
|
|
)
|
|
self.logger.info(
|
|
f" Actual passage percentage: {actual_passage_percentage:.1%} (target: {self.config.text_passage_percentage:.1%})" # noqa
|
|
)
|
|
self.logger.info(f" Training items: {len(self.train_words)}")
|
|
self.logger.info(f" Test items: {len(self.test_words)}")
|
|
|
|
# Show examples of both types
|
|
word_examples = [item for item in self.train_words[:10] if len(item) <= 50][:3]
|
|
passage_examples = [
|
|
item[:100] + "..." for item in self.train_words[:50] if len(item) > 50
|
|
][:3]
|
|
|
|
if word_examples:
|
|
self.logger.info(f" Example words: {word_examples}")
|
|
if passage_examples:
|
|
self.logger.info(f" Example passages: {passage_examples}")
|
|
|
|
# If we have upsampled passages, show some statistics
|
|
if unique_passages_used < final_passage_count and unique_passages_used > 0:
|
|
avg_repetitions = final_passage_count / unique_passages_used
|
|
self.logger.info(f" Average passage repetitions: {avg_repetitions:.1f}x")
|
|
|
|
# Log configuration details
|
|
self.logger.info(
|
|
f"Word length range: {self.config.min_word_length}-{self.config.max_word_length}"
|
|
)
|
|
self.logger.info(
|
|
f"Passage length range: {self.config.min_text_length}-{self.config.max_text_length} characters"
|
|
)
|
|
self.logger.info(
|
|
f"Include punctuation: {self.config.include_punctuation_in_count}"
|
|
)
|
|
self.logger.info(f"Include spaces: {self.config.include_spaces_in_count}")
|
|
self.logger.info(
|
|
f"Training threshold: {self.config.max_group_average_for_training}"
|
|
)
|
|
|
|
self.logger.info("Mixed dataset setup complete")
|
|
|
|
async def _setup_text_passage_dataset(self):
|
|
"""
|
|
Set up the environment using text passages from OpenWebText dataset.
|
|
"""
|
|
if load_dataset is None:
|
|
raise ImportError(
|
|
"datasets library is required for text passage mode. Please install with: pip install datasets"
|
|
)
|
|
|
|
# Set random seed for reproducibility if configured
|
|
if self.config.random_seed is not None:
|
|
random.seed(self.config.random_seed)
|
|
|
|
# Validate configuration
|
|
await self._validate_config()
|
|
|
|
self.logger.info("Loading OpenWebText-10k dataset...")
|
|
|
|
# Load the dataset
|
|
dataset = load_dataset("stas/openwebtext-10k", split="train")
|
|
|
|
self.logger.info(f"Loaded {len(dataset)} text samples from OpenWebText")
|
|
|
|
# Extract and filter text passages
|
|
all_passages = []
|
|
processed_texts = 0
|
|
|
|
for item in dataset:
|
|
text = item["text"]
|
|
|
|
# Skip texts that are too long initially
|
|
if (
|
|
len(text) > self.config.max_text_length * 3
|
|
): # Allow some overhead for chunking
|
|
continue
|
|
|
|
# Extract passages from this text
|
|
passages = self._extract_text_passages(text)
|
|
all_passages.extend(passages)
|
|
processed_texts += 1
|
|
|
|
# Log progress periodically
|
|
if processed_texts % 1000 == 0:
|
|
self.logger.info(
|
|
f"Processed {processed_texts} texts, extracted {len(all_passages)} passages so far..."
|
|
)
|
|
|
|
self.logger.info(
|
|
f"Extracted {len(all_passages)} text passages from {processed_texts} texts"
|
|
)
|
|
|
|
# Shuffle passages for randomness
|
|
random.shuffle(all_passages)
|
|
|
|
# Create train/test split
|
|
split_point = int(self.config.train_test_split * len(all_passages))
|
|
self.train_words = all_passages[
|
|
:split_point
|
|
] # Reusing train_words for passages
|
|
self.test_words = all_passages[split_point:]
|
|
|
|
# Log dataset statistics
|
|
self.logger.info(f"Training passages: {len(self.train_words)}")
|
|
self.logger.info(f"Test passages: {len(self.test_words)}")
|
|
self.logger.info(
|
|
f"Example passages: {[p[:100] + '...' for p in self.train_words[:3]]}"
|
|
)
|
|
|
|
# Log configuration details
|
|
self.logger.info(
|
|
f"Passage length range: {self.config.min_text_length}-{self.config.max_text_length} characters"
|
|
)
|
|
self.logger.info(
|
|
f"Include punctuation: {self.config.include_punctuation_in_count}"
|
|
)
|
|
self.logger.info(f"Include spaces: {self.config.include_spaces_in_count}")
|
|
self.logger.info(
|
|
f"Training threshold: {self.config.max_group_average_for_training}"
|
|
)
|
|
|
|
self.logger.info("Text passage dataset setup complete")
|
|
|
|
async def _setup_word_dataset(self):
|
|
"""
|
|
Set up the environment using single words (original functionality).
|
|
"""
|
|
# Load the NLTK words corpus (contains 236,736 English words)
|
|
if words is None:
|
|
raise ImportError(
|
|
"NLTK is required for this environment. Please install with: pip install nltk"
|
|
)
|
|
|
|
# Set random seed for reproducibility if configured
|
|
if self.config.random_seed is not None:
|
|
random.seed(self.config.random_seed)
|
|
|
|
# Validate configuration
|
|
await self._validate_config()
|
|
|
|
# Get all English words from NLTK
|
|
all_words = words.words()
|
|
|
|
# Filter words to ensure they contain only alphabetic characters
|
|
# and are within the configured length range for reasonable difficulty
|
|
filtered_words = [
|
|
word.lower()
|
|
for word in all_words
|
|
if word.isalpha()
|
|
and self.config.min_word_length <= len(word) <= self.config.max_word_length
|
|
]
|
|
|
|
# Apply capitalization to real words if configured
|
|
filtered_words = self._apply_word_capitalization(filtered_words)
|
|
|
|
# Generate random strings if configured
|
|
if self.config.random_string_percentage > 0.0:
|
|
# Calculate how many random strings to generate
|
|
total_filtered_words = len(filtered_words)
|
|
# If we want X% random strings, then random_strings / (words + random_strings) = X
|
|
# Solving: random_strings = X * words / (1 - X)
|
|
if self.config.random_string_percentage >= 1.0:
|
|
# Special case: 100% random strings, no real words
|
|
num_random_strings = (
|
|
total_filtered_words if total_filtered_words > 0 else 1000
|
|
)
|
|
filtered_words = []
|
|
else:
|
|
num_random_strings = int(
|
|
(self.config.random_string_percentage * total_filtered_words)
|
|
/ (1.0 - self.config.random_string_percentage)
|
|
)
|
|
|
|
random_strings = self._generate_random_strings(num_random_strings)
|
|
|
|
# Combine real words and random strings
|
|
all_strings = filtered_words + random_strings
|
|
self.logger.info(
|
|
f"Generated {num_random_strings} random strings ({self.config.random_string_percentage:.1%} of dataset)"
|
|
)
|
|
else:
|
|
# No random strings, use only real words
|
|
all_strings = filtered_words
|
|
random_strings = []
|
|
|
|
# Shuffle all strings for randomness
|
|
random.shuffle(all_strings)
|
|
|
|
# Create train/test split using configured ratio
|
|
split_point = int(self.config.train_test_split * len(all_strings))
|
|
self.train_words = all_strings[:split_point]
|
|
self.test_words = all_strings[split_point:]
|
|
|
|
# Log dataset statistics
|
|
if self.config.random_string_percentage > 0.0:
|
|
self.logger.info(
|
|
f"Total dataset: {len(all_strings)} strings ({len(filtered_words)} real words + {len(random_strings)} random strings)" # noqa
|
|
)
|
|
else:
|
|
self.logger.info(f"Loaded {len(all_strings)} words total")
|
|
self.logger.info(f"Training strings: {len(self.train_words)}")
|
|
self.logger.info(f"Test strings: {len(self.test_words)}")
|
|
self.logger.info(f"Example strings: {self.train_words[:10]}")
|
|
|
|
# Log configuration details
|
|
self.logger.info(
|
|
f"Word length range: {self.config.min_word_length}-{self.config.max_word_length}"
|
|
)
|
|
if self.config.random_string_percentage > 0.0:
|
|
self.logger.info(
|
|
f"Random string length range: {self.config.random_string_min_length}-{self.config.random_string_max_length}" # noqa
|
|
)
|
|
self.logger.info(f"Train/test split: {self.config.train_test_split:.2%}")
|
|
self.logger.info(
|
|
f"Letter set: {'all 26 letters' if self.config.use_all_letters else f'custom ({self.config.custom_letters})'}" # noqa
|
|
)
|
|
self.logger.info(
|
|
f"Random strings: {self.config.random_string_percentage:.1%} of dataset"
|
|
)
|
|
|
|
# Log capitalization settings
|
|
if (
|
|
self.config.uppercase_word_percentage > 0.0
|
|
or self.config.capitalized_word_percentage > 0.0
|
|
):
|
|
self.logger.info(
|
|
f"Word capitalization: {self.config.uppercase_word_percentage:.1%} uppercase, {self.config.capitalized_word_percentage:.1%} title case" # noqa
|
|
)
|
|
|
|
self.logger.info(f"Random seed: {self.config.random_seed}")
|
|
|
|
# Log letter selection bias configuration
|
|
self.logger.info(
|
|
f"Letter selection bias: {self.config.present_letter_bias:.1%} toward letters present in text"
|
|
)
|
|
|
|
# Log data dumping configuration
|
|
if self.config.dump_rollouts:
|
|
self.logger.info(
|
|
"Data dumping enabled - saving groups with appropriate difficulty for training"
|
|
)
|
|
self.logger.info(
|
|
f"Group difficulty threshold: group average ≤ {self.config.max_group_average_for_training}"
|
|
)
|
|
self.logger.info(f"Data dumps directory: {self.datadumps_dir}")
|
|
self.logger.info(
|
|
f"Batch size: {self.config.dump_batch_size} groups per file"
|
|
)
|
|
|
|
self.logger.info("Letter counting environment setup complete")
|
|
|
|
async def _save_rollouts_to_jsonl(self):
|
|
"""Saves the buffered rollouts to a JSONL file in the datadumps directory."""
|
|
if not self.rollouts_to_save_buffer:
|
|
self.logger.warning("_save_rollouts_to_jsonl called but buffer is empty!")
|
|
return
|
|
|
|
buffer_size = len(self.rollouts_to_save_buffer)
|
|
self.logger.info(f"Starting save of {buffer_size} groups to JSONL file...")
|
|
|
|
try:
|
|
if not os.path.exists(self.datadumps_dir):
|
|
os.makedirs(self.datadumps_dir)
|
|
self.logger.info(f"Created directory: {self.datadumps_dir}")
|
|
except OSError as e:
|
|
self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}")
|
|
return
|
|
|
|
file_path = os.path.join(
|
|
self.datadumps_dir,
|
|
f"letter_counting_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
|
|
)
|
|
|
|
try:
|
|
with open(file_path, "w") as f:
|
|
for rollout_dict in self.rollouts_to_save_buffer:
|
|
json.dump(rollout_dict, f)
|
|
f.write("\n")
|
|
self.logger.info(
|
|
f"Successfully saved {buffer_size} groups to {file_path} "
|
|
f"(batch #{self.save_file_batch_num})"
|
|
)
|
|
self.rollouts_to_save_buffer.clear()
|
|
self.save_file_batch_num += 1
|
|
self.logger.info(
|
|
f"Buffer cleared. Next batch will be #{self.save_file_batch_num}"
|
|
)
|
|
except IOError as e:
|
|
self.logger.error(f"Error writing rollouts to {file_path}: {e}")
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"An unexpected error occurred while saving rollouts to {file_path}: {e}"
|
|
)
|
|
|
|
def _get_letter_set(self):
|
|
"""
|
|
Get the set of letters to choose from based on configuration.
|
|
|
|
Returns:
|
|
String containing letters to choose from (lowercase only for consistency)
|
|
"""
|
|
if not self.config.use_all_letters:
|
|
return self.config.custom_letters.lower()
|
|
else:
|
|
return "abcdefghijklmnopqrstuvwxyz"
|
|
|
|
def _select_target_letters(self, text: str, num_letters: int) -> List[str]:
|
|
"""
|
|
Select target letters with bias toward letters present in the text.
|
|
|
|
Args:
|
|
text: The text to analyze for present letters
|
|
num_letters: Number of letters to select
|
|
|
|
Returns:
|
|
List of selected target letters
|
|
"""
|
|
available_letters = list(self._get_letter_set())
|
|
|
|
# Prepare text for counting to match the actual counting logic
|
|
text_for_counting = self._prepare_text_for_counting(text)
|
|
|
|
# Find letters that are present in the text (case-insensitive)
|
|
present_letters = []
|
|
absent_letters = []
|
|
|
|
for letter in available_letters:
|
|
if letter.lower() in text_for_counting.lower():
|
|
present_letters.append(letter)
|
|
else:
|
|
absent_letters.append(letter)
|
|
|
|
# If we need more letters than are present, we'll need some absent ones too
|
|
if num_letters > len(present_letters):
|
|
# Select all present letters and fill the rest randomly from absent letters
|
|
selected_letters = present_letters.copy()
|
|
remaining_needed = num_letters - len(present_letters)
|
|
if remaining_needed > 0 and absent_letters:
|
|
selected_letters.extend(
|
|
random.sample(
|
|
absent_letters, min(remaining_needed, len(absent_letters))
|
|
)
|
|
)
|
|
return selected_letters[:num_letters]
|
|
|
|
# We have enough present letters, so apply the bias
|
|
selected_letters = []
|
|
|
|
for _ in range(num_letters):
|
|
# Decide whether to pick from present or absent letters based on bias
|
|
if present_letters and (
|
|
not absent_letters or random.random() < self.config.present_letter_bias
|
|
):
|
|
# Choose from present letters
|
|
chosen_letter = random.choice(present_letters)
|
|
present_letters.remove(chosen_letter)
|
|
elif absent_letters:
|
|
# Choose from absent letters
|
|
chosen_letter = random.choice(absent_letters)
|
|
absent_letters.remove(chosen_letter)
|
|
else:
|
|
# Fallback: choose from any remaining available letters
|
|
remaining_letters = [
|
|
l for l in available_letters if l not in selected_letters # noqa
|
|
]
|
|
if remaining_letters:
|
|
chosen_letter = random.choice(remaining_letters)
|
|
else:
|
|
break # No more letters available
|
|
|
|
selected_letters.append(chosen_letter)
|
|
|
|
return selected_letters
|
|
|
|
def _generate_random_string(self, length: int) -> str:
|
|
"""
|
|
Generate a random string of specified length with at least 80% alphabetical characters.
|
|
|
|
Args:
|
|
length: Length of the string to generate
|
|
|
|
Returns:
|
|
Random string with mix of uppercase, lowercase, and some non-alphabetical chars
|
|
"""
|
|
# Ensure at least 80% alphabetical characters
|
|
min_alpha_chars = max(1, int(length * 0.8))
|
|
|
|
# Character sets
|
|
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
non_alpha = "0123456789!@#$%^&*()-_=+[]{}|;:,.<>?"
|
|
|
|
result = []
|
|
|
|
# First, add the required alphabetical characters
|
|
for _ in range(min_alpha_chars):
|
|
result.append(random.choice(letters))
|
|
|
|
# Fill the rest with either alphabetical or non-alphabetical characters
|
|
remaining_length = length - min_alpha_chars
|
|
for _ in range(remaining_length):
|
|
# 90% chance of alphabetical even for remaining chars (to exceed 80% minimum)
|
|
if random.random() < 0.9:
|
|
result.append(random.choice(letters))
|
|
else:
|
|
result.append(random.choice(non_alpha))
|
|
|
|
# Shuffle to avoid alphabetical chars being clustered at the beginning
|
|
random.shuffle(result)
|
|
return "".join(result)
|
|
|
|
def _generate_random_strings(self, count: int) -> List[str]:
|
|
"""
|
|
Generate a list of random strings with lengths within the configured random string range.
|
|
|
|
Args:
|
|
count: Number of random strings to generate
|
|
|
|
Returns:
|
|
List of random strings
|
|
"""
|
|
random_strings = []
|
|
for _ in range(count):
|
|
# Generate random length within configured random string range
|
|
length = random.randint(
|
|
self.config.random_string_min_length,
|
|
self.config.random_string_max_length,
|
|
)
|
|
random_string = self._generate_random_string(length)
|
|
random_strings.append(random_string)
|
|
return random_strings
|
|
|
|
def _apply_word_capitalization(self, words: List[str]) -> List[str]:
|
|
"""
|
|
Apply capitalization transformations to real words based on configuration.
|
|
|
|
Args:
|
|
words: List of lowercase words
|
|
|
|
Returns:
|
|
List of words with applied capitalization
|
|
"""
|
|
if (
|
|
self.config.uppercase_word_percentage == 0.0
|
|
and self.config.capitalized_word_percentage == 0.0
|
|
):
|
|
return words
|
|
|
|
result = []
|
|
for word in words:
|
|
rand_val = random.random()
|
|
|
|
if rand_val < self.config.uppercase_word_percentage:
|
|
# Make uppercase
|
|
result.append(word.upper())
|
|
elif (
|
|
rand_val
|
|
< self.config.uppercase_word_percentage
|
|
+ self.config.capitalized_word_percentage
|
|
):
|
|
# Capitalize first letter
|
|
result.append(word.capitalize())
|
|
else:
|
|
# Keep lowercase
|
|
result.append(word)
|
|
|
|
return result
|
|
|
|
def _extract_text_passages(self, text: str) -> List[str]:
|
|
"""
|
|
Extract text passages from raw text based on character length.
|
|
|
|
Args:
|
|
text: Raw text to extract passages from
|
|
|
|
Returns:
|
|
List of filtered text passages
|
|
"""
|
|
# Clean the text - remove excessive whitespace and normalize
|
|
text = re.sub(r"\s+", " ", text.strip())
|
|
|
|
# If text is shorter than min length, skip it
|
|
if len(text) < self.config.min_text_length:
|
|
return []
|
|
|
|
# If text is within range, use it as-is
|
|
if len(text) <= self.config.max_text_length:
|
|
return [text]
|
|
|
|
# For longer texts, create overlapping chunks
|
|
passages = []
|
|
chunk_size = self.config.max_text_length
|
|
overlap = min(50, chunk_size // 4) # 25% overlap, max 50 chars
|
|
|
|
start = 0
|
|
while start < len(text):
|
|
end = start + chunk_size
|
|
|
|
# If this would be the last chunk and it's too small, extend the previous chunk
|
|
if end >= len(text):
|
|
if len(text) - start >= self.config.min_text_length:
|
|
passages.append(text[start:])
|
|
break
|
|
|
|
# Try to break at a natural boundary (space, punctuation)
|
|
chunk = text[start:end]
|
|
|
|
# Look for a good break point in the last 20% of the chunk
|
|
break_start = int(len(chunk) * 0.8)
|
|
break_candidates = []
|
|
|
|
# Find sentence endings first
|
|
for i in range(len(chunk) - 1, break_start - 1, -1):
|
|
if chunk[i] in ".!?":
|
|
break_candidates.append(i + 1)
|
|
break
|
|
|
|
# If no sentence ending, look for other punctuation
|
|
if not break_candidates:
|
|
for i in range(len(chunk) - 1, break_start - 1, -1):
|
|
if chunk[i] in ",;:":
|
|
break_candidates.append(i + 1)
|
|
break
|
|
|
|
# If no punctuation, look for spaces
|
|
if not break_candidates:
|
|
for i in range(len(chunk) - 1, break_start - 1, -1):
|
|
if chunk[i] == " ":
|
|
break_candidates.append(i)
|
|
break
|
|
|
|
# Use the break point if found, otherwise use the full chunk
|
|
if break_candidates:
|
|
actual_end = start + break_candidates[0]
|
|
passage = text[start:actual_end].strip()
|
|
else:
|
|
passage = chunk.strip()
|
|
actual_end = end
|
|
|
|
# Only add if it meets minimum length
|
|
if len(passage) >= self.config.min_text_length:
|
|
passages.append(passage)
|
|
|
|
# Move start position with overlap
|
|
start = actual_end - overlap
|
|
|
|
# Avoid infinite loops
|
|
if start >= end - overlap:
|
|
start = end
|
|
|
|
return passages
|
|
|
|
def _prepare_text_for_counting(self, text: str) -> str:
|
|
"""
|
|
Prepare text for letter counting based on configuration.
|
|
|
|
Args:
|
|
text: Original text
|
|
|
|
Returns:
|
|
Processed text for counting
|
|
"""
|
|
if not self.config.include_punctuation_in_count:
|
|
# Remove punctuation but keep spaces and alphanumeric
|
|
text = "".join(c for c in text if c.isalnum() or c.isspace())
|
|
|
|
if not self.config.include_spaces_in_count:
|
|
# Remove spaces
|
|
text = text.replace(" ", "")
|
|
|
|
return text
|
|
|
|
async def _validate_config(self):
|
|
"""Validate configuration parameters."""
|
|
if not (0.0 <= self.config.random_string_percentage <= 1.0):
|
|
raise ValueError(
|
|
f"random_string_percentage must be between 0.0 and 1.0, got {self.config.random_string_percentage}"
|
|
) # noqa
|
|
if not (0.0 <= self.config.uppercase_word_percentage <= 1.0):
|
|
raise ValueError(
|
|
f"uppercase_word_percentage must be between 0.0 and 1.0, got {self.config.uppercase_word_percentage}"
|
|
) # noqa
|
|
if not (0.0 <= self.config.capitalized_word_percentage <= 1.0):
|
|
raise ValueError(
|
|
f"capitalized_word_percentage must be between 0.0 and 1.0, got {self.config.capitalized_word_percentage}" # noqa
|
|
) # noqa
|
|
if (
|
|
self.config.uppercase_word_percentage
|
|
+ self.config.capitalized_word_percentage
|
|
> 1.0
|
|
):
|
|
raise ValueError(
|
|
f"Sum of uppercase_word_percentage ({self.config.uppercase_word_percentage}) and capitalized_word_percentage ({self.config.capitalized_word_percentage}) cannot exceed 1.0" # noqa
|
|
) # noqa
|
|
if self.config.random_string_min_length < 1:
|
|
raise ValueError(
|
|
f"random_string_min_length must be at least 1, got {self.config.random_string_min_length}"
|
|
) # noqa
|
|
if self.config.random_string_max_length < self.config.random_string_min_length:
|
|
raise ValueError(
|
|
f"random_string_max_length ({self.config.random_string_max_length}) must be >= random_string_min_length ({self.config.random_string_min_length})" # noqa
|
|
) # noqa
|
|
if self.config.use_text_passages:
|
|
if not (0.0 <= self.config.text_passage_percentage <= 1.0):
|
|
raise ValueError(
|
|
f"text_passage_percentage must be between 0.0 and 1.0, got {self.config.text_passage_percentage}" # noqa
|
|
)
|
|
if self.config.min_text_length < 1:
|
|
raise ValueError(
|
|
f"min_text_length must be at least 1, got {self.config.min_text_length}" # noqa
|
|
)
|
|
if self.config.max_text_length < self.config.min_text_length:
|
|
raise ValueError(
|
|
f"max_text_length ({self.config.max_text_length}) must be >= min_text_length ({self.config.min_text_length})" # noqa
|
|
)
|
|
if self.config.max_text_length < 10:
|
|
raise ValueError(
|
|
f"max_text_length must be at least 10 characters, got {self.config.max_text_length}" # noqa
|
|
)
|
|
if self.config.max_letters_to_count < 1:
|
|
raise ValueError(
|
|
f"max_letters_to_count must be at least 1, got {self.config.max_letters_to_count}"
|
|
)
|
|
if not (0.0 <= self.config.multi_letter_probability <= 1.0):
|
|
raise ValueError(
|
|
f"multi_letter_probability must be between 0.0 and 1.0, got {self.config.multi_letter_probability}"
|
|
)
|
|
if self.config.max_letters_to_count > 26:
|
|
raise ValueError(
|
|
f"max_letters_to_count cannot exceed 26 (total letters), got {self.config.max_letters_to_count}"
|
|
)
|
|
if not (0.0 <= self.config.present_letter_bias <= 1.0):
|
|
raise ValueError(
|
|
f"present_letter_bias must be between 0.0 and 1.0, got {self.config.present_letter_bias}"
|
|
)
|
|
if self.config.dump_batch_size < 1:
|
|
raise ValueError(
|
|
f"dump_batch_size must be at least 1, got {self.config.dump_batch_size}"
|
|
)
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save checkpoint including current iteration number, statistics, and data dumping state."""
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
data["processed_item_count"] = self.processed_item_count
|
|
data["save_file_batch_num"] = self.save_file_batch_num
|
|
data["letter_distribution_stats"] = self.letter_distribution_stats
|
|
data["word_length_stats"] = self.word_length_stats
|
|
data["answer_format_errors"] = self.answer_format_errors
|
|
data["think_format_errors"] = self.think_format_errors
|
|
super().save_checkpoint(step, data)
|
|
|
|
def load_checkpoint(self):
|
|
"""Load checkpoint including iteration number, statistics, and data dumping state."""
|
|
# Call the base class method first to load the data
|
|
super().load_checkpoint()
|
|
|
|
# The base class loads data into attributes, so we can access them directly
|
|
# if they were saved in save_checkpoint
|
|
if hasattr(self, "iter"):
|
|
# Restore statistics if available
|
|
if (
|
|
hasattr(self, "letter_distribution_stats")
|
|
and self.letter_distribution_stats
|
|
):
|
|
total_letters = sum(self.letter_distribution_stats.values())
|
|
self.logger.info(
|
|
f"Restored letter distribution stats with {total_letters} total letters"
|
|
)
|
|
if hasattr(self, "word_length_stats") and self.word_length_stats:
|
|
total_words = sum(self.word_length_stats.values())
|
|
self.logger.info(
|
|
f"Restored word length stats with {total_words} total words"
|
|
)
|
|
if hasattr(self, "answer_format_errors"):
|
|
self.logger.info(
|
|
f"Restored error counts: {self.answer_format_errors} answer format errors, {self.think_format_errors} think format errors" # noqa
|
|
) # noqa
|
|
|
|
async def close(self):
|
|
"""Clean up and save any remaining rollouts before exiting."""
|
|
self.logger.info(
|
|
"Closing LetterCountingEnv. Attempting to save any remaining rollouts..."
|
|
)
|
|
if self.config.dump_rollouts and self.rollouts_to_save_buffer:
|
|
self.logger.info(
|
|
f"FINAL SAVE: Found {len(self.rollouts_to_save_buffer)} groups in buffer. "
|
|
f"Saving final batch to disk (batch #{self.save_file_batch_num})..."
|
|
)
|
|
await self._save_rollouts_to_jsonl()
|
|
self.logger.info("Final save completed successfully.")
|
|
elif self.config.dump_rollouts:
|
|
self.logger.info("No rollouts in buffer to save upon closing.")
|
|
else:
|
|
self.logger.info("Data dumping is disabled - no rollouts to save.")
|
|
|
|
# Call the superclass's close method if it exists
|
|
if hasattr(super(), "close"):
|
|
await super().close()
|
|
self.logger.info("LetterCountingEnv closed.")
|
|
|
|
async def get_next_item(self):
|
|
"""
|
|
Get the next training item from the dataset.
|
|
|
|
Returns:
|
|
A tuple containing prompt and expected answer
|
|
"""
|
|
# Get the next text from training set (could be a word, sentence, or random string)
|
|
text = self.train_words[self.iter % len(self.train_words)]
|
|
|
|
# Decide whether to use multiple letters
|
|
use_multiple = (
|
|
self.config.max_letters_to_count > 1
|
|
and random.random() < self.config.multi_letter_probability
|
|
)
|
|
|
|
if use_multiple:
|
|
# Choose 2 to max_letters_to_count different letters
|
|
num_letters = random.randint(2, self.config.max_letters_to_count)
|
|
target_letters = self._select_target_letters(text, num_letters)
|
|
else:
|
|
# Single letter counting
|
|
target_letters = self._select_target_letters(text, 1)
|
|
|
|
# Prepare text for counting (handle punctuation/spaces based on config)
|
|
text_for_counting = self._prepare_text_for_counting(text)
|
|
|
|
# Count occurrences for each target letter (case-insensitive)
|
|
correct_counts = {}
|
|
for letter in target_letters:
|
|
correct_counts[letter] = text_for_counting.lower().count(letter.lower())
|
|
|
|
# Determine if this is a text passage or word/string based on length and content
|
|
is_text_passage = (
|
|
len(text) > 50 or " " in text or any(c in text for c in ".,!?;:")
|
|
)
|
|
|
|
# Log item selection details for every item
|
|
text_type = "passage" if is_text_passage else "word/string"
|
|
text_preview = text[:50] + "..." if len(text) > 50 else text
|
|
letters_str = ", ".join(target_letters)
|
|
counts_str = ", ".join(
|
|
f"{letter}:{correct_counts[letter]}" for letter in target_letters
|
|
)
|
|
|
|
# Add information about present vs absent letters for bias analysis
|
|
present_count = sum(
|
|
1 for letter in target_letters if correct_counts[letter] > 0
|
|
)
|
|
absent_count = len(target_letters) - present_count
|
|
bias_info = f"Present: {present_count}, Absent: {absent_count}"
|
|
|
|
self.logger.info(
|
|
f"Selected {text_type}: '{text_preview}' | Letters: [{letters_str}] | Counts: [{counts_str}] | {bias_info} (iteration {self.iter})" # noqa
|
|
)
|
|
|
|
self.iter += 1
|
|
|
|
# Create the question based on whether this item is a text passage or word/string and single/multiple letters
|
|
if len(target_letters) == 1:
|
|
# Single letter question
|
|
target_letter = target_letters[0]
|
|
if is_text_passage:
|
|
question_text = (
|
|
f'How many {target_letter}s are in the following text: "{text}"?'
|
|
)
|
|
else:
|
|
question_text = f"How many {target_letter}s are in the string {text}?"
|
|
|
|
# Add instruction for single letter answer format
|
|
question_with_instruction = f"{question_text}\n\nProvide your answer in the format: <answer>{{number}}</answer>" # noqa
|
|
else:
|
|
# Multiple letters question
|
|
letters_str = (
|
|
", ".join(f"'{letter}'" for letter in target_letters[:-1])
|
|
+ f", and '{target_letters[-1]}'"
|
|
)
|
|
if is_text_passage:
|
|
question_text = f'Count the occurrences of the letters {letters_str} in the following text: "{text}"'
|
|
else:
|
|
question_text = f"Count the occurrences of the letters {letters_str} in the string {text}"
|
|
|
|
# Add instruction for multiple letter JSON answer format
|
|
example_json = (
|
|
"{" + ", ".join(f'"{letter}": 0' for letter in target_letters) + "}"
|
|
)
|
|
question_with_instruction = f"{question_text}\n\nProvide your answer as JSON in the format: <answer>{example_json}</answer>" # noqa
|
|
|
|
# Create prompt tuple using frozensets as required
|
|
prompt = []
|
|
|
|
# Add system prompt
|
|
prompt.append(frozenset({"role": "system", "content": system_prompt}.items()))
|
|
|
|
# Add user message with the question
|
|
prompt.append(
|
|
frozenset({"role": "user", "content": question_with_instruction}.items())
|
|
)
|
|
|
|
# Return the prompt, correct counts, text, and target letters
|
|
return (tuple(prompt), correct_counts, text, target_letters)
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
|
|
"""
|
|
Generate and collect model responses for scoring.
|
|
|
|
Args:
|
|
item: Input item containing prompt and expected answer
|
|
|
|
Returns:
|
|
Tuple of lists containing scored data groups and backlog
|
|
"""
|
|
# Extract messages from the item
|
|
messages = []
|
|
for role_dict in item[0]:
|
|
messages.append(dict(role_dict))
|
|
|
|
# Apply chat template to convert messages to a single string
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages, add_generation_prompt=True, tokenize=False
|
|
)
|
|
|
|
# Get completions from the model
|
|
completions = await self.server.completion(
|
|
prompt=prompt,
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_generation_tokens,
|
|
temperature=self.config.generation_temperature,
|
|
)
|
|
|
|
to_score = list()
|
|
|
|
for i, completion_choice in enumerate(completions.choices):
|
|
# Create a copy of the prompt messages
|
|
trajectory_messages = []
|
|
for role_dict in item[0]:
|
|
trajectory_messages.append(dict(role_dict))
|
|
|
|
# Add the model's response
|
|
trajectory_messages.append(
|
|
{"role": "assistant", "content": completion_choice.text}
|
|
)
|
|
|
|
# Add to scoring queue with expected answer and metadata
|
|
to_score.append(
|
|
(
|
|
tuple(trajectory_messages),
|
|
item[1], # correct_counts (dict)
|
|
item[2], # text (word or sentence)
|
|
item[3], # target_letters (list)
|
|
completion_choice.finish_reason, # stop reason
|
|
)
|
|
)
|
|
|
|
# Call score to get the scored data
|
|
scored_data = await self.score(to_score)
|
|
|
|
# Data dumping logic - save groups that meet the accuracy threshold BEFORE training filtering
|
|
# This ensures we save high-quality data even if it's too easy for training
|
|
if self.config.dump_rollouts:
|
|
# First, we need to score the data to get the scores, but we'll do our own scoring for data dumping
|
|
temp_scores = []
|
|
for score_item in to_score:
|
|
# Extract the model's response and score it
|
|
model_response = score_item[0][-1]["content"]
|
|
stop_reason = score_item[4]
|
|
expected_counts = score_item[1]
|
|
target_letters = score_item[3]
|
|
|
|
# Handle legacy format
|
|
if isinstance(target_letters, str):
|
|
target_letters = [target_letters]
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
elif isinstance(expected_counts, int):
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
|
|
# Score this individual response
|
|
if stop_reason == "length":
|
|
temp_scores.append(0.0)
|
|
else:
|
|
expected_format = "single" if len(target_letters) == 1 else "multi"
|
|
model_answer = self._extract_answer(model_response, expected_format)
|
|
|
|
if model_answer is None:
|
|
temp_scores.append(0.0)
|
|
else:
|
|
if expected_format == "single":
|
|
expected_single_count = expected_counts[target_letters[0]]
|
|
temp_scores.append(
|
|
1.0 if model_answer == expected_single_count else 0.0
|
|
)
|
|
else:
|
|
if set(model_answer.keys()) == set(target_letters) and all(
|
|
model_answer.get(letter, -1) == expected_counts[letter]
|
|
for letter in target_letters
|
|
):
|
|
temp_scores.append(1.0)
|
|
else:
|
|
temp_scores.append(0.0)
|
|
|
|
# Check if group has appropriate difficulty for training (group average within training range)
|
|
group_average_score = (
|
|
sum(temp_scores) / len(temp_scores) if temp_scores else 0.0
|
|
)
|
|
|
|
# Skip groups where all scores are identical (no learning signal, including all 0.0)
|
|
if temp_scores and all(temp_scores[0] == score for score in temp_scores):
|
|
self.logger.debug(
|
|
f"Skipping group save - all scores identical ({temp_scores[0]:.3f}) - no learning signal"
|
|
)
|
|
elif group_average_score <= self.config.max_group_average_for_training:
|
|
self.logger.debug(
|
|
f"Saving group with scores: {[f'{s:.3f}' for s in temp_scores]} (group_avg: {group_average_score:.3f} <= {self.config.max_group_average_for_training})" # noqa
|
|
)
|
|
rollouts_with_scores_to_save = []
|
|
|
|
for i, score_for_rollout in enumerate(temp_scores):
|
|
conversation_messages = to_score[i][0]
|
|
correct_counts = to_score[i][1]
|
|
text = to_score[i][2]
|
|
target_letters = to_score[i][3]
|
|
stop_reason = to_score[i][4]
|
|
|
|
rollouts_with_scores_to_save.append(
|
|
{
|
|
"conversation": conversation_messages,
|
|
"score": score_for_rollout,
|
|
"expected_counts": correct_counts,
|
|
"text": text,
|
|
"target_letters": target_letters,
|
|
"stop_reason": stop_reason,
|
|
"group_average_score": group_average_score,
|
|
}
|
|
)
|
|
|
|
if rollouts_with_scores_to_save:
|
|
# Extract item info for logging
|
|
correct_counts = item[1]
|
|
text = item[2]
|
|
target_letters = item[3]
|
|
text_preview = (
|
|
text[:30].replace(" ", "_")
|
|
if len(text) > 30
|
|
else text.replace(" ", "_")
|
|
)
|
|
letters_str = "_".join(target_letters)
|
|
counts_str = "_".join(
|
|
str(correct_counts.get(letter, 0)) for letter in target_letters
|
|
)
|
|
item_id = f"{text_preview}_{letters_str}_{counts_str}"
|
|
|
|
item_data_to_save = {
|
|
"item_id": item_id,
|
|
"rollouts": rollouts_with_scores_to_save,
|
|
}
|
|
self.rollouts_to_save_buffer.append(item_data_to_save)
|
|
self.processed_item_count += 1
|
|
|
|
# Log every single sample added to buffer
|
|
self.logger.info(
|
|
f"BUFFER ADD: Added item '{item_id}' to buffer. "
|
|
f"Buffer now contains {len(self.rollouts_to_save_buffer)} groups "
|
|
f"(target batch size: {self.config.dump_batch_size})"
|
|
)
|
|
|
|
# Log progress every 10 items
|
|
if self.processed_item_count % 10 == 0:
|
|
self.logger.info(
|
|
f"Data dump progress: {self.processed_item_count} items processed "
|
|
f"(Buffer size: {len(self.rollouts_to_save_buffer)})"
|
|
)
|
|
|
|
# Save in batches when buffer reaches the configured size
|
|
if (
|
|
self.config.dump_rollouts
|
|
and len(self.rollouts_to_save_buffer) >= self.config.dump_batch_size
|
|
):
|
|
self.logger.info(
|
|
f"Buffer reached batch size ({len(self.rollouts_to_save_buffer)}/{self.config.dump_batch_size}). " # noqa
|
|
f"Saving batch to disk..."
|
|
)
|
|
await self._save_rollouts_to_jsonl()
|
|
|
|
# Safety mechanism: save every 50 items processed to prevent data loss
|
|
elif (
|
|
self.config.dump_rollouts
|
|
and self.processed_item_count % 50 == 0
|
|
and len(self.rollouts_to_save_buffer) > 0
|
|
):
|
|
self.logger.info(
|
|
f"Safety save: {self.processed_item_count} items processed. "
|
|
f"Saving {len(self.rollouts_to_save_buffer)} groups to prevent data loss..."
|
|
)
|
|
await self._save_rollouts_to_jsonl()
|
|
else:
|
|
self.logger.debug(
|
|
f"Skipping group save - group too easy for training (group_avg: {group_average_score:.3f} > {self.config.max_group_average_for_training})" # noqa
|
|
)
|
|
|
|
to_backlog = []
|
|
return scored_data, to_backlog
|
|
|
|
def _extract_answer(self, text, expected_format="single"):
|
|
"""
|
|
Extract the answer from model response (single number or JSON).
|
|
Only allows one valid answer format - multiple answer formats result in a score of 0.
|
|
|
|
Args:
|
|
text: Text containing the model's response
|
|
expected_format: "single" for number, "multi" for JSON
|
|
|
|
Returns:
|
|
Extracted answer (int for single, dict for multi) or None if invalid
|
|
"""
|
|
# Check for multiple <think> tags - score as 0 if found
|
|
think_tags = re.findall(r"<think>", text, re.IGNORECASE)
|
|
if len(think_tags) > 1:
|
|
return None
|
|
|
|
# Check if the think tag is properly opened - we need exactly one opening tag
|
|
if len(think_tags) != 1:
|
|
return None
|
|
|
|
# Check for </think> closing tags
|
|
think_close_tags = re.findall(r"</think>", text, re.IGNORECASE)
|
|
if len(think_close_tags) != 1:
|
|
return None # Must have exactly one closing tag
|
|
|
|
# Split the text into thinking and answer sections
|
|
parts = re.split(r"</think>", text, flags=re.IGNORECASE, maxsplit=1)
|
|
|
|
# If there's no </think> tag or multiple sections, return None
|
|
if len(parts) != 2:
|
|
return None
|
|
|
|
thinking_section, answer_section = parts
|
|
|
|
# Validate thinking section
|
|
# Make sure thinking section actually contains the opening <think> tag
|
|
if "<think>" not in thinking_section.lower():
|
|
return None # Malformed thinking section
|
|
|
|
# Check if there are any <think> tags in the answer section (after the first </think>)
|
|
if "<think>" in answer_section.lower():
|
|
return None
|
|
|
|
# Look for answer tags in the answer section
|
|
if expected_format == "single":
|
|
# Single number format
|
|
answer_pattern = r"<answer>\s*(\d+)\s*</answer>"
|
|
answer_matches = re.findall(answer_pattern, answer_section, re.IGNORECASE)
|
|
|
|
# If no answers found or multiple answers found, return None
|
|
if len(answer_matches) != 1:
|
|
return None
|
|
|
|
# Return the single found answer as an integer
|
|
try:
|
|
return int(answer_matches[0])
|
|
except ValueError:
|
|
return None
|
|
else:
|
|
# Multi-letter JSON format
|
|
answer_pattern = r"<answer>\s*(\{[^}]+\})\s*</answer>"
|
|
answer_matches = re.findall(answer_pattern, answer_section, re.IGNORECASE)
|
|
|
|
# If no answers found or multiple answers found, return None
|
|
if len(answer_matches) != 1:
|
|
return None
|
|
|
|
# Try to parse the JSON
|
|
try:
|
|
import json
|
|
|
|
answer_dict = json.loads(answer_matches[0])
|
|
|
|
# Validate that all values are integers
|
|
if not isinstance(answer_dict, dict):
|
|
return None
|
|
|
|
for key, value in answer_dict.items():
|
|
if not isinstance(key, str) or not isinstance(value, int):
|
|
return None
|
|
|
|
return answer_dict
|
|
except (json.JSONDecodeError, ValueError):
|
|
return None
|
|
|
|
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
|
"""
|
|
Score the generated model responses against expected letter counts.
|
|
|
|
Args:
|
|
rollout_group_data: List of generated responses with expected answers
|
|
|
|
Returns:
|
|
ScoredDataGroup with tokenized inputs and scores, or None if no valid scores
|
|
"""
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = list()
|
|
scores["masks"] = list()
|
|
scores["scores"] = list()
|
|
|
|
if not rollout_group_data:
|
|
return None
|
|
|
|
# Get the expected answer from first item
|
|
expected_counts = rollout_group_data[0][
|
|
1
|
|
] # correct counts (dict for multi, int for single - legacy)
|
|
text = rollout_group_data[0][2] # text (word or sentence)
|
|
target_letters = rollout_group_data[0][3] # target letters (list)
|
|
|
|
# Handle legacy format (single letter as string, single count as int)
|
|
if isinstance(target_letters, str):
|
|
target_letters = [target_letters]
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
elif isinstance(expected_counts, int):
|
|
# Legacy format with single letter
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
|
|
# Track statistics for all target letters
|
|
for target_letter in target_letters:
|
|
if target_letter not in self.letter_distribution_stats:
|
|
self.letter_distribution_stats[target_letter] = 0
|
|
self.letter_distribution_stats[target_letter] += 1
|
|
|
|
text_len = len(text)
|
|
if text_len not in self.word_length_stats:
|
|
self.word_length_stats[text_len] = 0
|
|
self.word_length_stats[text_len] += 1
|
|
|
|
# Shuffle to avoid bias in selection
|
|
random.shuffle(rollout_group_data)
|
|
|
|
format_errors_in_group = 0
|
|
think_errors_in_group = 0
|
|
|
|
for item in rollout_group_data:
|
|
# Extract the model's response
|
|
model_response = item[0][-1]["content"]
|
|
stop_reason = item[4] # Get the stop reason
|
|
|
|
# If the response was cut off due to length, give it a score of 0
|
|
if stop_reason == "length":
|
|
reward = 0
|
|
if self.config.debug_logging:
|
|
letters_str = ", ".join(target_letters)
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letters '{letters_str}': Length cutoff, score=0"
|
|
)
|
|
else:
|
|
# Determine expected format and extract the answer
|
|
expected_format = "single" if len(target_letters) == 1 else "multi"
|
|
model_answer = self._extract_answer(model_response, expected_format)
|
|
|
|
# Track metrics based on result
|
|
if model_answer is None:
|
|
reward = 0 # Invalid format gets 0 reward
|
|
format_errors_in_group += 1
|
|
# Check if it's a think format error
|
|
if (
|
|
"<think>" not in model_response.lower()
|
|
or "</think>" not in model_response.lower()
|
|
):
|
|
think_errors_in_group += 1
|
|
if self.config.debug_logging:
|
|
letters_str = ", ".join(target_letters)
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letters '{letters_str}': Format error, score=0"
|
|
)
|
|
else:
|
|
# Check if answer matches expected counts
|
|
if expected_format == "single":
|
|
# Single letter: compare integer
|
|
expected_single_count = expected_counts[target_letters[0]]
|
|
if model_answer == expected_single_count:
|
|
reward = 1
|
|
if self.config.debug_logging:
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letter '{target_letters[0]}': Correct answer {model_answer}, score=1" # noqa
|
|
)
|
|
else:
|
|
reward = 0
|
|
if self.config.debug_logging:
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letter '{target_letters[0]}': Wrong answer {model_answer} (expected {expected_single_count}), score=0" # noqa
|
|
)
|
|
else:
|
|
# Multiple letters: compare dictionaries
|
|
# Check if all expected letters are present and counts match
|
|
if set(model_answer.keys()) == set(target_letters) and all(
|
|
model_answer.get(letter, -1) == expected_counts[letter]
|
|
for letter in target_letters
|
|
):
|
|
reward = 1
|
|
if self.config.debug_logging:
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letters {target_letters}: Correct answer {model_answer}, score=1" # noqa
|
|
)
|
|
else:
|
|
reward = 0
|
|
if self.config.debug_logging:
|
|
self.logger.debug(
|
|
f"Text '{text[:50]}...' letters {target_letters}: Wrong answer {model_answer} (expected {expected_counts}), score=0" # noqa
|
|
)
|
|
|
|
# Tokenize the conversation for learning
|
|
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
|
tokens = out_dict["tokens"]
|
|
masks = out_dict["masks"]
|
|
|
|
# Remove examples with insufficient context
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(1.0 if reward else 0.0)
|
|
|
|
# Break once we have enough examples
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
if not scores["tokens"]:
|
|
letters_str = ", ".join(target_letters)
|
|
self.logger.warning(
|
|
f"No valid items were scored for text '{text[:50]}...' letters '{letters_str}' - all items had insufficient context" # noqa
|
|
)
|
|
return None
|
|
|
|
# Update global error counters
|
|
self.answer_format_errors += format_errors_in_group
|
|
self.think_format_errors += think_errors_in_group
|
|
|
|
# Record success rate metrics for wandb logging
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(score)
|
|
|
|
# Calculate and log average score for the current group
|
|
current_scores = scores.get("scores", [])
|
|
if current_scores:
|
|
average_score = sum(current_scores) / len(current_scores)
|
|
|
|
# Create log message with appropriate text preview
|
|
text_preview = text[:50] + "..." if len(text) > 50 else text
|
|
letters_str = ", ".join(target_letters)
|
|
expected_str = (
|
|
str(expected_counts)
|
|
if len(target_letters) > 1
|
|
else str(expected_counts[target_letters[0]])
|
|
)
|
|
log_message_main = (
|
|
f"Text: '{text_preview}' | Letters: '{letters_str}' | Expected: {expected_str} | "
|
|
f"Group average score: {average_score:.4f}"
|
|
)
|
|
|
|
if all(s == 1.0 for s in current_scores):
|
|
self.logger.info(f"{log_message_main} (All correct in this group!)")
|
|
elif all(s == 0.0 for s in current_scores):
|
|
self.logger.info(
|
|
f"{log_message_main} (All failed - format/answer errors!)"
|
|
)
|
|
else:
|
|
self.logger.info(log_message_main)
|
|
|
|
# Check training threshold - if group is too easy, skip it for training
|
|
if average_score > self.config.max_group_average_for_training:
|
|
self.logger.debug(
|
|
f"Skipping group for training - too easy (avg: {average_score:.3f} > threshold: {self.config.max_group_average_for_training})" # noqa
|
|
)
|
|
return None
|
|
|
|
# Check if all scores are the same (no learning signal)
|
|
if all(scores["scores"][0] == score for score in scores["scores"]):
|
|
self.logger.debug(
|
|
f"All scores in group are identical ({scores['scores'][0]:.4f}) - no learning signal, skipping group"
|
|
)
|
|
return None
|
|
|
|
return scores
|
|
|
|
async def rollout_and_score_eval(self, test_text):
|
|
"""
|
|
Generate and score model responses for a single test text.
|
|
|
|
Args:
|
|
test_text: Test text from dataset (could be word, sentence, or random string)
|
|
|
|
Returns:
|
|
Score (1 for correct, 0 for incorrect)
|
|
"""
|
|
# Decide whether to use multiple letters (same logic as get_next_item)
|
|
use_multiple = (
|
|
self.config.max_letters_to_count > 1
|
|
and random.random() < self.config.multi_letter_probability
|
|
)
|
|
|
|
if use_multiple:
|
|
# Choose 2 to max_letters_to_count different letters
|
|
num_letters = random.randint(2, self.config.max_letters_to_count)
|
|
target_letters = self._select_target_letters(test_text, num_letters)
|
|
else:
|
|
# Single letter counting
|
|
target_letters = self._select_target_letters(test_text, 1)
|
|
|
|
# Prepare text for counting (handle punctuation/spaces based on config)
|
|
text_for_counting = self._prepare_text_for_counting(test_text)
|
|
|
|
# Count occurrences for each target letter (case-insensitive)
|
|
expected_counts = {}
|
|
for letter in target_letters:
|
|
expected_counts[letter] = text_for_counting.lower().count(letter.lower())
|
|
|
|
# Determine if this is a text passage or word/string based on length and content
|
|
is_text_passage = (
|
|
len(test_text) > 50
|
|
or " " in test_text
|
|
or any(c in test_text for c in ".,!?;:")
|
|
)
|
|
|
|
# Create the question based on whether this item is a text passage or word/string and single/multiple letters
|
|
if len(target_letters) == 1:
|
|
# Single letter question
|
|
target_letter = target_letters[0]
|
|
if is_text_passage:
|
|
question_text = f'How many {target_letter}s are in the following text: "{test_text}"?'
|
|
else:
|
|
question_text = (
|
|
f"How many {target_letter}s are in the string {test_text}?"
|
|
)
|
|
|
|
# Add instruction for single letter answer format
|
|
question_with_instruction = f"{question_text}\n\nProvide your answer in the format: <answer>{{number}}</answer>" # noqa
|
|
else:
|
|
# Multiple letters question
|
|
letters_str = (
|
|
", ".join(f"'{letter}'" for letter in target_letters[:-1])
|
|
+ f", and '{target_letters[-1]}'" # noqa
|
|
)
|
|
if is_text_passage:
|
|
question_text = f'Count the occurrences of the letters {letters_str} in the following text: "{test_text}"' # noqa
|
|
else:
|
|
question_text = f"Count the occurrences of the letters {letters_str} in the string {test_text}" # noqa
|
|
|
|
# Add instruction for multiple letter JSON answer format
|
|
example_json = (
|
|
"{" + ", ".join(f'"{letter}": 0' for letter in target_letters) + "}"
|
|
)
|
|
question_with_instruction = f"{question_text}\n\nProvide your answer as JSON in the format: <answer>{example_json}</answer>" # noqa
|
|
|
|
# Create messages for model
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": question_with_instruction},
|
|
]
|
|
|
|
# Apply chat template to convert messages to a single string
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages, add_generation_prompt=True, tokenize=False
|
|
)
|
|
|
|
# Get model completion
|
|
completion = await self.server.completion(
|
|
prompt=prompt,
|
|
n=1,
|
|
max_tokens=self.config.max_generation_tokens,
|
|
temperature=self.config.eval_temperature,
|
|
split="eval",
|
|
)
|
|
|
|
# Extract the model's response from the completion
|
|
model_response = completion.choices[0].text
|
|
|
|
# Determine expected format and extract the answer
|
|
expected_format = "single" if len(target_letters) == 1 else "multi"
|
|
model_answer = self._extract_answer(model_response, expected_format)
|
|
|
|
# Score 1 if the answers match, 0 otherwise
|
|
if model_answer is None:
|
|
score = 0
|
|
elif expected_format == "single":
|
|
# Single letter: compare integer
|
|
expected_single_count = expected_counts[target_letters[0]]
|
|
score = 1 if model_answer == expected_single_count else 0
|
|
else:
|
|
# Multiple letters: compare dictionaries
|
|
score = (
|
|
1
|
|
if (
|
|
set(model_answer.keys()) == set(target_letters)
|
|
and all(
|
|
model_answer.get(letter, -1) == expected_counts[letter]
|
|
for letter in target_letters
|
|
)
|
|
)
|
|
else 0
|
|
)
|
|
|
|
return score
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluate the model on test data.
|
|
"""
|
|
self.logger.info("Starting evaluation...")
|
|
if not self.test_words:
|
|
self.logger.warning("No test texts available for evaluation. Skipping.")
|
|
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
|
return
|
|
|
|
eval_tasks = []
|
|
# Sample a subset of test texts for evaluation to keep it manageable
|
|
eval_texts = random.sample(
|
|
self.test_words, min(len(self.test_words), self.config.eval_sample_size)
|
|
)
|
|
|
|
text_type = (
|
|
"mixed items (words and passages)"
|
|
if self.config.use_text_passages
|
|
else "strings"
|
|
)
|
|
self.logger.info(
|
|
f"Starting evaluation on {len(eval_texts)} test {text_type}..."
|
|
)
|
|
|
|
for test_text in eval_texts:
|
|
eval_tasks.append(self.rollout_and_score_eval(test_text))
|
|
|
|
# Run evaluation
|
|
scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating")
|
|
|
|
if not scores:
|
|
percent_correct = 0.0
|
|
else:
|
|
percent_correct = sum(scores) / len(scores)
|
|
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
self.logger.info(f"Evaluation finished. Percent correct: {percent_correct:.4f}")
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
|
item: Item = None,
|
|
):
|
|
if item is None or scored_data is None or not scored_data.get("tokens"):
|
|
return
|
|
|
|
expected_counts = item[1] # correct counts (dict)
|
|
text = item[2] # text (word or sentence)
|
|
target_letters = item[3] # target letters (list)
|
|
|
|
# Handle legacy format
|
|
if isinstance(target_letters, str):
|
|
target_letters = [target_letters]
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
elif isinstance(expected_counts, int):
|
|
expected_counts = {target_letters[0]: expected_counts}
|
|
|
|
# save rollout to trajectory
|
|
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
if num_keep == -1:
|
|
num_keep = self.config.group_size
|
|
|
|
# Make sure there's data to log
|
|
num_keep = min(num_keep, len(scored_data["tokens"]))
|
|
if num_keep == 0:
|
|
return
|
|
|
|
# Calculate group average score
|
|
group_scores = scored_data.get("scores", [])
|
|
group_average_score = (
|
|
sum(group_scores) / len(group_scores) if group_scores else 0.0
|
|
)
|
|
|
|
current_rollouts = []
|
|
for i in range(num_keep):
|
|
# Ensure tokens and scores have the same length
|
|
if i < len(scored_data["tokens"]) and i < len(scored_data["scores"]):
|
|
# Decode the full trajectory including prompt and model response
|
|
full_text = self.tokenizer.decode(
|
|
scored_data["tokens"][i], skip_special_tokens=True
|
|
)
|
|
score_val = scored_data["scores"][i]
|
|
expected_str = (
|
|
str(expected_counts)
|
|
if len(target_letters) > 1
|
|
else str(expected_counts[target_letters[0]])
|
|
)
|
|
letters_str = ", ".join(target_letters)
|
|
current_rollouts.append(
|
|
(
|
|
full_text,
|
|
score_val,
|
|
expected_str,
|
|
text[:100],
|
|
letters_str,
|
|
group_average_score,
|
|
)
|
|
)
|
|
else:
|
|
self.logger.warning(
|
|
f"Mismatch in lengths of tokens/scores for wandb logging at index {i}."
|
|
)
|
|
|
|
self.rollouts_for_wandb.append(current_rollouts)
|
|
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics):
|
|
if len(self.rollouts_for_wandb) > 0:
|
|
table = wandb.Table(
|
|
columns=[
|
|
"full_text",
|
|
"score",
|
|
"expected_counts",
|
|
"text",
|
|
"target_letters",
|
|
"group_average_score",
|
|
]
|
|
)
|
|
for group in self.rollouts_for_wandb:
|
|
for item in group:
|
|
# Handle both old format (5 items) and new format (6 items)
|
|
if len(item) >= 6:
|
|
table.add_data(
|
|
item[0], item[1], item[2], item[3], item[4], item[5]
|
|
)
|
|
else:
|
|
table.add_data(item[0], item[1], item[2], item[3], item[4], 0.0)
|
|
wandb_metrics["train/rollouts"] = table
|
|
self.rollouts_for_wandb = []
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Try to calculate percent_correct, pass if there's a division by zero
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
# Skip if buffer is empty
|
|
pass
|
|
|
|
self.percent_correct_buffer = list()
|
|
|
|
# Add eval metrics
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
self.eval_metrics = list()
|
|
|
|
# Add comprehensive letter counting specific metrics
|
|
if self.letter_distribution_stats:
|
|
# Log letter distribution statistics
|
|
total_letters_asked = sum(self.letter_distribution_stats.values())
|
|
wandb_metrics["stats/total_letters_asked"] = total_letters_asked
|
|
|
|
# Log most and least common letters asked
|
|
most_common_letter = max(
|
|
self.letter_distribution_stats, key=self.letter_distribution_stats.get
|
|
)
|
|
least_common_letter = min(
|
|
self.letter_distribution_stats, key=self.letter_distribution_stats.get
|
|
)
|
|
wandb_metrics["stats/most_common_letter_count"] = (
|
|
self.letter_distribution_stats[most_common_letter]
|
|
)
|
|
wandb_metrics["stats/least_common_letter_count"] = (
|
|
self.letter_distribution_stats[least_common_letter]
|
|
)
|
|
|
|
# Log distribution entropy (measure of how evenly distributed the letters are)
|
|
import math
|
|
|
|
entropy = -sum(
|
|
(count / total_letters_asked) * math.log2(count / total_letters_asked)
|
|
for count in self.letter_distribution_stats.values()
|
|
if count > 0
|
|
)
|
|
wandb_metrics["stats/letter_distribution_entropy"] = entropy
|
|
|
|
if self.word_length_stats:
|
|
# Log word length statistics
|
|
total_words_asked = sum(self.word_length_stats.values())
|
|
wandb_metrics["stats/total_words_asked"] = total_words_asked
|
|
|
|
# Calculate average word length
|
|
avg_word_length = (
|
|
sum(length * count for length, count in self.word_length_stats.items())
|
|
/ total_words_asked
|
|
)
|
|
wandb_metrics["stats/avg_word_length"] = avg_word_length
|
|
|
|
# Log min and max word lengths seen
|
|
wandb_metrics["stats/min_word_length"] = min(self.word_length_stats.keys())
|
|
wandb_metrics["stats/max_word_length"] = max(self.word_length_stats.keys())
|
|
|
|
# Log error rates
|
|
if self.processed_item_count > 0:
|
|
wandb_metrics["errors/answer_format_error_rate"] = (
|
|
self.answer_format_errors
|
|
/ (self.processed_item_count * self.config.group_size)
|
|
)
|
|
wandb_metrics["errors/think_format_error_rate"] = (
|
|
self.think_format_errors
|
|
/ (self.processed_item_count * self.config.group_size)
|
|
)
|
|
wandb_metrics["errors/total_format_errors"] = (
|
|
self.answer_format_errors + self.think_format_errors
|
|
)
|
|
|
|
# Log data dumping progress
|
|
if self.config.dump_rollouts:
|
|
wandb_metrics["data_dumps/processed_item_count"] = self.processed_item_count
|
|
wandb_metrics["data_dumps/rollouts_buffer_size"] = len(
|
|
self.rollouts_to_save_buffer
|
|
)
|
|
wandb_metrics["data_dumps/save_file_batch_num"] = self.save_file_batch_num
|
|
wandb_metrics["data_dumps/batch_size"] = self.config.dump_batch_size
|
|
|
|
# Add rollout table
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
LetterCountingEnv.cli()
|