diff --git a/environments/pairwise_judgement_environment.py b/environments/pairwise_judgement_environment.py index 97600964..82c87311 100644 --- a/environments/pairwise_judgement_environment.py +++ b/environments/pairwise_judgement_environment.py @@ -23,44 +23,44 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer class PairwiseJudgementConfig(BaseEnvConfig): """Configuration for PairwiseJudgementEnv with thinking mode and configurable options.""" - + thinking_mode: bool = Field( default=False, description="Whether to enable thinking mode with tags.", ) - + num_choices: int = Field( default=4, ge=2, le=26, description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).", ) - + custom_thinking_prompt: Optional[str] = Field( default=None, description="Custom thinking prompt. If None, uses the default thinking prompt.", ) - + custom_judgment_prompt: Optional[str] = Field( default=None, description="Custom judgment prompt. If None, uses the default judgment prompt.", ) - + eval_temperature: float = Field( default=0.6, description="Temperature for evaluation completions.", ) - + rollout_temperature: float = Field( default=0.8, description="Temperature for training rollout completions.", ) - + eval_max_tokens: int = Field( default=1024 * 16, description="Maximum tokens for evaluation completions.", ) - + train_max_tokens: int = Field( default=1024 * 16, description="Maximum tokens for training completions.", @@ -70,7 +70,7 @@ class PairwiseJudgementConfig(BaseEnvConfig): class PairwiseJudgementEnv(BaseEnv): name = "pairwise_judgement" env_config_cls = PairwiseJudgementConfig - + def __init__( self, config: PairwiseJudgementConfig, @@ -82,31 +82,38 @@ class PairwiseJudgementEnv(BaseEnv): self.config: PairwiseJudgementConfig = config self.percent_correct_buffer = [] self.eval_metrics = [] - + # Generate choice letters based on num_choices (A, B, C, D... up to Z) self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)] - + # Initialize detailed metrics tracking for all choice letters self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters} self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters} self.error_count = 0 # Failed to follow format self.total_judgments = 0 self.rollouts_for_wandb = [] - + # Pre-compile regex patterns for performance - self._think_pattern = re.compile(r'') - self._think_close_pattern = re.compile(r'') - self._think_content_pattern = re.compile(r'\s*(.*)', re.DOTALL) - self._question_pattern = re.compile(r'\[User Question\]\s*(.*?)\s*\[The Start of Assistant A', re.DOTALL) - self._thinking_extract_pattern = re.compile(r'(.*?)', re.DOTALL) - + self._think_pattern = re.compile(r"") + self._think_close_pattern = re.compile(r"") + self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL) + self._question_pattern = re.compile( + r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL + ) + self._thinking_extract_pattern = re.compile(r"(.*?)", re.DOTALL) + # Pre-compile choice patterns for each letter - self._choice_patterns = {letter: re.compile(rf'\[\[{letter}\]\]') for letter in self.choice_letters} + self._choice_patterns = { + letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters + } self._answer_choice_patterns = { - letter: re.compile(rf'\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]', re.DOTALL) + letter: re.compile( + rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]", + re.DOTALL, + ) for letter in self.choice_letters } - + # System prompts (use custom ones if provided, otherwise defaults) self.thinking_system_prompt = self._get_thinking_prompt() self.judgment_system_prompt = self._get_judgment_prompt() @@ -114,20 +121,26 @@ class PairwiseJudgementEnv(BaseEnv): def _get_thinking_prompt(self) -> str: """Get thinking system prompt.""" return ( - self.config.custom_thinking_prompt if self.config.custom_thinking_prompt else - "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " + self.config.custom_thinking_prompt + if self.config.custom_thinking_prompt + else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) - + def _get_judgment_prompt(self) -> str: """Get judgment system prompt.""" if self.config.custom_judgment_prompt: return self.config.custom_judgment_prompt - - choice_format_examples = ", ".join([f'"[[{letter}]]" if assistant {letter} is best' for letter in self.choice_letters]) - + + choice_format_examples = ", ".join( + [ + f'"[[{letter}]]" if assistant {letter} is best' + for letter in self.choice_letters + ] + ) + return ( f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. " f"You should choose the assistant that follows the user's instructions and answers the user's question best. " @@ -136,7 +149,7 @@ class PairwiseJudgementEnv(BaseEnv): f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. " f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. " f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: " - f'{choice_format_examples}.' + f"{choice_format_examples}." ) def _reset_metrics(self) -> None: @@ -226,45 +239,55 @@ class PairwiseJudgementEnv(BaseEnv): # Note: This isn't actually used since get_next_item() generates synthetic examples self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100 print(f"Using synthetic placeholder training data due to error: {e}") - + # Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH) - self.test = load_dataset("allenai/reward-bench-2", split="test", trust_remote_code=True) + self.test = load_dataset( + "allenai/reward-bench-2", split="test", trust_remote_code=True + ) print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples") - + # Debug: Show sample evaluation item structure if len(self.test) > 0: try: sample_item = self.test[0] print(f"\nSample eval item structure:") print(f"- Available keys: {list(sample_item.keys())}") - + # Handle different dataset structures - if 'prompt' in sample_item: + if "prompt" in sample_item: print(f"- Prompt: {sample_item['prompt'][:100]}...") - elif 'chosen' in sample_item and isinstance(sample_item['chosen'], str): + elif "chosen" in sample_item and isinstance(sample_item["chosen"], str): print(f"- Chosen (string): {sample_item['chosen'][:100]}...") - elif 'rejected' in sample_item and isinstance(sample_item['rejected'], str): + elif "rejected" in sample_item and isinstance( + sample_item["rejected"], str + ): print(f"- Rejected (string): {sample_item['rejected'][:100]}...") - - if 'chosen' in sample_item: - if isinstance(sample_item['chosen'], list): + + if "chosen" in sample_item: + if isinstance(sample_item["chosen"], list): print(f"- Chosen responses: {len(sample_item['chosen'])}") - if sample_item['chosen']: - print(f"- First chosen (truncated): {sample_item['chosen'][0][:200]}...") + if sample_item["chosen"]: + print( + f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..." + ) else: print(f"- Chosen (string): {sample_item['chosen'][:200]}...") - - if 'rejected' in sample_item: - if isinstance(sample_item['rejected'], list): + + if "rejected" in sample_item: + if isinstance(sample_item["rejected"], list): print(f"- Rejected responses: {len(sample_item['rejected'])}") - if sample_item['rejected']: - print(f"- First rejected (truncated): {sample_item['rejected'][0][:200]}...") + if sample_item["rejected"]: + print( + f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..." + ) else: - print(f"- Rejected (string): {sample_item['rejected'][:200]}...") - + print( + f"- Rejected (string): {sample_item['rejected'][:200]}..." + ) + except Exception as e: print(f"Warning: Could not display sample item structure: {e}") - + self.iter = 0 def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None: @@ -280,13 +303,13 @@ class PairwiseJudgementEnv(BaseEnv): # Check for exactly one pair of think tags using pre-compiled patterns think_open_count = len(self._think_pattern.findall(judgment)) think_close_count = len(self._think_close_pattern.findall(judgment)) - + if think_open_count != 1 or think_close_count != 1: if track_metrics: self.error_count += 1 self.total_judgments += 1 return "format_error" - + # Parse only content after tags match = self._think_content_pattern.search(judgment) if match: @@ -296,17 +319,17 @@ class PairwiseJudgementEnv(BaseEnv): self.error_count += 1 self.total_judgments += 1 return "format_error" - + if track_metrics: self.total_judgments += 1 - + # Check for each possible choice letter using pre-compiled patterns for letter in self.choice_letters: if self._choice_patterns[letter].search(judgment): if track_metrics: self.judgment_letter_counts[letter] += 1 return letter - + # No valid judgment found if track_metrics: self.error_count += 1 @@ -315,23 +338,25 @@ class PairwiseJudgementEnv(BaseEnv): def create_judgment_prompt(self, question: str, answers: List[str]) -> str: """Create the user prompt for judgment task.""" if len(answers) != self.config.num_choices: - raise ValueError(f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}") - + raise ValueError( + f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}" + ) + prompt = f"[User Question]\n{question}\n\n" - + for i, answer in enumerate(answers): letter = self.choice_letters[i] prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n" - + return prompt.strip() async def get_next_item(self) -> Item: """Generate next training item with synthetic data.""" self.iter += 1 - + # Create system message system_content = self._create_system_content() - + # Create varied placeholder judgment tasks examples = [ { @@ -343,8 +368,8 @@ class PairwiseJudgementEnv(BaseEnv): "I don't know the answer to this question.", "France doesn't have a capital city.", "The capital changes every year in France.", - "Paris is just a city, not a capital." - ] + "Paris is just a city, not a capital.", + ], }, { "question": "How do you fix a memory leak in Python?", @@ -355,8 +380,8 @@ class PairwiseJudgementEnv(BaseEnv): "You need to reinstall Python to fix memory leaks.", "Memory leaks don't exist in Python because it's interpreted.", "Use more RAM to solve memory leaks.", - "Delete the Python installation and use a different language." - ] + "Delete the Python installation and use a different language.", + ], }, { "question": "Explain the difference between machine learning and artificial intelligence.", @@ -367,89 +392,102 @@ class PairwiseJudgementEnv(BaseEnv): "AI is only about robots, while machine learning is only about statistics.", "Machine learning came before AI historically.", "AI is a subset of machine learning, not the other way around.", - "There is no difference; they are marketing terms for the same technology." - ] - } + "There is no difference; they are marketing terms for the same technology.", + ], + }, ] - + # Select random example example = random.choice(examples) - + # Create list with correct and incorrect answers, ensuring we have enough - incorrect_answers = example["incorrect"][:self.config.num_choices - 1] # Take enough incorrect answers + incorrect_answers = example["incorrect"][ + : self.config.num_choices - 1 + ] # Take enough incorrect answers all_answers = [example["correct"]] + incorrect_answers - + # If we don't have enough incorrect answers, pad with generic ones while len(all_answers) < self.config.num_choices: - all_answers.append("I don't have enough information to answer this question.") - + all_answers.append( + "I don't have enough information to answer this question." + ) + random.shuffle(all_answers) - + # Find where correct answer ended up correct_index = all_answers.index(example["correct"]) correct_answer = self.choice_letters[correct_index] - + user_content = self.create_judgment_prompt(example["question"], all_answers) - - prompt = tuple([ - frozenset({"role": "system", "content": system_content}.items()), - frozenset({"role": "user", "content": user_content}.items()) - ]) - + + prompt = tuple( + [ + frozenset({"role": "system", "content": system_content}.items()), + frozenset({"role": "user", "content": user_content}.items()), + ] + ) + return (prompt, correct_answer) def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]: """ Prepare an evaluation item from the reward-bench-2 dataset. - + Dataset structure: - chosen: list with 1 element (the best response) - - rejected: list with 3+ elements (worse responses) + - rejected: list with 3+ elements (worse responses) - We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices """ try: question = item.get("prompt", "") chosen_responses = item.get("chosen", []) rejected_responses = item.get("rejected", []) - + # Validate required fields if not question: return None, None - + # Take one chosen response and (num_choices-1) rejected responses required_rejected = self.config.num_choices - 1 - if len(chosen_responses) == 0 or len(rejected_responses) < required_rejected: + if ( + len(chosen_responses) == 0 + or len(rejected_responses) < required_rejected + ): return None, None - + chosen = chosen_responses[0] rejected = rejected_responses[:required_rejected] - + # Validate response content if not chosen or not all(rejected): return None, None - + # Create list with answer and whether it's correct data = [(chosen, True)] + [(r, False) for r in rejected] random.shuffle(data) - + # Extract shuffled answers and find correct position shuffled_answers = [item[0] for item in data] - correct_index = next(i for i, (_, is_correct) in enumerate(data) if is_correct) + correct_index = next( + i for i, (_, is_correct) in enumerate(data) if is_correct + ) correct_answer = self.choice_letters[correct_index] - + # Create system message system_content = self._create_system_content() - + # Create user prompt user_content = self.create_judgment_prompt(question, shuffled_answers) - - prompt = tuple([ - frozenset({"role": "system", "content": system_content}.items()), - frozenset({"role": "user", "content": user_content}.items()) - ]) - + + prompt = tuple( + [ + frozenset({"role": "system", "content": system_content}.items()), + frozenset({"role": "user", "content": user_content}.items()), + ] + ) + return prompt, correct_answer - + except Exception as e: print(f"Error preparing evaluation item: {e}") return None, None @@ -459,8 +497,10 @@ class PairwiseJudgementEnv(BaseEnv): messages, prompt_text = self._prepare_completion_input(item[0]) completion_params = self._get_train_completion_params() - completions = await self.server.completion(prompt=prompt_text, **completion_params) - + completions = await self.server.completion( + prompt=prompt_text, **completion_params + ) + # Build trajectories without duplicating message construction to_score = [] for completion_choice in completions.choices: @@ -469,71 +509,76 @@ class PairwiseJudgementEnv(BaseEnv): {"role": "assistant", "content": completion_choice.text} ] to_score.append((tuple(trajectory_messages), item[1])) - + scored_data = await self.score(to_score) - + # Add rollouts for wandb visualization if scored_data is not None: await self.add_rollouts_for_wandb(scored_data, item) - + return scored_data, [] async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]: """Score a group of rollout data.""" if not rollout_group_data: return None - + try: scores = ScoredDataGroup() scores["tokens"] = [] scores["masks"] = [] scores["scores"] = [] - + random.shuffle(rollout_group_data) - + for item in rollout_group_data: # Simplified validation if not item or len(item) < 2 or not item[0]: continue - + model_response = item[0][-1]["content"] ground_truth = item[1] - - predicted_answer = self.process_judgement(model_response, track_metrics=True) + + predicted_answer = self.process_judgement( + model_response, track_metrics=True + ) reward = 1.0 if predicted_answer == ground_truth else 0.0 - + # Track correct judgments per letter - if predicted_answer == ground_truth and predicted_answer != "format_error": + if ( + predicted_answer == ground_truth + and predicted_answer != "format_error" + ): self.judgment_letter_correct[predicted_answer] += 1 - + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) tokens = out_dict["tokens"] masks = out_dict["masks"] - + # Skip obviously bad examples if len([1 for mask in masks if mask != -100]) < 10: continue - + scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(reward) # Use reward directly (1.0 or 0.0) - + if len(scores["tokens"]) >= self.config.group_size: break - + if not scores["tokens"]: return None - + # Update percent correct buffer for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) - + # Return None if all scores are the same (no learning signal) if len(set(scores["scores"])) == 1: return None - + return scores - + except Exception as e: print(f"Error in score method: {e}") return None @@ -544,35 +589,45 @@ class PairwiseJudgementEnv(BaseEnv): prompt, ground_truth = self.prepare_eval_item(test_item) if prompt is None: return {"score": 0.0, "sample": None} - + messages, prompt_text = self._prepare_completion_input(prompt) completion_params = self._get_eval_completion_params() - - completion = await self.server.completion(prompt=prompt_text, **completion_params) - + + completion = await self.server.completion( + prompt=prompt_text, **completion_params + ) + if not completion.choices: return {"score": 0.0, "sample": None} - + model_response = completion.choices[0].text - predicted_answer = self.process_judgement(model_response, track_metrics=False) - + predicted_answer = self.process_judgement( + model_response, track_metrics=False + ) + score = 1.0 if predicted_answer == ground_truth else 0.0 - + # Extract question and answer choices from the user message user_content = messages[1]["content"] question_match = self._question_pattern.search(user_content) - question = question_match.group(1).strip() if question_match else "Unknown question" - + question = ( + question_match.group(1).strip() + if question_match + else "Unknown question" + ) + # Extract individual answer choices for all configured letters answer_choices = {} for letter in self.choice_letters: match = self._answer_choice_patterns[letter].search(user_content) if match: answer_choices[letter] = match.group(1).strip() - + # Add full conversation including model response - full_messages = messages + [{"role": "assistant", "content": model_response}] - + full_messages = messages + [ + {"role": "assistant", "content": model_response} + ] + sample = { "messages": full_messages, "question": question, @@ -588,88 +643,101 @@ class PairwiseJudgementEnv(BaseEnv): "dataset_subset": test_item.get("subset", "unknown"), "num_choices": self.config.num_choices, } - + # Add thinking-specific parsing info if self.config.thinking_mode: if "" in model_response: - sample["response_after_think"] = model_response.split("")[-1].strip() - sample["thinking_content"] = self._thinking_extract_pattern.search(model_response) + sample["response_after_think"] = model_response.split("")[ + -1 + ].strip() + sample["thinking_content"] = self._thinking_extract_pattern.search( + model_response + ) if sample["thinking_content"]: - sample["thinking_content"] = sample["thinking_content"].group(1).strip() + sample["thinking_content"] = ( + sample["thinking_content"].group(1).strip() + ) else: sample["response_after_think"] = model_response sample["thinking_content"] = None - + return {"score": score, "sample": sample} - + except Exception as e: print(f"Error in rollout_and_score_eval: {e}") return {"score": 0.0, "sample": None} - def _calculate_response_metrics(self, samples: List[dict], thinking_mode_used: bool) -> Tuple[List[int], int, Dict[str, int]]: + def _calculate_response_metrics( + self, samples: List[dict], thinking_mode_used: bool + ) -> Tuple[List[int], int, Dict[str, int]]: """Calculate response-related metrics from samples.""" response_lengths = [] thinking_utilization = 0 judgment_counts = {letter: 0 for letter in self.choice_letters} judgment_counts["format_error"] = 0 - + for sample in samples: if not sample: continue - + # Track response length messages = sample.get("messages", []) if messages: assistant_msg = messages[-1].get("content", "") response_lengths.append(len(assistant_msg)) - + # Track thinking utilization in thinking mode if thinking_mode_used: thinking_content = sample.get("thinking_content") if thinking_content: thinking_utilization += 1 - + # Track judgment distribution predicted_judgment = sample.get("predicted_judgment", "format_error") if predicted_judgment in judgment_counts: judgment_counts[predicted_judgment] += 1 - + return response_lengths, thinking_utilization, judgment_counts async def evaluate(self, *args, **kwargs) -> None: """Evaluate the model on the test dataset.""" start_time = time.time() - + try: - eval_tasks = [self.rollout_and_score_eval(test_item) for test_item in self.test] + eval_tasks = [ + self.rollout_and_score_eval(test_item) for test_item in self.test + ] results = await tqdm_asyncio.gather(*eval_tasks) - + # Filter valid results valid_results = [ - result for result in results - if not isinstance(result, Exception) and result and result.get("sample") is not None + result + for result in results + if not isinstance(result, Exception) + and result + and result.get("sample") is not None ] - + if not valid_results: print("Warning: No valid evaluation results obtained") return - + except Exception as e: print(f"Error during evaluation: {e}") return - + # Extract scores and samples from valid results scores = [result["score"] for result in valid_results] samples = [result["sample"] for result in valid_results] valid_scores = [s for s in scores if s is not None] - + if not valid_scores: print("Warning: No valid scores found during evaluation") return - + percent_correct = sum(valid_scores) / len(valid_scores) self.eval_metrics.append(("eval/percent_correct", percent_correct)) - + # Track performance by subset if available subset_scores = {} for i, sample in enumerate(samples): @@ -678,33 +746,42 @@ class PairwiseJudgementEnv(BaseEnv): if subset not in subset_scores: subset_scores[subset] = [] subset_scores[subset].append(scores[i]) - + # Log subset-specific metrics for subset, subset_score_list in subset_scores.items(): valid_subset_scores = [s for s in subset_score_list if s is not None] if valid_subset_scores: avg_score = sum(valid_subset_scores) / len(valid_subset_scores) self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score)) - + # Calculate additional metrics - format_compliant = sum(1 for sample in samples if sample.get("format_compliant", False)) + format_compliant = sum( + 1 for sample in samples if sample.get("format_compliant", False) + ) thinking_mode_used = self.config.thinking_mode - + # Get response metrics - response_lengths, thinking_utilization, judgment_counts = self._calculate_response_metrics(samples, thinking_mode_used) - + response_lengths, thinking_utilization, judgment_counts = ( + self._calculate_response_metrics(samples, thinking_mode_used) + ) + # Response length metrics if response_lengths: avg_response_length = sum(response_lengths) / len(response_lengths) - response_length_std = (sum((x - avg_response_length) ** 2 for x in response_lengths) / len(response_lengths)) ** 0.5 + response_length_std = ( + sum((x - avg_response_length) ** 2 for x in response_lengths) + / len(response_lengths) + ) ** 0.5 self.eval_metrics.append(("eval/avg_response_length", avg_response_length)) self.eval_metrics.append(("eval/response_length_std", response_length_std)) - + # Thinking utilization rate if thinking_mode_used and samples: thinking_utilization_rate = thinking_utilization / len(samples) - self.eval_metrics.append(("eval/thinking_utilization_rate", thinking_utilization_rate)) - + self.eval_metrics.append( + ("eval/thinking_utilization_rate", thinking_utilization_rate) + ) + # Judgment distribution metrics total_judgments = sum(judgment_counts.values()) if total_judgments > 0: @@ -715,54 +792,63 @@ class PairwiseJudgementEnv(BaseEnv): freq = count / total_judgments entropy -= freq * math.log(freq) self.eval_metrics.append(("eval/judgment_entropy", entropy)) - + # Most common judgment frequency (bias detection) max_judgment_count = max(judgment_counts.values()) most_common_judgment_freq = max_judgment_count / total_judgments - self.eval_metrics.append(("eval/most_common_judgment_freq", most_common_judgment_freq)) - + self.eval_metrics.append( + ("eval/most_common_judgment_freq", most_common_judgment_freq) + ) + # Format error rate format_error_rate = judgment_counts["format_error"] / total_judgments self.eval_metrics.append(("eval/format_error_rate", format_error_rate)) - + # Add overall dataset statistics self.eval_metrics.append(("eval/total_items", len(self.test))) self.eval_metrics.append(("eval/valid_scores", len(valid_scores))) self.eval_metrics.append(("eval/subset_count", len(subset_scores))) - self.eval_metrics.append(("eval/format_compliance_rate", format_compliant / len(samples) if samples else 0.0)) - + self.eval_metrics.append( + ( + "eval/format_compliance_rate", + format_compliant / len(samples) if samples else 0.0, + ) + ) + end_time = time.time() - + # Build evaluation metrics dict eval_metrics = { "eval/percent_correct": percent_correct, "eval/total_samples": len(samples), "eval/correct_samples": sum(valid_scores), - "eval/format_compliance_rate": format_compliant / len(samples) if samples else 0.0, + "eval/format_compliance_rate": ( + format_compliant / len(samples) if samples else 0.0 + ), } - + # Add response length metrics if response_lengths: eval_metrics["eval/avg_response_length"] = avg_response_length eval_metrics["eval/response_length_std"] = response_length_std - + # Add thinking utilization if thinking_mode_used and samples: eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate - + # Add judgment distribution metrics if total_judgments > 0: eval_metrics["eval/judgment_entropy"] = entropy eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq eval_metrics["eval/format_error_rate"] = format_error_rate - + # Add subset metrics for subset, subset_score_list in subset_scores.items(): valid_subset_scores = [s for s in subset_score_list if s is not None] if valid_subset_scores: avg_score = sum(valid_subset_scores) / len(valid_subset_scores) eval_metrics[f"eval/percent_correct_{subset}"] = avg_score - + try: await self.evaluate_log( metrics=eval_metrics, @@ -786,10 +872,10 @@ class PairwiseJudgementEnv(BaseEnv): """Add rollouts to wandb for visualization.""" if item is None or scored_data is None or not scored_data.get("tokens"): return - + # Extract ground truth and question info ground_truth = item[1] - + # Extract question from the item prompt question_info = "unknown_question" try: @@ -806,24 +892,24 @@ class PairwiseJudgementEnv(BaseEnv): except Exception: # Fallback to placeholder if extraction fails question_info = "extraction_failed" - + # Keep a reasonable number of rollouts num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: num_keep = self.config.group_size - + num_keep = min(num_keep, len(scored_data["tokens"])) - + current_rollouts = [] mode = "thinking" if self.config.thinking_mode else "direct" - + for i in range(num_keep): # Decode the full trajectory full_text = self.tokenizer.decode( scored_data["tokens"][i], skip_special_tokens=True ) score_val = scored_data["scores"][i] - + # Extract the model's judgment predicted_judgment = "unknown" try: @@ -834,22 +920,26 @@ class PairwiseJudgementEnv(BaseEnv): else: # Fallback to decoding tokens model_response = full_text - - predicted_judgment = self.process_judgement(model_response, track_metrics=False) + + predicted_judgment = self.process_judgement( + model_response, track_metrics=False + ) except Exception: predicted_judgment = "parse_error" - - current_rollouts.append(( - full_text, - score_val, - ground_truth, - predicted_judgment, - question_info, - mode - )) - + + current_rollouts.append( + ( + full_text, + score_val, + ground_truth, + predicted_judgment, + question_info, + mode, + ) + ) + self.rollouts_for_wandb.append(current_rollouts) - + # Keep only recent rollouts if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) @@ -858,23 +948,23 @@ class PairwiseJudgementEnv(BaseEnv): """Create wandb table for rollout visualization.""" if not self.rollouts_for_wandb: return wandb_metrics - + table = wandb.Table( columns=[ - "full_text", - "score", - "ground_truth", - "predicted_judgment", - "question_info", - "mode" + "full_text", + "score", + "ground_truth", + "predicted_judgment", + "question_info", + "mode", ] ) - + for group_rollouts in self.rollouts_for_wandb: for rollout_tuple in group_rollouts: if len(rollout_tuple) == 6: table.add_data(*rollout_tuple) - + wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics @@ -883,11 +973,13 @@ class PairwiseJudgementEnv(BaseEnv): """Log metrics to wandb.""" if wandb_metrics is None: wandb_metrics = {} - + # Basic accuracy metrics if self.percent_correct_buffer: - wandb_metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer) - + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + # Judgment letter distribution and accuracy total_letters = sum(self.judgment_letter_counts.values()) if total_letters > 0: @@ -896,44 +988,54 @@ class PairwiseJudgementEnv(BaseEnv): for letter in self.choice_letters: letter_count = self.judgment_letter_counts[letter] letter_correct = self.judgment_letter_correct[letter] - + # Letter frequency and accuracy freq = letter_count / total_letters wandb_metrics[f"train/judgment_freq_{letter}"] = freq - wandb_metrics[f"train/judgment_acc_{letter}"] = letter_correct / letter_count if letter_count > 0 else 0.0 - + wandb_metrics[f"train/judgment_acc_{letter}"] = ( + letter_correct / letter_count if letter_count > 0 else 0.0 + ) + # Accumulate entropy if freq > 0: entropy -= freq * math.log(freq) - + wandb_metrics["train/judgment_entropy"] = entropy - wandb_metrics["train/judgment_balance"] = entropy / math.log(self.config.num_choices) # Normalized entropy - + wandb_metrics["train/judgment_balance"] = entropy / math.log( + self.config.num_choices + ) # Normalized entropy + # Error rate and other metrics if self.total_judgments > 0: wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments - wandb_metrics["train/format_compliance_rate"] = 1.0 - (self.error_count / self.total_judgments) - + wandb_metrics["train/format_compliance_rate"] = 1.0 - ( + self.error_count / self.total_judgments + ) + # Configuration and mode metrics - wandb_metrics.update({ - "train/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0, - "train/total_judgments": self.total_judgments, - "config/group_size": self.config.group_size, - "config/max_token_length": self.config.max_token_length, - "config/num_choices": self.config.num_choices, - }) - + wandb_metrics.update( + { + "train/thinking_mode_enabled": ( + 1.0 if self.config.thinking_mode else 0.0 + ), + "train/total_judgments": self.total_judgments, + "config/group_size": self.config.group_size, + "config/max_token_length": self.config.max_token_length, + "config/num_choices": self.config.num_choices, + } + ) + # Reset training metrics self._reset_metrics() - + # Add evaluation metrics for metric_name, metric_value in self.eval_metrics: wandb_metrics[metric_name] = metric_value self.eval_metrics = [] - + # Add rollout table wandb_metrics = await self.create_rollout_table(wandb_metrics) - + await super().wandb_log(wandb_metrics)