import json import math import random import re import time from typing import Dict, List, Optional, Tuple, Union import wandb from datasets import load_dataset from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer class PairwiseJudgementConfig(BaseEnvConfig): """Configuration for PairwiseJudgementEnv with thinking mode and configurable options.""" thinking_mode: bool = Field( default=False, description="Whether to enable thinking mode with tags.", ) num_choices: int = Field( default=4, ge=2, le=26, description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).", ) custom_thinking_prompt: Optional[str] = Field( default=None, description="Custom thinking prompt. If None, uses the default thinking prompt.", ) custom_judgment_prompt: Optional[str] = Field( default=None, description="Custom judgment prompt. If None, uses the default judgment prompt.", ) eval_temperature: float = Field( default=0.6, description="Temperature for evaluation completions.", ) rollout_temperature: float = Field( default=0.8, description="Temperature for training rollout completions.", ) eval_max_tokens: int = Field( default=1024 * 16, description="Maximum tokens for evaluation completions.", ) train_max_tokens: int = Field( default=1024 * 16, description="Maximum tokens for training completions.", ) class PairwiseJudgementEnv(BaseEnv): name = "pairwise_judgement" env_config_cls = PairwiseJudgementConfig def __init__( self, config: PairwiseJudgementConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.config: PairwiseJudgementConfig = config self.percent_correct_buffer = [] self.eval_metrics = [] # Generate choice letters based on num_choices (A, B, C, D... up to Z) self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)] # Initialize detailed metrics tracking for all choice letters self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters} self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters} self.error_count = 0 # Failed to follow format self.total_judgments = 0 self.rollouts_for_wandb = [] # Pre-compile regex patterns for performance self._think_pattern = re.compile(r"") self._think_close_pattern = re.compile(r"") self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL) self._question_pattern = re.compile( r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL ) self._thinking_extract_pattern = re.compile(r"(.*?)", re.DOTALL) # Pre-compile choice patterns for each letter self._choice_patterns = { letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters } self._answer_choice_patterns = { letter: re.compile( rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]", re.DOTALL, ) for letter in self.choice_letters } # System prompts (use custom ones if provided, otherwise defaults) self.thinking_system_prompt = self._get_thinking_prompt() self.judgment_system_prompt = self._get_judgment_prompt() def _get_thinking_prompt(self) -> str: """Get thinking system prompt.""" return ( self.config.custom_thinking_prompt if self.config.custom_thinking_prompt else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) def _get_judgment_prompt(self) -> str: """Get judgment system prompt.""" if self.config.custom_judgment_prompt: return self.config.custom_judgment_prompt choice_format_examples = ", ".join( [ f'"[[{letter}]]" if assistant {letter} is best' for letter in self.choice_letters ] ) return ( f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. " f"You should choose the assistant that follows the user's instructions and answers the user's question best. " f"Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. " f"Begin your evaluation by comparing the {self.config.num_choices} responses and provide a short explanation. " f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. " f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. " f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: " f"{choice_format_examples}." ) def _reset_metrics(self) -> None: """Reset training metrics.""" self.percent_correct_buffer = [] self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters} self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters} self.error_count = 0 self.total_judgments = 0 def _convert_messages_to_list(self, prompt_tuple: Tuple) -> List[Dict]: """Convert frozenset message format to list format.""" messages = [] for role_dict in prompt_tuple: messages.append(dict(role_dict)) return messages def _create_system_content(self) -> str: """Create system message content based on thinking mode.""" if self.config.thinking_mode: return f"{self.thinking_system_prompt}\n\n{self.judgment_system_prompt}" return self.judgment_system_prompt def _prepare_completion_input(self, prompt_tuple: Tuple) -> Tuple[List[Dict], str]: """Convert prompt tuple to messages and formatted prompt text.""" messages = self._convert_messages_to_list(prompt_tuple) prompt_text = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) return messages, prompt_text def _get_train_completion_params(self) -> Dict: """Get completion parameters for training rollouts.""" return { "n": self.config.group_size, "max_tokens": self.config.train_max_tokens, "temperature": self.config.rollout_temperature, } def _get_eval_completion_params(self) -> Dict: """Get completion parameters for evaluation.""" return { "n": 1, "max_tokens": self.config.eval_max_tokens, "temperature": self.config.eval_temperature, "split": "eval", } @classmethod def config_init(cls) -> Tuple[PairwiseJudgementConfig, List[APIServerConfig]]: env_config = PairwiseJudgementConfig( tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3", group_size=16, use_wandb=True, max_num_workers_per_node=16, rollout_server_url="http://localhost:8000", total_steps=2000, batch_size=1024, steps_per_eval=25, max_token_length=1024 * 16, inference_weight=1.0, wandb_name="pairwise_judgment", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, min_batch_allocation=0.1, thinking_mode=True, ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ), ] return env_config, server_configs async def setup(self) -> None: """Set up the environment by loading datasets.""" # Load placeholder train dataset (not actually used since we generate synthetic examples) try: self.train = load_dataset("example/train", split="train") print(f"Loaded placeholder train dataset with {len(self.train)} examples") except Exception as e: # Create minimal placeholder data if dataset doesn't exist # Note: This isn't actually used since get_next_item() generates synthetic examples self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100 print(f"Using synthetic placeholder training data due to error: {e}") # Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH) self.test = load_dataset( "allenai/reward-bench-2", split="test", trust_remote_code=True ) print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples") # Debug: Show sample evaluation item structure if len(self.test) > 0: try: sample_item = self.test[0] print(f"\nSample eval item structure:") print(f"- Available keys: {list(sample_item.keys())}") # Handle different dataset structures if "prompt" in sample_item: print(f"- Prompt: {sample_item['prompt'][:100]}...") elif "chosen" in sample_item and isinstance(sample_item["chosen"], str): print(f"- Chosen (string): {sample_item['chosen'][:100]}...") elif "rejected" in sample_item and isinstance( sample_item["rejected"], str ): print(f"- Rejected (string): {sample_item['rejected'][:100]}...") if "chosen" in sample_item: if isinstance(sample_item["chosen"], list): print(f"- Chosen responses: {len(sample_item['chosen'])}") if sample_item["chosen"]: print( f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..." ) else: print(f"- Chosen (string): {sample_item['chosen'][:200]}...") if "rejected" in sample_item: if isinstance(sample_item["rejected"], list): print(f"- Rejected responses: {len(sample_item['rejected'])}") if sample_item["rejected"]: print( f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..." ) else: print( f"- Rejected (string): {sample_item['rejected'][:200]}..." ) except Exception as e: print(f"Warning: Could not display sample item structure: {e}") self.iter = 0 def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None: """Save checkpoint including iteration state.""" if data is None: data = {} data["iter"] = self.iter super().save_checkpoint(step, data) def process_judgement(self, judgment: str, track_metrics: bool = True) -> str: """Extract judgment from model response.""" if self.config.thinking_mode: # Check for exactly one pair of think tags using pre-compiled patterns think_open_count = len(self._think_pattern.findall(judgment)) think_close_count = len(self._think_close_pattern.findall(judgment)) if think_open_count != 1 or think_close_count != 1: if track_metrics: self.error_count += 1 self.total_judgments += 1 return "format_error" # Parse only content after tags match = self._think_content_pattern.search(judgment) if match: judgment = match.group(1) else: if track_metrics: self.error_count += 1 self.total_judgments += 1 return "format_error" if track_metrics: self.total_judgments += 1 # Check for each possible choice letter using pre-compiled patterns for letter in self.choice_letters: if self._choice_patterns[letter].search(judgment): if track_metrics: self.judgment_letter_counts[letter] += 1 return letter # No valid judgment found if track_metrics: self.error_count += 1 return "format_error" def create_judgment_prompt(self, question: str, answers: List[str]) -> str: """Create the user prompt for judgment task.""" if len(answers) != self.config.num_choices: raise ValueError( f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}" ) prompt = f"[User Question]\n{question}\n\n" for i, answer in enumerate(answers): letter = self.choice_letters[i] prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n" return prompt.strip() async def get_next_item(self) -> Item: """Generate next training item with synthetic data.""" self.iter += 1 # Create system message system_content = self._create_system_content() # Create varied placeholder judgment tasks examples = [ { "question": "What is the capital of France?", "correct": "The capital of France is Paris, which has been the capital since 987 AD and serves as the political, economic, and cultural center of the country.", "incorrect": [ "The capital of France is London.", "France's capital is Berlin, located in central Europe.", "I don't know the answer to this question.", "France doesn't have a capital city.", "The capital changes every year in France.", "Paris is just a city, not a capital.", ], }, { "question": "How do you fix a memory leak in Python?", "correct": "To fix memory leaks in Python: 1) Use memory profilers like tracemalloc or memory_profiler to identify leaks, 2) Ensure proper cleanup of resources with context managers, 3) Break circular references, 4) Close files and database connections explicitly, and 5) Use weak references when appropriate.", "incorrect": [ "Just restart your computer and the memory leak will be fixed.", "Python automatically handles all memory management, so memory leaks are impossible.", "You need to reinstall Python to fix memory leaks.", "Memory leaks don't exist in Python because it's interpreted.", "Use more RAM to solve memory leaks.", "Delete the Python installation and use a different language.", ], }, { "question": "Explain the difference between machine learning and artificial intelligence.", "correct": "Artificial Intelligence (AI) is the broader field focused on creating systems that can perform tasks typically requiring human intelligence. Machine Learning (ML) is a subset of AI that uses algorithms to learn patterns from data without being explicitly programmed for each task. So ML is one approach to achieving AI.", "incorrect": [ "Machine learning and artificial intelligence are exactly the same thing with different names.", "Machine learning is much broader than AI and includes all computer science.", "AI is only about robots, while machine learning is only about statistics.", "Machine learning came before AI historically.", "AI is a subset of machine learning, not the other way around.", "There is no difference; they are marketing terms for the same technology.", ], }, ] # Select random example example = random.choice(examples) # Create list with correct and incorrect answers, ensuring we have enough incorrect_answers = example["incorrect"][ : self.config.num_choices - 1 ] # Take enough incorrect answers all_answers = [example["correct"]] + incorrect_answers # If we don't have enough incorrect answers, pad with generic ones while len(all_answers) < self.config.num_choices: all_answers.append( "I don't have enough information to answer this question." ) random.shuffle(all_answers) # Find where correct answer ended up correct_index = all_answers.index(example["correct"]) correct_answer = self.choice_letters[correct_index] user_content = self.create_judgment_prompt(example["question"], all_answers) prompt = tuple( [ frozenset({"role": "system", "content": system_content}.items()), frozenset({"role": "user", "content": user_content}.items()), ] ) return (prompt, correct_answer) def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]: """ Prepare an evaluation item from the reward-bench-2 dataset. Dataset structure: - chosen: list with 1 element (the best response) - rejected: list with 3+ elements (worse responses) - We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices """ try: question = item.get("prompt", "") chosen_responses = item.get("chosen", []) rejected_responses = item.get("rejected", []) # Validate required fields if not question: return None, None # Take one chosen response and (num_choices-1) rejected responses required_rejected = self.config.num_choices - 1 if ( len(chosen_responses) == 0 or len(rejected_responses) < required_rejected ): return None, None chosen = chosen_responses[0] rejected = rejected_responses[:required_rejected] # Validate response content if not chosen or not all(rejected): return None, None # Create list with answer and whether it's correct data = [(chosen, True)] + [(r, False) for r in rejected] random.shuffle(data) # Extract shuffled answers and find correct position shuffled_answers = [item[0] for item in data] correct_index = next( i for i, (_, is_correct) in enumerate(data) if is_correct ) correct_answer = self.choice_letters[correct_index] # Create system message system_content = self._create_system_content() # Create user prompt user_content = self.create_judgment_prompt(question, shuffled_answers) prompt = tuple( [ frozenset({"role": "system", "content": system_content}.items()), frozenset({"role": "user", "content": user_content}.items()), ] ) return prompt, correct_answer except Exception as e: print(f"Error preparing evaluation item: {e}") return None, None async def collect_trajectories(self, item: Item) -> Tuple[ScoredDataGroup, List]: """Collect and score model trajectories.""" messages, prompt_text = self._prepare_completion_input(item[0]) completion_params = self._get_train_completion_params() completions = await self.server.completion( prompt=prompt_text, **completion_params ) # Build trajectories without duplicating message construction to_score = [] for completion_choice in completions.choices: # Add assistant response to existing messages trajectory_messages = messages + [ {"role": "assistant", "content": completion_choice.text} ] to_score.append((tuple(trajectory_messages), item[1])) scored_data = await self.score(to_score) # Add rollouts for wandb visualization if scored_data is not None: await self.add_rollouts_for_wandb(scored_data, item) return scored_data, [] async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]: """Score a group of rollout data.""" if not rollout_group_data: return None try: scores = ScoredDataGroup() scores["tokens"] = [] scores["masks"] = [] scores["scores"] = [] random.shuffle(rollout_group_data) for item in rollout_group_data: # Simplified validation if not item or len(item) < 2 or not item[0]: continue model_response = item[0][-1]["content"] ground_truth = item[1] predicted_answer = self.process_judgement( model_response, track_metrics=True ) reward = 1.0 if predicted_answer == ground_truth else 0.0 # Track correct judgments per letter if ( predicted_answer == ground_truth and predicted_answer != "format_error" ): self.judgment_letter_correct[predicted_answer] += 1 out_dict = tokenize_for_trainer(self.tokenizer, item[0]) tokens = out_dict["tokens"] masks = out_dict["masks"] # Skip obviously bad examples if len([1 for mask in masks if mask != -100]) < 10: continue scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(reward) # Use reward directly (1.0 or 0.0) if len(scores["tokens"]) >= self.config.group_size: break if not scores["tokens"]: return None # Update percent correct buffer for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) # Return None if all scores are the same (no learning signal) if len(set(scores["scores"])) == 1: return None return scores except Exception as e: print(f"Error in score method: {e}") return None async def rollout_and_score_eval(self, test_item: dict) -> dict: """Rollout and score evaluation with detailed sample data collection.""" try: prompt, ground_truth = self.prepare_eval_item(test_item) if prompt is None: return {"score": 0.0, "sample": None} messages, prompt_text = self._prepare_completion_input(prompt) completion_params = self._get_eval_completion_params() completion = await self.server.completion( prompt=prompt_text, **completion_params ) if not completion.choices: return {"score": 0.0, "sample": None} model_response = completion.choices[0].text predicted_answer = self.process_judgement( model_response, track_metrics=False ) score = 1.0 if predicted_answer == ground_truth else 0.0 # Extract question and answer choices from the user message user_content = messages[1]["content"] question_match = self._question_pattern.search(user_content) question = ( question_match.group(1).strip() if question_match else "Unknown question" ) # Extract individual answer choices for all configured letters answer_choices = {} for letter in self.choice_letters: match = self._answer_choice_patterns[letter].search(user_content) if match: answer_choices[letter] = match.group(1).strip() # Add full conversation including model response full_messages = messages + [ {"role": "assistant", "content": model_response} ] sample = { "messages": full_messages, "question": question, "answer_choices": answer_choices, "ground_truth": ground_truth, "predicted_judgment": predicted_answer, "score": int(score), "correct": bool(score), "finish_reason": completion.choices[0].finish_reason, "thinking_mode": self.config.thinking_mode, "format_compliant": predicted_answer != "format_error", "dataset_item_id": test_item.get("id", "unknown"), "dataset_subset": test_item.get("subset", "unknown"), "num_choices": self.config.num_choices, } # Add thinking-specific parsing info if self.config.thinking_mode: if "" in model_response: sample["response_after_think"] = model_response.split("")[ -1 ].strip() sample["thinking_content"] = self._thinking_extract_pattern.search( model_response ) if sample["thinking_content"]: sample["thinking_content"] = ( sample["thinking_content"].group(1).strip() ) else: sample["response_after_think"] = model_response sample["thinking_content"] = None return {"score": score, "sample": sample} except Exception as e: print(f"Error in rollout_and_score_eval: {e}") return {"score": 0.0, "sample": None} def _calculate_response_metrics( self, samples: List[dict], thinking_mode_used: bool ) -> Tuple[List[int], int, Dict[str, int]]: """Calculate response-related metrics from samples.""" response_lengths = [] thinking_utilization = 0 judgment_counts = {letter: 0 for letter in self.choice_letters} judgment_counts["format_error"] = 0 for sample in samples: if not sample: continue # Track response length messages = sample.get("messages", []) if messages: assistant_msg = messages[-1].get("content", "") response_lengths.append(len(assistant_msg)) # Track thinking utilization in thinking mode if thinking_mode_used: thinking_content = sample.get("thinking_content") if thinking_content: thinking_utilization += 1 # Track judgment distribution predicted_judgment = sample.get("predicted_judgment", "format_error") if predicted_judgment in judgment_counts: judgment_counts[predicted_judgment] += 1 return response_lengths, thinking_utilization, judgment_counts async def evaluate(self, *args, **kwargs) -> None: """Evaluate the model on the test dataset.""" start_time = time.time() try: eval_tasks = [ self.rollout_and_score_eval(test_item) for test_item in self.test ] results = await tqdm_asyncio.gather(*eval_tasks) # Filter valid results valid_results = [ result for result in results if not isinstance(result, Exception) and result and result.get("sample") is not None ] if not valid_results: print("Warning: No valid evaluation results obtained") return except Exception as e: print(f"Error during evaluation: {e}") return # Extract scores and samples from valid results scores = [result["score"] for result in valid_results] samples = [result["sample"] for result in valid_results] valid_scores = [s for s in scores if s is not None] if not valid_scores: print("Warning: No valid scores found during evaluation") return percent_correct = sum(valid_scores) / len(valid_scores) self.eval_metrics.append(("eval/percent_correct", percent_correct)) # Track performance by subset if available subset_scores = {} for i, sample in enumerate(samples): if sample and i < len(scores): subset = sample.get("dataset_subset", "unknown") if subset not in subset_scores: subset_scores[subset] = [] subset_scores[subset].append(scores[i]) # Log subset-specific metrics for subset, subset_score_list in subset_scores.items(): valid_subset_scores = [s for s in subset_score_list if s is not None] if valid_subset_scores: avg_score = sum(valid_subset_scores) / len(valid_subset_scores) self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score)) # Calculate additional metrics format_compliant = sum( 1 for sample in samples if sample.get("format_compliant", False) ) thinking_mode_used = self.config.thinking_mode # Get response metrics response_lengths, thinking_utilization, judgment_counts = ( self._calculate_response_metrics(samples, thinking_mode_used) ) # Response length metrics if response_lengths: avg_response_length = sum(response_lengths) / len(response_lengths) response_length_std = ( sum((x - avg_response_length) ** 2 for x in response_lengths) / len(response_lengths) ) ** 0.5 self.eval_metrics.append(("eval/avg_response_length", avg_response_length)) self.eval_metrics.append(("eval/response_length_std", response_length_std)) # Thinking utilization rate if thinking_mode_used and samples: thinking_utilization_rate = thinking_utilization / len(samples) self.eval_metrics.append( ("eval/thinking_utilization_rate", thinking_utilization_rate) ) # Judgment distribution metrics total_judgments = sum(judgment_counts.values()) if total_judgments > 0: # Calculate entropy for judgment balance entropy = 0.0 for count in judgment_counts.values(): if count > 0: freq = count / total_judgments entropy -= freq * math.log(freq) self.eval_metrics.append(("eval/judgment_entropy", entropy)) # Most common judgment frequency (bias detection) max_judgment_count = max(judgment_counts.values()) most_common_judgment_freq = max_judgment_count / total_judgments self.eval_metrics.append( ("eval/most_common_judgment_freq", most_common_judgment_freq) ) # Format error rate format_error_rate = judgment_counts["format_error"] / total_judgments self.eval_metrics.append(("eval/format_error_rate", format_error_rate)) # Add overall dataset statistics self.eval_metrics.append(("eval/total_items", len(self.test))) self.eval_metrics.append(("eval/valid_scores", len(valid_scores))) self.eval_metrics.append(("eval/subset_count", len(subset_scores))) self.eval_metrics.append( ( "eval/format_compliance_rate", format_compliant / len(samples) if samples else 0.0, ) ) end_time = time.time() # Build evaluation metrics dict eval_metrics = { "eval/percent_correct": percent_correct, "eval/total_samples": len(samples), "eval/correct_samples": sum(valid_scores), "eval/format_compliance_rate": ( format_compliant / len(samples) if samples else 0.0 ), } # Add response length metrics if response_lengths: eval_metrics["eval/avg_response_length"] = avg_response_length eval_metrics["eval/response_length_std"] = response_length_std # Add thinking utilization if thinking_mode_used and samples: eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate # Add judgment distribution metrics if total_judgments > 0: eval_metrics["eval/judgment_entropy"] = entropy eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq eval_metrics["eval/format_error_rate"] = format_error_rate # Add subset metrics for subset, subset_score_list in subset_scores.items(): valid_subset_scores = [s for s in subset_score_list if s is not None] if valid_subset_scores: avg_score = sum(valid_subset_scores) / len(valid_subset_scores) eval_metrics[f"eval/percent_correct_{subset}"] = avg_score try: await self.evaluate_log( metrics=eval_metrics, samples=samples, start_time=start_time, end_time=end_time, generation_parameters={ "temperature": self.config.eval_temperature, "max_tokens": self.config.eval_max_tokens, "thinking_mode": thinking_mode_used, }, ) except Exception as e: print(f"Error logging evaluation results: {e}") async def add_rollouts_for_wandb( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], item: Item = None, ) -> None: """Add rollouts to wandb for visualization.""" if item is None or scored_data is None or not scored_data.get("tokens"): return # Extract ground truth and question info ground_truth = item[1] # Extract question from the item prompt question_info = "unknown_question" try: # The item[0] contains the prompt tuple with system and user messages for role_dict in item[0]: role_dict_converted = dict(role_dict) if role_dict_converted.get("role") == "user": user_content = role_dict_converted.get("content", "") # Extract question from the user message format question_match = self._question_pattern.search(user_content) if question_match: question_info = question_match.group(1).strip() break except Exception: # Fallback to placeholder if extraction fails question_info = "extraction_failed" # Keep a reasonable number of rollouts num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: num_keep = self.config.group_size num_keep = min(num_keep, len(scored_data["tokens"])) current_rollouts = [] mode = "thinking" if self.config.thinking_mode else "direct" for i in range(num_keep): # Decode the full trajectory full_text = self.tokenizer.decode( scored_data["tokens"][i], skip_special_tokens=True ) score_val = scored_data["scores"][i] # Extract the model's judgment predicted_judgment = "unknown" try: # Try to get model response from messages or decode from tokens messages = scored_data.get("messages", []) if i < len(messages) and isinstance(messages[i], list) and messages[i]: model_response = messages[i][-1].get("content", "") else: # Fallback to decoding tokens model_response = full_text predicted_judgment = self.process_judgement( model_response, track_metrics=False ) except Exception: predicted_judgment = "parse_error" current_rollouts.append( ( full_text, score_val, ground_truth, predicted_judgment, question_info, mode, ) ) self.rollouts_for_wandb.append(current_rollouts) # Keep only recent rollouts if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) async def create_rollout_table(self, wandb_metrics: Dict) -> Dict: """Create wandb table for rollout visualization.""" if not self.rollouts_for_wandb: return wandb_metrics table = wandb.Table( columns=[ "full_text", "score", "ground_truth", "predicted_judgment", "question_info", "mode", ] ) for group_rollouts in self.rollouts_for_wandb: for rollout_tuple in group_rollouts: if len(rollout_tuple) == 6: table.add_data(*rollout_tuple) wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): """Log metrics to wandb.""" if wandb_metrics is None: wandb_metrics = {} # Basic accuracy metrics if self.percent_correct_buffer: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) # Judgment letter distribution and accuracy total_letters = sum(self.judgment_letter_counts.values()) if total_letters > 0: # Calculate entropy once entropy = 0.0 for letter in self.choice_letters: letter_count = self.judgment_letter_counts[letter] letter_correct = self.judgment_letter_correct[letter] # Letter frequency and accuracy freq = letter_count / total_letters wandb_metrics[f"train/judgment_freq_{letter}"] = freq wandb_metrics[f"train/judgment_acc_{letter}"] = ( letter_correct / letter_count if letter_count > 0 else 0.0 ) # Accumulate entropy if freq > 0: entropy -= freq * math.log(freq) wandb_metrics["train/judgment_entropy"] = entropy wandb_metrics["train/judgment_balance"] = entropy / math.log( self.config.num_choices ) # Normalized entropy # Error rate and other metrics if self.total_judgments > 0: wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments wandb_metrics["train/format_compliance_rate"] = 1.0 - ( self.error_count / self.total_judgments ) # Configuration and mode metrics wandb_metrics.update( { "train/thinking_mode_enabled": ( 1.0 if self.config.thinking_mode else 0.0 ), "train/total_judgments": self.total_judgments, "config/group_size": self.config.group_size, "config/max_token_length": self.config.max_token_length, "config/num_choices": self.config.num_choices, } ) # Reset training metrics self._reset_metrics() # Add evaluation metrics for metric_name, metric_value in self.eval_metrics: wandb_metrics[metric_name] = metric_value self.eval_metrics = [] # Add rollout table wandb_metrics = await self.create_rollout_table(wandb_metrics) await super().wandb_log(wandb_metrics) if __name__ == "__main__": PairwiseJudgementEnv.cli()