mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1546 lines
58 KiB
Python
1546 lines
58 KiB
Python
"""
|
|
Letter Counting Environment with Adaptive Difficulty (Curriculum Learning)
|
|
|
|
This environment trains models to count letter occurrences in words and strings
|
|
using a 10-tier difficulty system that automatically adjusts based on performance.
|
|
|
|
Evaluation uses a static external dataset from HuggingFace: NousResearch/Letter-Counting-Eval
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
import threading
|
|
import uuid
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import wandb
|
|
import yaml
|
|
from pydantic import Field
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
Item,
|
|
ScoredDataGroup,
|
|
)
|
|
|
|
# Default path to the YAML configuration file
|
|
DEFAULT_CONFIG_PATH = "config.yaml"
|
|
|
|
|
|
def get_config_path() -> str:
|
|
"""
|
|
Get the config path from CLI --config argument or use default.
|
|
|
|
Usage: python letter_counting_environment.py --config /path/to/config.yaml
|
|
|
|
Returns:
|
|
Config file path (from --config arg or DEFAULT_CONFIG_PATH)
|
|
"""
|
|
for i, arg in enumerate(sys.argv):
|
|
if arg == "--config" and i + 1 < len(sys.argv):
|
|
return sys.argv[i + 1]
|
|
if arg.startswith("--config="):
|
|
return arg.split("=", 1)[1]
|
|
return DEFAULT_CONFIG_PATH
|
|
|
|
|
|
# Import NLTK words corpus for training word list
|
|
try:
|
|
import nltk
|
|
from nltk.corpus import words
|
|
|
|
try:
|
|
words.words()
|
|
except LookupError:
|
|
nltk.download("words")
|
|
except ImportError:
|
|
print("Warning: NLTK not available. Please install with: pip install nltk")
|
|
words = None
|
|
|
|
# Import datasets for loading external eval dataset
|
|
try:
|
|
from datasets import load_dataset
|
|
except ImportError:
|
|
print("Warning: datasets library not available. Install with: pip install datasets")
|
|
load_dataset = None
|
|
|
|
|
|
# =============================================================================
|
|
# DIFFICULTY TIERS CONFIGURATION
|
|
# =============================================================================
|
|
# This dict can be modified to adjust difficulty levels.
|
|
# Each tier defines text length ranges, letter count ranges, and whether to use random strings.
|
|
# Levels 1-5 use real English words, levels 6-10 use random letter strings.
|
|
|
|
DIFFICULTY_TIERS = {
|
|
# Tier 1: Very Easy - Short words, always single letter
|
|
1: {
|
|
"min_word_length": 3,
|
|
"max_word_length": 8,
|
|
"multi_letter_probability": 0.0,
|
|
"min_letters_to_count": 1,
|
|
"max_letters_to_count": 1,
|
|
"use_random_string": False,
|
|
},
|
|
# Tier 2: Easy - Short-medium words, 50% multi (1-2 letters)
|
|
2: {
|
|
"min_word_length": 5,
|
|
"max_word_length": 12,
|
|
"multi_letter_probability": 0.5,
|
|
"min_letters_to_count": 1,
|
|
"max_letters_to_count": 2,
|
|
"use_random_string": False,
|
|
},
|
|
# Tier 3: Medium-Easy - Medium words, 60% multi (2-3 letters)
|
|
3: {
|
|
"min_word_length": 8,
|
|
"max_word_length": 18,
|
|
"multi_letter_probability": 0.6,
|
|
"min_letters_to_count": 2,
|
|
"max_letters_to_count": 3,
|
|
"use_random_string": False,
|
|
},
|
|
# Tier 4: Medium - Longer words, 80% multi (3-5 letters)
|
|
4: {
|
|
"min_word_length": 10,
|
|
"max_word_length": 25,
|
|
"multi_letter_probability": 0.8,
|
|
"min_letters_to_count": 3,
|
|
"max_letters_to_count": 5,
|
|
"use_random_string": False,
|
|
},
|
|
# Tier 5: Medium-Hard - Long words, 90% multi (5-8 letters)
|
|
5: {
|
|
"min_word_length": 15,
|
|
"max_word_length": 35,
|
|
"multi_letter_probability": 0.9,
|
|
"min_letters_to_count": 5,
|
|
"max_letters_to_count": 8,
|
|
"use_random_string": False,
|
|
},
|
|
# Tier 6: Hard - Random strings, 100% multi (8-12 letters)
|
|
6: {
|
|
"min_word_length": 40,
|
|
"max_word_length": 80,
|
|
"multi_letter_probability": 1.0,
|
|
"min_letters_to_count": 8,
|
|
"max_letters_to_count": 12,
|
|
"use_random_string": True,
|
|
},
|
|
# Tier 7: Very Hard - Long random strings, 100% multi (12-20 letters)
|
|
7: {
|
|
"min_word_length": 80,
|
|
"max_word_length": 150,
|
|
"multi_letter_probability": 1.0,
|
|
"min_letters_to_count": 12,
|
|
"max_letters_to_count": 20,
|
|
"use_random_string": True,
|
|
},
|
|
# Tier 8: Expert - Very long strings, 100% multi (18-30 letters)
|
|
8: {
|
|
"min_word_length": 150,
|
|
"max_word_length": 250,
|
|
"multi_letter_probability": 1.0,
|
|
"min_letters_to_count": 18,
|
|
"max_letters_to_count": 30,
|
|
"use_random_string": True,
|
|
},
|
|
# Tier 9: Master - Extreme strings, 100% multi (25-40 letters)
|
|
9: {
|
|
"min_word_length": 250,
|
|
"max_word_length": 400,
|
|
"multi_letter_probability": 1.0,
|
|
"min_letters_to_count": 25,
|
|
"max_letters_to_count": 40,
|
|
"use_random_string": True,
|
|
},
|
|
# Tier 10: Maximum - 500 chars, 100% multi (35-50 letters)
|
|
10: {
|
|
"min_word_length": 400,
|
|
"max_word_length": 500,
|
|
"multi_letter_probability": 1.0,
|
|
"min_letters_to_count": 35,
|
|
"max_letters_to_count": 50,
|
|
"use_random_string": True,
|
|
},
|
|
}
|
|
|
|
|
|
class LetterCountingConfig(BaseEnvConfig):
|
|
"""
|
|
Configuration class for Letter Counting Environment.
|
|
|
|
This environment uses adaptive difficulty (curriculum learning) with 10 difficulty tiers.
|
|
Difficulty automatically adjusts based on model performance.
|
|
"""
|
|
|
|
# Generation configuration
|
|
generation_temperature: float = Field(
|
|
1.0, description="Temperature for training generation"
|
|
)
|
|
eval_temperature: float = Field(
|
|
0.6, description="Temperature for evaluation generation"
|
|
)
|
|
max_generation_tokens: int = Field(
|
|
1024 * 15, description="Maximum tokens for model generation"
|
|
)
|
|
|
|
# Adaptive difficulty / curriculum learning configuration (10 tiers)
|
|
difficulty_window_size: int = Field(
|
|
150, description="Number of recent groups to track for difficulty adjustment"
|
|
)
|
|
difficulty_increase_threshold: float = Field(
|
|
0.8, description="If success rate > this, increase difficulty (skip group)"
|
|
)
|
|
difficulty_decrease_threshold: float = Field(
|
|
0.2, description="If success rate < this, decrease difficulty (skip group)"
|
|
)
|
|
min_difficulty_level: int = Field(
|
|
1, description="Minimum difficulty level (1 = easiest)"
|
|
)
|
|
max_difficulty_level: int = Field(
|
|
10, description="Maximum difficulty level (10 = hardest)"
|
|
)
|
|
starting_difficulty_level: int = Field(
|
|
3, description="Starting difficulty level (1-10)"
|
|
)
|
|
|
|
# Logging configuration
|
|
debug_logging: bool = Field(
|
|
True, description="Enable debug-level logging for more verbose output"
|
|
)
|
|
suppress_base_env_logs: bool = Field(
|
|
True, description="Suppress verbose base environment logs"
|
|
)
|
|
|
|
# Data dumping configuration (for creating offline training datasets)
|
|
dump_rollouts: bool = Field(
|
|
False, description="Whether to dump successful rollouts to JSONL files"
|
|
)
|
|
dump_batch_size: int = Field(
|
|
100, description="Number of groups to accumulate before saving to disk"
|
|
)
|
|
|
|
|
|
class LetterCountingEnv(BaseEnv):
|
|
"""
|
|
Letter Counting Environment for training models to count letters in words and strings.
|
|
|
|
This environment uses adaptive difficulty (curriculum learning) with 10 difficulty tiers:
|
|
- Levels 1-5: Real English words with increasing length and multi-letter counting
|
|
- Levels 6-10: Random letter strings with extreme length and many letters
|
|
|
|
The model is presented with questions like:
|
|
- "How many 'a's are in the string banana?" (single letter)
|
|
- "Count the occurrences of the letters 'e', 'o', 't' in the string..." (multi-letter)
|
|
|
|
Expected response format:
|
|
- <think>reasoning</think><answer>3</answer> (single letter)
|
|
- <think>reasoning</think><answer>{"e": 4, "o": 4, "t": 2}</answer> (multi-letter)
|
|
|
|
Training filters:
|
|
- Groups with >80% success rate are skipped (too easy)
|
|
- Groups with <20% success rate are skipped (too hard)
|
|
- Difficulty automatically adjusts based on rolling window performance
|
|
|
|
Evaluation uses static dataset from HuggingFace: NousResearch/Letter-Counting-Eval
|
|
"""
|
|
|
|
name = "letter_counting"
|
|
env_config_cls = LetterCountingConfig
|
|
|
|
# Fixed parameters (not configurable)
|
|
TRAIN_TEST_SPLIT = 0.95
|
|
RANDOM_SEED = 42
|
|
MIN_WORD_LENGTH = 3
|
|
MAX_WORD_LENGTH = 30
|
|
EVAL_DATASET_HF = "NousResearch/Letter-Counting-Eval"
|
|
|
|
def __init__(
|
|
self,
|
|
config: LetterCountingConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
"""
|
|
Initialize the Letter Counting environment.
|
|
|
|
Args:
|
|
config: Configuration for the environment
|
|
server_configs: List of server configurations for OpenAI API
|
|
slurm: Whether to use Slurm for distributed training
|
|
testing: Whether in testing mode
|
|
"""
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# Run UUID for tracking
|
|
self.run_uuid = str(uuid.uuid4())
|
|
|
|
# Data dumping infrastructure
|
|
self.rollouts_to_save_buffer: List[Dict] = []
|
|
self.processed_item_count = 0
|
|
self.datadumps_dir = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "data_dumps"
|
|
)
|
|
self.save_file_batch_num = 0
|
|
|
|
# Metrics tracking
|
|
self.letter_distribution_stats: Dict[str, int] = {}
|
|
self.word_length_stats: Dict[int, int] = {}
|
|
self.answer_format_errors = 0
|
|
self.think_format_errors = 0
|
|
|
|
# Adaptive difficulty / curriculum learning state
|
|
# Thread-safe lock for concurrent workers
|
|
self._difficulty_lock = threading.Lock()
|
|
self.current_difficulty_level = self.config.starting_difficulty_level
|
|
self.recent_scores: List[float] = [] # Rolling window of recent group scores
|
|
self.difficulty_history: List[Tuple[int, int, float]] = []
|
|
|
|
# Fixed evaluation dataset (loaded from HuggingFace during setup)
|
|
self.eval_dataset: List[Dict] = []
|
|
|
|
# Initialize the logger
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
if not self.logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
handler.setFormatter(formatter)
|
|
self.logger.addHandler(handler)
|
|
log_level = logging.DEBUG if self.config.debug_logging else logging.INFO
|
|
self.logger.setLevel(log_level)
|
|
self.logger.disabled = False
|
|
|
|
# Suppress base environment logs if requested
|
|
if self.config.suppress_base_env_logs:
|
|
base_logger = logging.getLogger("atroposlib.envs.base")
|
|
base_logger.setLevel(logging.WARNING)
|
|
|
|
# Log initialization
|
|
self.logger.info(
|
|
f"LetterCountingEnv initialized with run UUID: {self.run_uuid}"
|
|
)
|
|
self.logger.info(
|
|
f"Adaptive difficulty: starting at L{self.current_difficulty_level}/10 "
|
|
f"(range: {self.config.min_difficulty_level}-{self.config.max_difficulty_level}, "
|
|
f"window: {self.config.difficulty_window_size}, "
|
|
f"thresholds: >{self.config.difficulty_increase_threshold:.0%} to increase, "
|
|
f"<{self.config.difficulty_decrease_threshold:.0%} to decrease)"
|
|
)
|
|
if self.config.dump_rollouts:
|
|
self.logger.info(
|
|
f"Data dumping ENABLED: saving to {self.datadumps_dir}, "
|
|
f"batch size: {self.config.dump_batch_size}"
|
|
)
|
|
|
|
self.percent_correct_buffer = []
|
|
self.eval_metrics = []
|
|
self.rollouts_for_wandb: List[List[Tuple]] = []
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[LetterCountingConfig, List[APIServerConfig]]:
|
|
"""
|
|
Load configuration from YAML file, with fallback to defaults if not found.
|
|
|
|
Config path priority:
|
|
1. CLI argument: --config /path/to/config.yaml
|
|
2. Default: config.yaml (DEFAULT_CONFIG_PATH)
|
|
|
|
If config file doesn't exist or can't be read, uses sensible defaults.
|
|
"""
|
|
config_path = get_config_path()
|
|
|
|
# Try to load config from file, fall back to empty dicts if not found
|
|
try:
|
|
with open(config_path, "r") as f:
|
|
config = yaml.safe_load(f) or {}
|
|
print(f"Loaded config from {config_path}")
|
|
except FileNotFoundError:
|
|
print(f"Config file not found at {config_path}, using defaults")
|
|
config = {}
|
|
except Exception as e:
|
|
print(f"Error loading config from {config_path}: {e}, using defaults")
|
|
config = {}
|
|
|
|
env = config.get("env", {})
|
|
openai_configs = config.get("openai", [])
|
|
|
|
# Build LetterCountingConfig from YAML
|
|
env_config = LetterCountingConfig(
|
|
# Base environment config
|
|
tokenizer_name=env.get(
|
|
"tokenizer_name", "meta-llama/Llama-3.1-8B-Instruct"
|
|
),
|
|
group_size=env.get("group_size", 16),
|
|
batch_size=env.get("batch_size", 256),
|
|
max_batches_offpolicy=env.get("max_batches_offpolicy", 3),
|
|
use_wandb=env.get("use_wandb", True),
|
|
rollout_server_url=env.get("rollout_server_url", "http://localhost:8000"),
|
|
wandb_name=env.get("wandb_name", "tinker-letter-counting-env"),
|
|
ensure_scores_are_not_same=env.get("ensure_scores_are_not_same", True),
|
|
max_token_length=env.get("max_token_length", 16864),
|
|
max_num_workers=env.get("max_num_workers", 24),
|
|
total_steps=env.get("total_steps", 5000),
|
|
steps_per_eval=env.get("steps_per_eval", 100),
|
|
inference_weight=env.get("inference_weight", 1.0),
|
|
data_path_to_save_groups=env.get("data_path_to_save_groups", None),
|
|
eval_limit_ratio=env.get("eval_limit_ratio", 0.1),
|
|
# Generation config
|
|
generation_temperature=env.get("generation_temperature", 1.0),
|
|
eval_temperature=env.get("eval_temperature", 0.6),
|
|
max_generation_tokens=env.get("max_generation_tokens", 15360),
|
|
# Adaptive difficulty config
|
|
difficulty_window_size=env.get("difficulty_window_size", 150),
|
|
difficulty_increase_threshold=env.get("difficulty_increase_threshold", 0.8),
|
|
difficulty_decrease_threshold=env.get("difficulty_decrease_threshold", 0.2),
|
|
min_difficulty_level=env.get("min_difficulty_level", 1),
|
|
max_difficulty_level=env.get("max_difficulty_level", 10),
|
|
starting_difficulty_level=env.get("starting_difficulty_level", 3),
|
|
# Logging config
|
|
debug_logging=env.get("debug_logging", True),
|
|
suppress_base_env_logs=env.get("suppress_base_env_logs", True),
|
|
# Data dumping config
|
|
dump_rollouts=env.get("dump_rollouts", False),
|
|
dump_batch_size=env.get("dump_batch_size", 100),
|
|
)
|
|
|
|
# Build server configs from YAML
|
|
server_configs = []
|
|
for openai_cfg in openai_configs:
|
|
server_configs.append(
|
|
APIServerConfig(
|
|
model_name=openai_cfg.get(
|
|
"model_name", "meta-llama/Llama-3.1-8B-Instruct"
|
|
),
|
|
base_url=openai_cfg.get("base_url", "http://localhost:8001/v1"),
|
|
api_key=openai_cfg.get("api_key", "x"),
|
|
num_requests_for_eval=openai_cfg.get("num_requests_for_eval", 256),
|
|
server_type=openai_cfg.get("server_type"),
|
|
)
|
|
)
|
|
|
|
# Default server config if none provided
|
|
if not server_configs:
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
|
base_url="http://localhost:8001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
"""Set up the environment by loading word dataset and eval dataset."""
|
|
if words is None:
|
|
raise ImportError(
|
|
"NLTK is required for this environment. Please install with: pip install nltk"
|
|
)
|
|
|
|
# Set random seed for reproducibility
|
|
random.seed(self.RANDOM_SEED)
|
|
|
|
self.logger.info("Setting up word dataset from NLTK...")
|
|
|
|
# Get all English words from NLTK and filter
|
|
all_words = words.words()
|
|
filtered_words = [
|
|
word.lower()
|
|
for word in all_words
|
|
if word.isalpha()
|
|
and self.MIN_WORD_LENGTH <= len(word) <= self.MAX_WORD_LENGTH
|
|
]
|
|
|
|
# Shuffle for randomness
|
|
random.shuffle(filtered_words)
|
|
|
|
# Create train/test split (test words used as fallback, eval uses HF dataset)
|
|
split_point = int(self.TRAIN_TEST_SPLIT * len(filtered_words))
|
|
self.train_words = filtered_words[:split_point]
|
|
self.test_words = filtered_words[split_point:]
|
|
|
|
self.logger.info(f"Loaded {len(filtered_words)} words total")
|
|
self.logger.info(f"Training words: {len(self.train_words)}")
|
|
self.logger.info(f"Test words: {len(self.test_words)}")
|
|
self.logger.info(f"Example words: {self.train_words[:10]}")
|
|
|
|
# Initialize iteration counter
|
|
self.iter = 0
|
|
|
|
# Load evaluation dataset from HuggingFace
|
|
self.eval_dataset = self._load_eval_dataset()
|
|
|
|
self.logger.info("Letter counting environment setup complete")
|
|
|
|
def _load_eval_dataset(self) -> List[Dict]:
|
|
"""
|
|
Load the static evaluation dataset from HuggingFace.
|
|
|
|
Dataset: NousResearch/Letter-Counting-Eval
|
|
Schema: Prompt, Answer, DifficultyLevel, _text, _target_letters, _expected_counts
|
|
"""
|
|
if load_dataset is None:
|
|
self.logger.warning(
|
|
"datasets library not available. Evaluation will be skipped. "
|
|
"Install with: pip install datasets"
|
|
)
|
|
return []
|
|
|
|
self.logger.info(
|
|
f"Loading eval dataset from HuggingFace: {self.EVAL_DATASET_HF}"
|
|
)
|
|
|
|
try:
|
|
dataset = load_dataset(self.EVAL_DATASET_HF, split="train")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to load eval dataset: {e}")
|
|
return []
|
|
|
|
eval_items = []
|
|
for item in dataset:
|
|
# Handle both formats from the schema
|
|
eval_item = {
|
|
"prompt": item.get("Prompt"),
|
|
"answer": item.get("Answer"),
|
|
"difficulty_level": item.get("DifficultyLevel", 0),
|
|
}
|
|
|
|
# Include raw data if available for detailed analysis
|
|
if "_text" in item:
|
|
eval_item["text"] = item["_text"]
|
|
if "_target_letters" in item:
|
|
eval_item["target_letters"] = item["_target_letters"]
|
|
if "_expected_counts" in item:
|
|
eval_item["expected_counts"] = item["_expected_counts"]
|
|
|
|
eval_items.append(eval_item)
|
|
|
|
# Log distribution by difficulty level
|
|
level_counts = {}
|
|
for item in eval_items:
|
|
level = item.get("difficulty_level", 0)
|
|
level_counts[level] = level_counts.get(level, 0) + 1
|
|
|
|
level_str = ", ".join(f"L{k}={v}" for k, v in sorted(level_counts.items()))
|
|
self.logger.info(f"Loaded eval dataset: {len(eval_items)} items ({level_str})")
|
|
|
|
return eval_items
|
|
|
|
def _get_difficulty_params(self, level: Optional[int] = None) -> Dict:
|
|
"""
|
|
Get difficulty parameters for the specified or current difficulty level.
|
|
|
|
Args:
|
|
level: Optional difficulty level. If None, uses current_difficulty_level.
|
|
|
|
Returns:
|
|
Dict with difficulty parameters from DIFFICULTY_TIERS
|
|
"""
|
|
if level is None:
|
|
level = self.current_difficulty_level
|
|
|
|
# Clamp to valid range
|
|
level = max(
|
|
self.config.min_difficulty_level,
|
|
min(self.config.max_difficulty_level, level),
|
|
)
|
|
return DIFFICULTY_TIERS.get(level, DIFFICULTY_TIERS[3])
|
|
|
|
def _generate_random_letter_string(self, min_length: int, max_length: int) -> str:
|
|
"""
|
|
Generate a random string of lowercase letters for high-difficulty tiers.
|
|
|
|
Args:
|
|
min_length: Minimum string length
|
|
max_length: Maximum string length
|
|
|
|
Returns:
|
|
Random string of lowercase letters
|
|
"""
|
|
length = random.randint(min_length, max_length)
|
|
return "".join(
|
|
random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(length)
|
|
)
|
|
|
|
def _select_target_letters(self, text: str, num_letters: int) -> List[str]:
|
|
"""
|
|
Select target letters with randomized bias toward letters present in the text.
|
|
|
|
The bias is randomized each time (between 30% and 90%) to prevent the model
|
|
from learning that a fixed percentage of letters are always absent.
|
|
For long strings with many letters requested, the bias automatically adjusts
|
|
to what's actually available.
|
|
|
|
Args:
|
|
text: The text to analyze for present letters
|
|
num_letters: Number of letters to select
|
|
|
|
Returns:
|
|
List of selected target letters
|
|
"""
|
|
all_letters = list("abcdefghijklmnopqrstuvwxyz")
|
|
|
|
# Find letters present/absent in text
|
|
text_lower = text.lower()
|
|
present_letters = [ch for ch in all_letters if ch in text_lower]
|
|
absent_letters = [ch for ch in all_letters if ch not in text_lower]
|
|
|
|
# Randomize the present letter bias between 30% and 90%
|
|
# This prevents the model from learning a fixed pattern
|
|
present_bias = random.uniform(0.3, 0.9)
|
|
|
|
selected_letters = []
|
|
present_pool = present_letters.copy()
|
|
absent_pool = absent_letters.copy()
|
|
|
|
for _ in range(num_letters):
|
|
# Calculate effective bias based on what's available
|
|
# If we've exhausted one pool, use the other
|
|
if not present_pool and not absent_pool:
|
|
break
|
|
elif not present_pool:
|
|
effective_bias = 0.0 # Must use absent
|
|
elif not absent_pool:
|
|
effective_bias = 1.0 # Must use present
|
|
else:
|
|
effective_bias = present_bias
|
|
|
|
# Select based on bias
|
|
if random.random() < effective_bias and present_pool:
|
|
chosen = random.choice(present_pool)
|
|
present_pool.remove(chosen)
|
|
elif absent_pool:
|
|
chosen = random.choice(absent_pool)
|
|
absent_pool.remove(chosen)
|
|
elif present_pool:
|
|
# Fallback if absent is empty
|
|
chosen = random.choice(present_pool)
|
|
present_pool.remove(chosen)
|
|
else:
|
|
break
|
|
|
|
selected_letters.append(chosen)
|
|
|
|
return selected_letters
|
|
|
|
def _update_difficulty(self, group_score: float, item_difficulty_level: int):
|
|
"""
|
|
Update difficulty level based on recent performance.
|
|
|
|
Key behaviors:
|
|
1. Only counts scores from items at the CURRENT difficulty level
|
|
2. Resets the rolling window when difficulty changes
|
|
3. Requires min_samples before making any adjustment decision
|
|
|
|
Thread-safe: uses a lock to prevent race conditions.
|
|
|
|
Args:
|
|
group_score: Average score for the most recent group (0.0 to 1.0)
|
|
item_difficulty_level: The difficulty level when the item was generated
|
|
"""
|
|
with self._difficulty_lock:
|
|
# Only count scores from items at the CURRENT difficulty level
|
|
if item_difficulty_level != self.current_difficulty_level:
|
|
self.logger.debug(
|
|
f"Ignoring score from L{item_difficulty_level} item "
|
|
f"(current: L{self.current_difficulty_level})"
|
|
)
|
|
return
|
|
|
|
# Add to rolling window
|
|
self.recent_scores.append(group_score)
|
|
|
|
# Keep only the most recent scores
|
|
if len(self.recent_scores) > self.config.difficulty_window_size:
|
|
self.recent_scores = self.recent_scores[
|
|
-self.config.difficulty_window_size :
|
|
]
|
|
|
|
# Need at least half the window size before adjusting
|
|
min_samples = max(10, self.config.difficulty_window_size // 2)
|
|
if len(self.recent_scores) < min_samples:
|
|
self.logger.debug(
|
|
f"Not enough samples at L{self.current_difficulty_level}: "
|
|
f"{len(self.recent_scores)}/{min_samples} required"
|
|
)
|
|
return
|
|
|
|
# Calculate recent success rate
|
|
recent_success_rate = sum(self.recent_scores) / len(self.recent_scores)
|
|
|
|
old_level = self.current_difficulty_level
|
|
level_changed = False
|
|
|
|
# Adjust difficulty
|
|
if recent_success_rate > self.config.difficulty_increase_threshold:
|
|
if self.current_difficulty_level < self.config.max_difficulty_level:
|
|
self.current_difficulty_level += 1
|
|
level_changed = True
|
|
self.logger.info(
|
|
f"DIFFICULTY INCREASED: L{old_level} -> L{self.current_difficulty_level} "
|
|
f"(success: {recent_success_rate:.1%} > {self.config.difficulty_increase_threshold:.1%})"
|
|
)
|
|
elif recent_success_rate < self.config.difficulty_decrease_threshold:
|
|
if self.current_difficulty_level > self.config.min_difficulty_level:
|
|
self.current_difficulty_level -= 1
|
|
level_changed = True
|
|
self.logger.info(
|
|
f"DIFFICULTY DECREASED: L{old_level} -> L{self.current_difficulty_level} "
|
|
f"(success: {recent_success_rate:.1%} < {self.config.difficulty_decrease_threshold:.1%})"
|
|
)
|
|
|
|
# Reset window when difficulty changes
|
|
if level_changed:
|
|
self.recent_scores = []
|
|
self.logger.info(
|
|
f"Window reset - collecting fresh samples at L{self.current_difficulty_level}"
|
|
)
|
|
|
|
# Track difficulty history
|
|
self.difficulty_history.append(
|
|
(self.iter, self.current_difficulty_level, recent_success_rate)
|
|
)
|
|
|
|
def _reconstruct_message_with_thinking(self, completion_choice) -> str:
|
|
"""
|
|
Reconstruct a message by combining reasoning_content and content.
|
|
|
|
Modern reasoning models return reasoning in a separate `reasoning_content` field.
|
|
This method wraps that in <think> tags for consistent processing.
|
|
"""
|
|
# Handle text completion API
|
|
if hasattr(completion_choice, "text"):
|
|
raw_content = completion_choice.text or ""
|
|
reasoning_content = None
|
|
|
|
if hasattr(completion_choice, "reasoning_content"):
|
|
reasoning_content = completion_choice.reasoning_content
|
|
|
|
if reasoning_content is None and hasattr(completion_choice, "message"):
|
|
message = completion_choice.message
|
|
if message and hasattr(message, "reasoning_content"):
|
|
reasoning_content = message.reasoning_content
|
|
|
|
if reasoning_content:
|
|
return f"<think>\n{reasoning_content}\n</think>{raw_content}"
|
|
return raw_content
|
|
|
|
# Handle chat completion API
|
|
elif hasattr(completion_choice, "message"):
|
|
message = completion_choice.message
|
|
content = getattr(message, "content", "") or ""
|
|
reasoning_content = getattr(message, "reasoning_content", None)
|
|
|
|
if reasoning_content:
|
|
return f"<think>\n{reasoning_content}\n</think>{content}"
|
|
return content
|
|
|
|
return ""
|
|
|
|
async def _save_rollouts_to_jsonl(self):
|
|
"""Save buffered rollouts to a JSONL file in the data_dumps directory."""
|
|
if not self.rollouts_to_save_buffer:
|
|
self.logger.warning("_save_rollouts_to_jsonl called but buffer is empty!")
|
|
return
|
|
|
|
buffer_size = len(self.rollouts_to_save_buffer)
|
|
self.logger.info(f"Starting save of {buffer_size} groups to JSONL file...")
|
|
|
|
try:
|
|
if not os.path.exists(self.datadumps_dir):
|
|
os.makedirs(self.datadumps_dir)
|
|
self.logger.info(f"Created directory: {self.datadumps_dir}")
|
|
except OSError as e:
|
|
self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}")
|
|
return
|
|
|
|
file_path = os.path.join(
|
|
self.datadumps_dir,
|
|
f"letter_counting_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
|
|
)
|
|
|
|
try:
|
|
with open(file_path, "w") as f:
|
|
for rollout_dict in self.rollouts_to_save_buffer:
|
|
json.dump(rollout_dict, f)
|
|
f.write("\n")
|
|
self.logger.info(
|
|
f"Successfully saved {buffer_size} groups to {file_path} "
|
|
f"(batch #{self.save_file_batch_num})"
|
|
)
|
|
self.rollouts_to_save_buffer.clear()
|
|
self.save_file_batch_num += 1
|
|
except IOError as e:
|
|
self.logger.error(f"Error writing rollouts to {file_path}: {e}")
|
|
except Exception as e:
|
|
self.logger.error(f"Unexpected error saving rollouts: {e}")
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save checkpoint including iteration and difficulty state."""
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
data["current_difficulty_level"] = self.current_difficulty_level
|
|
data["recent_scores"] = self.recent_scores
|
|
data["processed_item_count"] = self.processed_item_count
|
|
data["save_file_batch_num"] = self.save_file_batch_num
|
|
data["letter_distribution_stats"] = self.letter_distribution_stats
|
|
data["word_length_stats"] = self.word_length_stats
|
|
data["answer_format_errors"] = self.answer_format_errors
|
|
data["think_format_errors"] = self.think_format_errors
|
|
super().save_checkpoint(step, data)
|
|
|
|
def load_checkpoint(self):
|
|
"""Load checkpoint including iteration and difficulty state."""
|
|
super().load_checkpoint()
|
|
|
|
if hasattr(self, "current_difficulty_level"):
|
|
self.logger.info(
|
|
f"Restored difficulty level: L{self.current_difficulty_level}"
|
|
)
|
|
if hasattr(self, "recent_scores") and self.recent_scores:
|
|
self.logger.info(f"Restored {len(self.recent_scores)} recent scores")
|
|
if hasattr(self, "save_file_batch_num"):
|
|
self.logger.info(
|
|
f"Restored save_file_batch_num: {self.save_file_batch_num}"
|
|
)
|
|
|
|
async def close(self):
|
|
"""Clean up and save any remaining rollouts before exiting."""
|
|
self.logger.info("Closing LetterCountingEnv...")
|
|
|
|
if self.config.dump_rollouts and self.rollouts_to_save_buffer:
|
|
self.logger.info(
|
|
f"FINAL SAVE: {len(self.rollouts_to_save_buffer)} groups in buffer. "
|
|
f"Saving to disk..."
|
|
)
|
|
await self._save_rollouts_to_jsonl()
|
|
|
|
if hasattr(super(), "close"):
|
|
await super().close()
|
|
|
|
self.logger.info("LetterCountingEnv closed.")
|
|
|
|
async def get_next_item(self):
|
|
"""
|
|
Get the next training item from the dataset.
|
|
|
|
Uses adaptive difficulty to determine text length and letter count.
|
|
"""
|
|
# Get difficulty parameters for current level
|
|
diff_params = self._get_difficulty_params()
|
|
min_len = diff_params["min_word_length"]
|
|
max_len = diff_params["max_word_length"]
|
|
multi_prob = diff_params["multi_letter_probability"]
|
|
min_letters = diff_params.get("min_letters_to_count", 1)
|
|
max_letters = diff_params["max_letters_to_count"]
|
|
use_random_string = diff_params.get("use_random_string", False)
|
|
|
|
# Get text based on difficulty level
|
|
if use_random_string:
|
|
# Higher difficulty tiers use generated random strings
|
|
text = self._generate_random_letter_string(min_len, max_len)
|
|
else:
|
|
# Find a word that matches the difficulty-appropriate length
|
|
attempts = 0
|
|
max_attempts = 100
|
|
while attempts < max_attempts:
|
|
candidate = self.train_words[
|
|
(self.iter + attempts) % len(self.train_words)
|
|
]
|
|
if min_len <= len(candidate) <= max_len:
|
|
text = candidate
|
|
break
|
|
attempts += 1
|
|
else:
|
|
# Fallback: generate random string if no matching word found
|
|
text = self._generate_random_letter_string(min_len, max_len)
|
|
|
|
# Decide number of letters to count
|
|
if random.random() < multi_prob:
|
|
num_letters = random.randint(max(2, min_letters), max_letters)
|
|
else:
|
|
num_letters = 1
|
|
|
|
target_letters = self._select_target_letters(text, num_letters)
|
|
|
|
# Count occurrences for each target letter
|
|
correct_counts = {
|
|
letter: text.lower().count(letter) for letter in target_letters
|
|
}
|
|
|
|
# Log item selection
|
|
text_preview = text[:50] + "..." if len(text) > 50 else text
|
|
letters_str = ", ".join(target_letters)
|
|
counts_str = ", ".join(
|
|
f"{letter}:{correct_counts[letter]}" for letter in target_letters
|
|
)
|
|
present_count = sum(
|
|
1 for letter in target_letters if correct_counts[letter] > 0
|
|
)
|
|
|
|
self.logger.info(
|
|
f"[L{self.current_difficulty_level}/10] '{text_preview}' | "
|
|
f"Letters: [{letters_str}] | Counts: [{counts_str}] | "
|
|
f"Present: {present_count}/{len(target_letters)} (iter {self.iter})"
|
|
)
|
|
|
|
self.iter += 1
|
|
|
|
# Create the question based on single/multiple letters
|
|
if len(target_letters) == 1:
|
|
target_letter = target_letters[0]
|
|
question_text = f"How many {target_letter}s are in the string {text}?"
|
|
question_with_instruction = (
|
|
f"{question_text}\n\n"
|
|
f"Provide your answer in the format: <answer>{{number}}</answer>"
|
|
)
|
|
else:
|
|
letters_str = (
|
|
", ".join(f"'{letter}'" for letter in target_letters[:-1])
|
|
+ f", and '{target_letters[-1]}'"
|
|
)
|
|
question_text = f"Count the occurrences of the letters {letters_str} in the string {text}"
|
|
example_json = (
|
|
"{" + ", ".join(f'"{letter}": 0' for letter in target_letters) + "}"
|
|
)
|
|
question_with_instruction = (
|
|
f"{question_text}\n\n"
|
|
f"Provide your answer as JSON in the format: <answer>{example_json}</answer>"
|
|
)
|
|
|
|
# Create prompt (user message only, no system prompt)
|
|
prompt = [
|
|
frozenset({"role": "user", "content": question_with_instruction}.items())
|
|
]
|
|
|
|
# Return prompt, correct counts, text, target letters, and difficulty level
|
|
return (
|
|
tuple(prompt),
|
|
correct_counts,
|
|
text,
|
|
target_letters,
|
|
self.current_difficulty_level,
|
|
)
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
|
|
"""Generate and collect model responses for scoring."""
|
|
# Extract messages from the item
|
|
messages = [dict(role_dict) for role_dict in item[0]]
|
|
|
|
try:
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
completions = await managed.chat_completion(
|
|
messages=messages,
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_generation_tokens,
|
|
temperature=self.config.generation_temperature,
|
|
stop=[self.tokenizer.eos_token_id],
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
text_preview = item[2][:50] if len(item[2]) > 50 else item[2]
|
|
self.logger.error(
|
|
f"API call failed for '{text_preview}' (iter {self.iter}): {e}"
|
|
)
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
to_score = []
|
|
|
|
for i, completion_choice in enumerate(completions.choices):
|
|
trajectory_messages = [dict(role_dict) for role_dict in item[0]]
|
|
reconstructed_response = self._reconstruct_message_with_thinking(
|
|
completion_choice
|
|
)
|
|
trajectory_messages.append(
|
|
{"role": "assistant", "content": reconstructed_response}
|
|
)
|
|
|
|
to_score.append(
|
|
{
|
|
"messages": tuple(trajectory_messages),
|
|
"correct_counts": item[1],
|
|
"text": item[2],
|
|
"target_letters": item[3],
|
|
"difficulty_level": item[4],
|
|
"finish_reason": completion_choice.finish_reason,
|
|
"tokens": nodes[i].tokens,
|
|
"masks": nodes[i].masked_tokens,
|
|
"logprobs": nodes[i].logprobs,
|
|
}
|
|
)
|
|
|
|
scored_data = await self.score(to_score)
|
|
|
|
# Data dumping logic
|
|
if self.config.dump_rollouts and scored_data is not None:
|
|
await self._handle_data_dumping(to_score, scored_data, item)
|
|
|
|
return scored_data, []
|
|
|
|
async def _handle_data_dumping(
|
|
self, to_score: List, scored_data: ScoredDataGroup, item
|
|
):
|
|
"""Handle data dumping for creating offline training datasets."""
|
|
if not scored_data.get("scores"):
|
|
return
|
|
|
|
scores = scored_data["scores"]
|
|
group_average = sum(scores) / len(scores)
|
|
|
|
# Only dump groups that have learning signal (not all same scores)
|
|
if all(scores[0] == s for s in scores):
|
|
return
|
|
|
|
# Extract successful rollouts for dumping
|
|
rollouts_to_dump = []
|
|
for i, score_val in enumerate(scores):
|
|
if score_val > 0: # Only save successful rollouts
|
|
rollouts_to_dump.append(
|
|
{
|
|
"conversation": to_score[i]["messages"],
|
|
"score": score_val,
|
|
"expected_counts": to_score[i]["correct_counts"],
|
|
"text": to_score[i]["text"],
|
|
"target_letters": to_score[i]["target_letters"],
|
|
"difficulty_level": to_score[i]["difficulty_level"],
|
|
"group_average_score": group_average,
|
|
}
|
|
)
|
|
|
|
if rollouts_to_dump:
|
|
text = item[2]
|
|
target_letters = item[3]
|
|
text_preview = (
|
|
text[:30].replace(" ", "_")
|
|
if len(text) > 30
|
|
else text.replace(" ", "_")
|
|
)
|
|
letters_str = "_".join(target_letters)
|
|
item_id = f"{text_preview}_{letters_str}"
|
|
|
|
self.rollouts_to_save_buffer.append(
|
|
{
|
|
"item_id": item_id,
|
|
"rollouts": rollouts_to_dump,
|
|
}
|
|
)
|
|
self.processed_item_count += 1
|
|
|
|
self.logger.debug(
|
|
f"BUFFER ADD: Added '{item_id}' to buffer. "
|
|
f"Buffer: {len(self.rollouts_to_save_buffer)}/{self.config.dump_batch_size}"
|
|
)
|
|
|
|
# Save when buffer reaches batch size
|
|
if len(self.rollouts_to_save_buffer) >= self.config.dump_batch_size:
|
|
await self._save_rollouts_to_jsonl()
|
|
|
|
def _extract_answer(self, text: str, expected_format: str = "single"):
|
|
"""
|
|
Extract the answer from model response.
|
|
|
|
Only allows one valid answer format - multiple formats result in score of 0.
|
|
|
|
Args:
|
|
text: The model's response text
|
|
expected_format: "single" for number, "multi" for JSON
|
|
|
|
Returns:
|
|
Extracted answer (int for single, dict for multi) or None if invalid
|
|
"""
|
|
# Check for exactly one <think> opening and closing tag
|
|
think_tags = re.findall(r"<think>", text, re.IGNORECASE)
|
|
if len(think_tags) != 1:
|
|
return None
|
|
|
|
think_close_tags = re.findall(r"</think>", text, re.IGNORECASE)
|
|
if len(think_close_tags) != 1:
|
|
return None
|
|
|
|
# Split into thinking and answer sections
|
|
parts = re.split(r"</think>", text, flags=re.IGNORECASE, maxsplit=1)
|
|
if len(parts) != 2:
|
|
return None
|
|
|
|
thinking_section, answer_section = parts
|
|
|
|
# Validate thinking section contains opening tag
|
|
if "<think>" not in thinking_section.lower():
|
|
return None
|
|
|
|
# No additional think tags in answer section
|
|
if "<think>" in answer_section.lower():
|
|
return None
|
|
|
|
# Extract answer based on format
|
|
if expected_format == "single":
|
|
answer_pattern = r"<answer>\s*(\d+)\s*</answer>"
|
|
matches = re.findall(answer_pattern, answer_section, re.IGNORECASE)
|
|
if len(matches) != 1:
|
|
return None
|
|
try:
|
|
return int(matches[0])
|
|
except ValueError:
|
|
return None
|
|
else:
|
|
answer_pattern = r"<answer>\s*(\{[^}]+\})\s*</answer>"
|
|
matches = re.findall(answer_pattern, answer_section, re.IGNORECASE)
|
|
if len(matches) != 1:
|
|
return None
|
|
try:
|
|
answer_dict = json.loads(matches[0])
|
|
if not isinstance(answer_dict, dict):
|
|
return None
|
|
for key, value in answer_dict.items():
|
|
if not isinstance(key, str) or not isinstance(value, int):
|
|
return None
|
|
return answer_dict
|
|
except (json.JSONDecodeError, ValueError):
|
|
return None
|
|
|
|
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
|
"""Score the generated model responses against expected letter counts."""
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
scores["inference_logprobs"] = []
|
|
|
|
if not rollout_group_data:
|
|
return None
|
|
|
|
expected_counts = rollout_group_data[0]["correct_counts"]
|
|
text = rollout_group_data[0]["text"]
|
|
target_letters = rollout_group_data[0]["target_letters"]
|
|
item_difficulty_level = rollout_group_data[0].get("difficulty_level", 0)
|
|
|
|
# Track statistics
|
|
for target_letter in target_letters:
|
|
if target_letter not in self.letter_distribution_stats:
|
|
self.letter_distribution_stats[target_letter] = 0
|
|
self.letter_distribution_stats[target_letter] += 1
|
|
|
|
text_len = len(text)
|
|
if text_len not in self.word_length_stats:
|
|
self.word_length_stats[text_len] = 0
|
|
self.word_length_stats[text_len] += 1
|
|
|
|
# Shuffle to avoid selection bias
|
|
random.shuffle(rollout_group_data)
|
|
|
|
format_errors_in_group = 0
|
|
think_errors_in_group = 0
|
|
|
|
for item in rollout_group_data:
|
|
model_response = item["messages"][-1]["content"]
|
|
stop_reason = item["finish_reason"]
|
|
|
|
if stop_reason == "length":
|
|
reward = 0
|
|
else:
|
|
expected_format = "single" if len(target_letters) == 1 else "multi"
|
|
model_answer = self._extract_answer(model_response, expected_format)
|
|
|
|
if model_answer is None:
|
|
reward = 0
|
|
format_errors_in_group += 1
|
|
if (
|
|
"<think>" not in model_response.lower()
|
|
or "</think>" not in model_response.lower()
|
|
):
|
|
think_errors_in_group += 1
|
|
else:
|
|
if expected_format == "single":
|
|
expected_single_count = expected_counts[target_letters[0]]
|
|
reward = 1 if model_answer == expected_single_count else 0
|
|
else:
|
|
if set(model_answer.keys()) == set(target_letters) and all(
|
|
model_answer.get(letter, -1) == expected_counts[letter]
|
|
for letter in target_letters
|
|
):
|
|
reward = 1
|
|
else:
|
|
reward = 0
|
|
|
|
tokens = item["tokens"]
|
|
masks = item["masks"]
|
|
logprobs = item["logprobs"]
|
|
|
|
# Remove examples with insufficient context
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["inference_logprobs"].append(logprobs)
|
|
scores["scores"].append(1.0 if reward else 0.0)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
if not scores["tokens"]:
|
|
self.logger.warning(f"No valid items scored for '{text[:50]}...'")
|
|
return None
|
|
|
|
# Update global error counters
|
|
self.answer_format_errors += format_errors_in_group
|
|
self.think_format_errors += think_errors_in_group
|
|
|
|
# Record success rate
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(score)
|
|
|
|
# Calculate group average and check training thresholds
|
|
current_scores = scores.get("scores", [])
|
|
if not current_scores:
|
|
return None
|
|
|
|
average_score = sum(current_scores) / len(current_scores)
|
|
|
|
# Update adaptive difficulty
|
|
self._update_difficulty(average_score, item_difficulty_level)
|
|
|
|
# Log group results
|
|
text_preview = text[:50] + "..." if len(text) > 50 else text
|
|
letters_str = ", ".join(target_letters)
|
|
expected_str = (
|
|
str(expected_counts)
|
|
if len(target_letters) > 1
|
|
else str(expected_counts[target_letters[0]])
|
|
)
|
|
|
|
self.logger.info(
|
|
f"'{text_preview}' | Letters: '{letters_str}' | Expected: {expected_str} | "
|
|
f"Score: {average_score:.2f} | L{item_difficulty_level} (current: L{self.current_difficulty_level})"
|
|
)
|
|
|
|
# CRITICAL: Training difficulty filtering
|
|
# Skip groups that are too easy (>80%) or too hard (<20%)
|
|
if average_score > self.config.difficulty_increase_threshold:
|
|
self.logger.debug(
|
|
f"Skipping group - too easy (avg: {average_score:.2f} > "
|
|
f"{self.config.difficulty_increase_threshold})"
|
|
)
|
|
return None
|
|
|
|
if average_score < self.config.difficulty_decrease_threshold:
|
|
self.logger.debug(
|
|
f"Skipping group - too hard (avg: {average_score:.2f} < "
|
|
f"{self.config.difficulty_decrease_threshold})"
|
|
)
|
|
return None
|
|
|
|
return scores
|
|
|
|
async def rollout_and_score_eval(self, eval_item: Dict) -> Tuple[int, int]:
|
|
"""
|
|
Generate and score model response for a single eval item.
|
|
|
|
Uses the pre-formatted prompt from the HuggingFace dataset.
|
|
|
|
Args:
|
|
eval_item: Item from the eval dataset with 'prompt' and 'answer' fields
|
|
|
|
Returns:
|
|
Tuple of (score, difficulty_level)
|
|
"""
|
|
prompt = eval_item["prompt"]
|
|
expected_answer = eval_item["answer"]
|
|
difficulty_level = eval_item.get("difficulty_level", 0)
|
|
|
|
# Determine expected format from answer
|
|
try:
|
|
# Try to parse as JSON (multi-letter)
|
|
parsed_answer = json.loads(expected_answer)
|
|
if isinstance(parsed_answer, dict):
|
|
expected_format = "multi"
|
|
expected_counts = parsed_answer
|
|
else:
|
|
expected_format = "single"
|
|
expected_counts = int(expected_answer)
|
|
except (json.JSONDecodeError, ValueError, TypeError):
|
|
# Single number
|
|
expected_format = "single"
|
|
try:
|
|
expected_counts = int(expected_answer)
|
|
except (ValueError, TypeError):
|
|
# Can't parse answer, skip this item
|
|
self.logger.warning(f"Could not parse eval answer: {expected_answer}")
|
|
return (0, difficulty_level)
|
|
|
|
# Create messages and get completion
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
completion = await managed.chat_completion(
|
|
messages=messages,
|
|
n=1,
|
|
max_tokens=self.config.max_generation_tokens,
|
|
temperature=self.config.eval_temperature,
|
|
split="eval",
|
|
stop=[self.tokenizer.eos_token_id],
|
|
)
|
|
|
|
model_response = self._reconstruct_message_with_thinking(completion.choices[0])
|
|
model_answer = self._extract_answer(model_response, expected_format)
|
|
|
|
# Score based on format
|
|
if model_answer is None:
|
|
score = 0
|
|
elif expected_format == "single":
|
|
score = 1 if model_answer == expected_counts else 0
|
|
else:
|
|
# Multi-letter: compare dicts
|
|
score = 1 if model_answer == expected_counts else 0
|
|
|
|
return (score, difficulty_level)
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluate the model on the static HuggingFace eval dataset.
|
|
|
|
Tracks and logs accuracy for each difficulty level separately.
|
|
"""
|
|
self.logger.info("Starting evaluation...")
|
|
|
|
if not self.eval_dataset:
|
|
self.logger.warning("No eval dataset available. Skipping evaluation.")
|
|
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
|
return
|
|
|
|
self.logger.info(
|
|
f"Evaluating on {len(self.eval_dataset)} items from {self.EVAL_DATASET_HF}..."
|
|
)
|
|
|
|
eval_tasks = [
|
|
self.rollout_and_score_eval(eval_item) for eval_item in self.eval_dataset
|
|
]
|
|
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating")
|
|
|
|
if not results:
|
|
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
|
self.logger.warning("No evaluation results returned.")
|
|
return
|
|
|
|
# Calculate overall and per-difficulty scores
|
|
all_scores = [r[0] for r in results]
|
|
level_scores = {}
|
|
|
|
for score, difficulty_level in results:
|
|
if difficulty_level not in level_scores:
|
|
level_scores[difficulty_level] = []
|
|
level_scores[difficulty_level].append(score)
|
|
|
|
# Overall accuracy
|
|
percent_correct = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
|
|
# Per-difficulty accuracy
|
|
level_accuracies = {}
|
|
for level in sorted(level_scores.keys()):
|
|
if level_scores[level]:
|
|
level_acc = sum(level_scores[level]) / len(level_scores[level])
|
|
level_accuracies[level] = level_acc
|
|
self.eval_metrics.append((f"eval/percent_correct_L{level}", level_acc))
|
|
|
|
level_summary = " | ".join(
|
|
f"L{lvl}: {level_accuracies.get(lvl, 0):.0%}"
|
|
for lvl in sorted(level_accuracies.keys())
|
|
)
|
|
self.logger.info(
|
|
f"Evaluation finished. Overall: {percent_correct:.1%} | {level_summary}"
|
|
)
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
|
item: Item = None,
|
|
):
|
|
"""Add rollouts for WandB logging."""
|
|
if item is None or scored_data is None or not scored_data.get("tokens"):
|
|
return
|
|
|
|
expected_counts = item[1]
|
|
text = item[2]
|
|
target_letters = item[3]
|
|
|
|
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
if num_keep == -1:
|
|
num_keep = self.config.group_size
|
|
|
|
num_keep = min(num_keep, len(scored_data["tokens"]))
|
|
if num_keep == 0:
|
|
return
|
|
|
|
group_scores = scored_data.get("scores", [])
|
|
group_average_score = (
|
|
sum(group_scores) / len(group_scores) if group_scores else 0.0
|
|
)
|
|
|
|
current_rollouts = []
|
|
for i in range(num_keep):
|
|
if i < len(scored_data["tokens"]) and i < len(scored_data["scores"]):
|
|
full_text = self.tokenizer.decode(
|
|
scored_data["tokens"][i], skip_special_tokens=True
|
|
)
|
|
score_val = scored_data["scores"][i]
|
|
expected_str = (
|
|
str(expected_counts)
|
|
if len(target_letters) > 1
|
|
else str(expected_counts[target_letters[0]])
|
|
)
|
|
letters_str = ", ".join(target_letters)
|
|
current_rollouts.append(
|
|
(
|
|
full_text,
|
|
score_val,
|
|
expected_str,
|
|
text[:100],
|
|
letters_str,
|
|
group_average_score,
|
|
)
|
|
)
|
|
|
|
self.rollouts_for_wandb.append(current_rollouts)
|
|
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics):
|
|
"""Create a WandB table with rollout examples."""
|
|
if len(self.rollouts_for_wandb) > 0:
|
|
table = wandb.Table(
|
|
columns=[
|
|
"full_text",
|
|
"score",
|
|
"expected_counts",
|
|
"text",
|
|
"target_letters",
|
|
"group_average_score",
|
|
]
|
|
)
|
|
for group in self.rollouts_for_wandb:
|
|
for item in group:
|
|
if len(item) >= 6:
|
|
table.add_data(
|
|
item[0], item[1], item[2], item[3], item[4], item[5]
|
|
)
|
|
else:
|
|
table.add_data(item[0], item[1], item[2], item[3], item[4], 0.0)
|
|
wandb_metrics["train/rollouts"] = table
|
|
self.rollouts_for_wandb = []
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to WandB."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Training accuracy
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
self.percent_correct_buffer = []
|
|
|
|
# Eval metrics
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
self.eval_metrics = []
|
|
|
|
# Letter distribution stats
|
|
if self.letter_distribution_stats:
|
|
total_letters_asked = sum(self.letter_distribution_stats.values())
|
|
wandb_metrics["stats/total_letters_asked"] = total_letters_asked
|
|
|
|
most_common = max(
|
|
self.letter_distribution_stats, key=self.letter_distribution_stats.get
|
|
)
|
|
least_common = min(
|
|
self.letter_distribution_stats, key=self.letter_distribution_stats.get
|
|
)
|
|
wandb_metrics["stats/most_common_letter_count"] = (
|
|
self.letter_distribution_stats[most_common]
|
|
)
|
|
wandb_metrics["stats/least_common_letter_count"] = (
|
|
self.letter_distribution_stats[least_common]
|
|
)
|
|
|
|
import math
|
|
|
|
entropy = -sum(
|
|
(count / total_letters_asked) * math.log2(count / total_letters_asked)
|
|
for count in self.letter_distribution_stats.values()
|
|
if count > 0
|
|
)
|
|
wandb_metrics["stats/letter_distribution_entropy"] = entropy
|
|
|
|
# Word length stats
|
|
if self.word_length_stats:
|
|
total_words = sum(self.word_length_stats.values())
|
|
wandb_metrics["stats/total_words_asked"] = total_words
|
|
|
|
avg_length = (
|
|
sum(length * count for length, count in self.word_length_stats.items())
|
|
/ total_words
|
|
)
|
|
wandb_metrics["stats/avg_word_length"] = avg_length
|
|
wandb_metrics["stats/min_word_length"] = min(self.word_length_stats.keys())
|
|
wandb_metrics["stats/max_word_length"] = max(self.word_length_stats.keys())
|
|
|
|
# Error rates
|
|
if self.processed_item_count > 0:
|
|
wandb_metrics["errors/answer_format_error_rate"] = (
|
|
self.answer_format_errors
|
|
/ (self.processed_item_count * self.config.group_size)
|
|
)
|
|
wandb_metrics["errors/think_format_error_rate"] = (
|
|
self.think_format_errors
|
|
/ (self.processed_item_count * self.config.group_size)
|
|
)
|
|
wandb_metrics["errors/total_format_errors"] = (
|
|
self.answer_format_errors + self.think_format_errors
|
|
)
|
|
|
|
# Adaptive difficulty metrics
|
|
wandb_metrics["curriculum/difficulty_level"] = self.current_difficulty_level
|
|
if self.recent_scores:
|
|
wandb_metrics["curriculum/recent_success_rate"] = sum(
|
|
self.recent_scores
|
|
) / len(self.recent_scores)
|
|
wandb_metrics["curriculum/samples_in_window"] = len(self.recent_scores)
|
|
|
|
# Data dumping metrics
|
|
if self.config.dump_rollouts:
|
|
wandb_metrics["data_dumps/buffer_size"] = len(self.rollouts_to_save_buffer)
|
|
wandb_metrics["data_dumps/batches_saved"] = self.save_file_batch_num
|
|
|
|
# Rollout table
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
LetterCountingEnv.cli()
|