diff --git a/environments/pairwise_judgement_environment.py b/environments/pairwise_judgement_environment.py
index 97600964..82c87311 100644
--- a/environments/pairwise_judgement_environment.py
+++ b/environments/pairwise_judgement_environment.py
@@ -23,44 +23,44 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PairwiseJudgementConfig(BaseEnvConfig):
"""Configuration for PairwiseJudgementEnv with thinking mode and configurable options."""
-
+
thinking_mode: bool = Field(
default=False,
description="Whether to enable thinking mode with tags.",
)
-
+
num_choices: int = Field(
default=4,
ge=2,
le=26,
description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).",
)
-
+
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Custom thinking prompt. If None, uses the default thinking prompt.",
)
-
+
custom_judgment_prompt: Optional[str] = Field(
default=None,
description="Custom judgment prompt. If None, uses the default judgment prompt.",
)
-
+
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation completions.",
)
-
+
rollout_temperature: float = Field(
default=0.8,
description="Temperature for training rollout completions.",
)
-
+
eval_max_tokens: int = Field(
default=1024 * 16,
description="Maximum tokens for evaluation completions.",
)
-
+
train_max_tokens: int = Field(
default=1024 * 16,
description="Maximum tokens for training completions.",
@@ -70,7 +70,7 @@ class PairwiseJudgementConfig(BaseEnvConfig):
class PairwiseJudgementEnv(BaseEnv):
name = "pairwise_judgement"
env_config_cls = PairwiseJudgementConfig
-
+
def __init__(
self,
config: PairwiseJudgementConfig,
@@ -82,31 +82,38 @@ class PairwiseJudgementEnv(BaseEnv):
self.config: PairwiseJudgementConfig = config
self.percent_correct_buffer = []
self.eval_metrics = []
-
+
# Generate choice letters based on num_choices (A, B, C, D... up to Z)
self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)]
-
+
# Initialize detailed metrics tracking for all choice letters
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
self.error_count = 0 # Failed to follow format
self.total_judgments = 0
self.rollouts_for_wandb = []
-
+
# Pre-compile regex patterns for performance
- self._think_pattern = re.compile(r'')
- self._think_close_pattern = re.compile(r'')
- self._think_content_pattern = re.compile(r'\s*(.*)', re.DOTALL)
- self._question_pattern = re.compile(r'\[User Question\]\s*(.*?)\s*\[The Start of Assistant A', re.DOTALL)
- self._thinking_extract_pattern = re.compile(r'(.*?)', re.DOTALL)
-
+ self._think_pattern = re.compile(r"")
+ self._think_close_pattern = re.compile(r"")
+ self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL)
+ self._question_pattern = re.compile(
+ r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL
+ )
+ self._thinking_extract_pattern = re.compile(r"(.*?)", re.DOTALL)
+
# Pre-compile choice patterns for each letter
- self._choice_patterns = {letter: re.compile(rf'\[\[{letter}\]\]') for letter in self.choice_letters}
+ self._choice_patterns = {
+ letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters
+ }
self._answer_choice_patterns = {
- letter: re.compile(rf'\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]', re.DOTALL)
+ letter: re.compile(
+ rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]",
+ re.DOTALL,
+ )
for letter in self.choice_letters
}
-
+
# System prompts (use custom ones if provided, otherwise defaults)
self.thinking_system_prompt = self._get_thinking_prompt()
self.judgment_system_prompt = self._get_judgment_prompt()
@@ -114,20 +121,26 @@ class PairwiseJudgementEnv(BaseEnv):
def _get_thinking_prompt(self) -> str:
"""Get thinking system prompt."""
return (
- self.config.custom_thinking_prompt if self.config.custom_thinking_prompt else
- "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
+ self.config.custom_thinking_prompt
+ if self.config.custom_thinking_prompt
+ else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside "
" tags, and then provide your solution or response to the problem."
)
-
+
def _get_judgment_prompt(self) -> str:
"""Get judgment system prompt."""
if self.config.custom_judgment_prompt:
return self.config.custom_judgment_prompt
-
- choice_format_examples = ", ".join([f'"[[{letter}]]" if assistant {letter} is best' for letter in self.choice_letters])
-
+
+ choice_format_examples = ", ".join(
+ [
+ f'"[[{letter}]]" if assistant {letter} is best'
+ for letter in self.choice_letters
+ ]
+ )
+
return (
f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. "
f"You should choose the assistant that follows the user's instructions and answers the user's question best. "
@@ -136,7 +149,7 @@ class PairwiseJudgementEnv(BaseEnv):
f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. "
f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. "
f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "
- f'{choice_format_examples}.'
+ f"{choice_format_examples}."
)
def _reset_metrics(self) -> None:
@@ -226,45 +239,55 @@ class PairwiseJudgementEnv(BaseEnv):
# Note: This isn't actually used since get_next_item() generates synthetic examples
self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100
print(f"Using synthetic placeholder training data due to error: {e}")
-
+
# Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH)
- self.test = load_dataset("allenai/reward-bench-2", split="test", trust_remote_code=True)
+ self.test = load_dataset(
+ "allenai/reward-bench-2", split="test", trust_remote_code=True
+ )
print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples")
-
+
# Debug: Show sample evaluation item structure
if len(self.test) > 0:
try:
sample_item = self.test[0]
print(f"\nSample eval item structure:")
print(f"- Available keys: {list(sample_item.keys())}")
-
+
# Handle different dataset structures
- if 'prompt' in sample_item:
+ if "prompt" in sample_item:
print(f"- Prompt: {sample_item['prompt'][:100]}...")
- elif 'chosen' in sample_item and isinstance(sample_item['chosen'], str):
+ elif "chosen" in sample_item and isinstance(sample_item["chosen"], str):
print(f"- Chosen (string): {sample_item['chosen'][:100]}...")
- elif 'rejected' in sample_item and isinstance(sample_item['rejected'], str):
+ elif "rejected" in sample_item and isinstance(
+ sample_item["rejected"], str
+ ):
print(f"- Rejected (string): {sample_item['rejected'][:100]}...")
-
- if 'chosen' in sample_item:
- if isinstance(sample_item['chosen'], list):
+
+ if "chosen" in sample_item:
+ if isinstance(sample_item["chosen"], list):
print(f"- Chosen responses: {len(sample_item['chosen'])}")
- if sample_item['chosen']:
- print(f"- First chosen (truncated): {sample_item['chosen'][0][:200]}...")
+ if sample_item["chosen"]:
+ print(
+ f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..."
+ )
else:
print(f"- Chosen (string): {sample_item['chosen'][:200]}...")
-
- if 'rejected' in sample_item:
- if isinstance(sample_item['rejected'], list):
+
+ if "rejected" in sample_item:
+ if isinstance(sample_item["rejected"], list):
print(f"- Rejected responses: {len(sample_item['rejected'])}")
- if sample_item['rejected']:
- print(f"- First rejected (truncated): {sample_item['rejected'][0][:200]}...")
+ if sample_item["rejected"]:
+ print(
+ f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..."
+ )
else:
- print(f"- Rejected (string): {sample_item['rejected'][:200]}...")
-
+ print(
+ f"- Rejected (string): {sample_item['rejected'][:200]}..."
+ )
+
except Exception as e:
print(f"Warning: Could not display sample item structure: {e}")
-
+
self.iter = 0
def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None:
@@ -280,13 +303,13 @@ class PairwiseJudgementEnv(BaseEnv):
# Check for exactly one pair of think tags using pre-compiled patterns
think_open_count = len(self._think_pattern.findall(judgment))
think_close_count = len(self._think_close_pattern.findall(judgment))
-
+
if think_open_count != 1 or think_close_count != 1:
if track_metrics:
self.error_count += 1
self.total_judgments += 1
return "format_error"
-
+
# Parse only content after tags
match = self._think_content_pattern.search(judgment)
if match:
@@ -296,17 +319,17 @@ class PairwiseJudgementEnv(BaseEnv):
self.error_count += 1
self.total_judgments += 1
return "format_error"
-
+
if track_metrics:
self.total_judgments += 1
-
+
# Check for each possible choice letter using pre-compiled patterns
for letter in self.choice_letters:
if self._choice_patterns[letter].search(judgment):
if track_metrics:
self.judgment_letter_counts[letter] += 1
return letter
-
+
# No valid judgment found
if track_metrics:
self.error_count += 1
@@ -315,23 +338,25 @@ class PairwiseJudgementEnv(BaseEnv):
def create_judgment_prompt(self, question: str, answers: List[str]) -> str:
"""Create the user prompt for judgment task."""
if len(answers) != self.config.num_choices:
- raise ValueError(f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}")
-
+ raise ValueError(
+ f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}"
+ )
+
prompt = f"[User Question]\n{question}\n\n"
-
+
for i, answer in enumerate(answers):
letter = self.choice_letters[i]
prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n"
-
+
return prompt.strip()
async def get_next_item(self) -> Item:
"""Generate next training item with synthetic data."""
self.iter += 1
-
+
# Create system message
system_content = self._create_system_content()
-
+
# Create varied placeholder judgment tasks
examples = [
{
@@ -343,8 +368,8 @@ class PairwiseJudgementEnv(BaseEnv):
"I don't know the answer to this question.",
"France doesn't have a capital city.",
"The capital changes every year in France.",
- "Paris is just a city, not a capital."
- ]
+ "Paris is just a city, not a capital.",
+ ],
},
{
"question": "How do you fix a memory leak in Python?",
@@ -355,8 +380,8 @@ class PairwiseJudgementEnv(BaseEnv):
"You need to reinstall Python to fix memory leaks.",
"Memory leaks don't exist in Python because it's interpreted.",
"Use more RAM to solve memory leaks.",
- "Delete the Python installation and use a different language."
- ]
+ "Delete the Python installation and use a different language.",
+ ],
},
{
"question": "Explain the difference between machine learning and artificial intelligence.",
@@ -367,89 +392,102 @@ class PairwiseJudgementEnv(BaseEnv):
"AI is only about robots, while machine learning is only about statistics.",
"Machine learning came before AI historically.",
"AI is a subset of machine learning, not the other way around.",
- "There is no difference; they are marketing terms for the same technology."
- ]
- }
+ "There is no difference; they are marketing terms for the same technology.",
+ ],
+ },
]
-
+
# Select random example
example = random.choice(examples)
-
+
# Create list with correct and incorrect answers, ensuring we have enough
- incorrect_answers = example["incorrect"][:self.config.num_choices - 1] # Take enough incorrect answers
+ incorrect_answers = example["incorrect"][
+ : self.config.num_choices - 1
+ ] # Take enough incorrect answers
all_answers = [example["correct"]] + incorrect_answers
-
+
# If we don't have enough incorrect answers, pad with generic ones
while len(all_answers) < self.config.num_choices:
- all_answers.append("I don't have enough information to answer this question.")
-
+ all_answers.append(
+ "I don't have enough information to answer this question."
+ )
+
random.shuffle(all_answers)
-
+
# Find where correct answer ended up
correct_index = all_answers.index(example["correct"])
correct_answer = self.choice_letters[correct_index]
-
+
user_content = self.create_judgment_prompt(example["question"], all_answers)
-
- prompt = tuple([
- frozenset({"role": "system", "content": system_content}.items()),
- frozenset({"role": "user", "content": user_content}.items())
- ])
-
+
+ prompt = tuple(
+ [
+ frozenset({"role": "system", "content": system_content}.items()),
+ frozenset({"role": "user", "content": user_content}.items()),
+ ]
+ )
+
return (prompt, correct_answer)
def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]:
"""
Prepare an evaluation item from the reward-bench-2 dataset.
-
+
Dataset structure:
- chosen: list with 1 element (the best response)
- - rejected: list with 3+ elements (worse responses)
+ - rejected: list with 3+ elements (worse responses)
- We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices
"""
try:
question = item.get("prompt", "")
chosen_responses = item.get("chosen", [])
rejected_responses = item.get("rejected", [])
-
+
# Validate required fields
if not question:
return None, None
-
+
# Take one chosen response and (num_choices-1) rejected responses
required_rejected = self.config.num_choices - 1
- if len(chosen_responses) == 0 or len(rejected_responses) < required_rejected:
+ if (
+ len(chosen_responses) == 0
+ or len(rejected_responses) < required_rejected
+ ):
return None, None
-
+
chosen = chosen_responses[0]
rejected = rejected_responses[:required_rejected]
-
+
# Validate response content
if not chosen or not all(rejected):
return None, None
-
+
# Create list with answer and whether it's correct
data = [(chosen, True)] + [(r, False) for r in rejected]
random.shuffle(data)
-
+
# Extract shuffled answers and find correct position
shuffled_answers = [item[0] for item in data]
- correct_index = next(i for i, (_, is_correct) in enumerate(data) if is_correct)
+ correct_index = next(
+ i for i, (_, is_correct) in enumerate(data) if is_correct
+ )
correct_answer = self.choice_letters[correct_index]
-
+
# Create system message
system_content = self._create_system_content()
-
+
# Create user prompt
user_content = self.create_judgment_prompt(question, shuffled_answers)
-
- prompt = tuple([
- frozenset({"role": "system", "content": system_content}.items()),
- frozenset({"role": "user", "content": user_content}.items())
- ])
-
+
+ prompt = tuple(
+ [
+ frozenset({"role": "system", "content": system_content}.items()),
+ frozenset({"role": "user", "content": user_content}.items()),
+ ]
+ )
+
return prompt, correct_answer
-
+
except Exception as e:
print(f"Error preparing evaluation item: {e}")
return None, None
@@ -459,8 +497,10 @@ class PairwiseJudgementEnv(BaseEnv):
messages, prompt_text = self._prepare_completion_input(item[0])
completion_params = self._get_train_completion_params()
- completions = await self.server.completion(prompt=prompt_text, **completion_params)
-
+ completions = await self.server.completion(
+ prompt=prompt_text, **completion_params
+ )
+
# Build trajectories without duplicating message construction
to_score = []
for completion_choice in completions.choices:
@@ -469,71 +509,76 @@ class PairwiseJudgementEnv(BaseEnv):
{"role": "assistant", "content": completion_choice.text}
]
to_score.append((tuple(trajectory_messages), item[1]))
-
+
scored_data = await self.score(to_score)
-
+
# Add rollouts for wandb visualization
if scored_data is not None:
await self.add_rollouts_for_wandb(scored_data, item)
-
+
return scored_data, []
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
"""Score a group of rollout data."""
if not rollout_group_data:
return None
-
+
try:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
-
+
random.shuffle(rollout_group_data)
-
+
for item in rollout_group_data:
# Simplified validation
if not item or len(item) < 2 or not item[0]:
continue
-
+
model_response = item[0][-1]["content"]
ground_truth = item[1]
-
- predicted_answer = self.process_judgement(model_response, track_metrics=True)
+
+ predicted_answer = self.process_judgement(
+ model_response, track_metrics=True
+ )
reward = 1.0 if predicted_answer == ground_truth else 0.0
-
+
# Track correct judgments per letter
- if predicted_answer == ground_truth and predicted_answer != "format_error":
+ if (
+ predicted_answer == ground_truth
+ and predicted_answer != "format_error"
+ ):
self.judgment_letter_correct[predicted_answer] += 1
-
+
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
-
+
# Skip obviously bad examples
if len([1 for mask in masks if mask != -100]) < 10:
continue
-
+
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward) # Use reward directly (1.0 or 0.0)
-
+
if len(scores["tokens"]) >= self.config.group_size:
break
-
+
if not scores["tokens"]:
return None
-
+
# Update percent correct buffer
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
-
+
# Return None if all scores are the same (no learning signal)
if len(set(scores["scores"])) == 1:
return None
-
+
return scores
-
+
except Exception as e:
print(f"Error in score method: {e}")
return None
@@ -544,35 +589,45 @@ class PairwiseJudgementEnv(BaseEnv):
prompt, ground_truth = self.prepare_eval_item(test_item)
if prompt is None:
return {"score": 0.0, "sample": None}
-
+
messages, prompt_text = self._prepare_completion_input(prompt)
completion_params = self._get_eval_completion_params()
-
- completion = await self.server.completion(prompt=prompt_text, **completion_params)
-
+
+ completion = await self.server.completion(
+ prompt=prompt_text, **completion_params
+ )
+
if not completion.choices:
return {"score": 0.0, "sample": None}
-
+
model_response = completion.choices[0].text
- predicted_answer = self.process_judgement(model_response, track_metrics=False)
-
+ predicted_answer = self.process_judgement(
+ model_response, track_metrics=False
+ )
+
score = 1.0 if predicted_answer == ground_truth else 0.0
-
+
# Extract question and answer choices from the user message
user_content = messages[1]["content"]
question_match = self._question_pattern.search(user_content)
- question = question_match.group(1).strip() if question_match else "Unknown question"
-
+ question = (
+ question_match.group(1).strip()
+ if question_match
+ else "Unknown question"
+ )
+
# Extract individual answer choices for all configured letters
answer_choices = {}
for letter in self.choice_letters:
match = self._answer_choice_patterns[letter].search(user_content)
if match:
answer_choices[letter] = match.group(1).strip()
-
+
# Add full conversation including model response
- full_messages = messages + [{"role": "assistant", "content": model_response}]
-
+ full_messages = messages + [
+ {"role": "assistant", "content": model_response}
+ ]
+
sample = {
"messages": full_messages,
"question": question,
@@ -588,88 +643,101 @@ class PairwiseJudgementEnv(BaseEnv):
"dataset_subset": test_item.get("subset", "unknown"),
"num_choices": self.config.num_choices,
}
-
+
# Add thinking-specific parsing info
if self.config.thinking_mode:
if "" in model_response:
- sample["response_after_think"] = model_response.split("")[-1].strip()
- sample["thinking_content"] = self._thinking_extract_pattern.search(model_response)
+ sample["response_after_think"] = model_response.split("")[
+ -1
+ ].strip()
+ sample["thinking_content"] = self._thinking_extract_pattern.search(
+ model_response
+ )
if sample["thinking_content"]:
- sample["thinking_content"] = sample["thinking_content"].group(1).strip()
+ sample["thinking_content"] = (
+ sample["thinking_content"].group(1).strip()
+ )
else:
sample["response_after_think"] = model_response
sample["thinking_content"] = None
-
+
return {"score": score, "sample": sample}
-
+
except Exception as e:
print(f"Error in rollout_and_score_eval: {e}")
return {"score": 0.0, "sample": None}
- def _calculate_response_metrics(self, samples: List[dict], thinking_mode_used: bool) -> Tuple[List[int], int, Dict[str, int]]:
+ def _calculate_response_metrics(
+ self, samples: List[dict], thinking_mode_used: bool
+ ) -> Tuple[List[int], int, Dict[str, int]]:
"""Calculate response-related metrics from samples."""
response_lengths = []
thinking_utilization = 0
judgment_counts = {letter: 0 for letter in self.choice_letters}
judgment_counts["format_error"] = 0
-
+
for sample in samples:
if not sample:
continue
-
+
# Track response length
messages = sample.get("messages", [])
if messages:
assistant_msg = messages[-1].get("content", "")
response_lengths.append(len(assistant_msg))
-
+
# Track thinking utilization in thinking mode
if thinking_mode_used:
thinking_content = sample.get("thinking_content")
if thinking_content:
thinking_utilization += 1
-
+
# Track judgment distribution
predicted_judgment = sample.get("predicted_judgment", "format_error")
if predicted_judgment in judgment_counts:
judgment_counts[predicted_judgment] += 1
-
+
return response_lengths, thinking_utilization, judgment_counts
async def evaluate(self, *args, **kwargs) -> None:
"""Evaluate the model on the test dataset."""
start_time = time.time()
-
+
try:
- eval_tasks = [self.rollout_and_score_eval(test_item) for test_item in self.test]
+ eval_tasks = [
+ self.rollout_and_score_eval(test_item) for test_item in self.test
+ ]
results = await tqdm_asyncio.gather(*eval_tasks)
-
+
# Filter valid results
valid_results = [
- result for result in results
- if not isinstance(result, Exception) and result and result.get("sample") is not None
+ result
+ for result in results
+ if not isinstance(result, Exception)
+ and result
+ and result.get("sample") is not None
]
-
+
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
-
+
except Exception as e:
print(f"Error during evaluation: {e}")
return
-
+
# Extract scores and samples from valid results
scores = [result["score"] for result in valid_results]
samples = [result["sample"] for result in valid_results]
valid_scores = [s for s in scores if s is not None]
-
+
if not valid_scores:
print("Warning: No valid scores found during evaluation")
return
-
+
percent_correct = sum(valid_scores) / len(valid_scores)
self.eval_metrics.append(("eval/percent_correct", percent_correct))
-
+
# Track performance by subset if available
subset_scores = {}
for i, sample in enumerate(samples):
@@ -678,33 +746,42 @@ class PairwiseJudgementEnv(BaseEnv):
if subset not in subset_scores:
subset_scores[subset] = []
subset_scores[subset].append(scores[i])
-
+
# Log subset-specific metrics
for subset, subset_score_list in subset_scores.items():
valid_subset_scores = [s for s in subset_score_list if s is not None]
if valid_subset_scores:
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score))
-
+
# Calculate additional metrics
- format_compliant = sum(1 for sample in samples if sample.get("format_compliant", False))
+ format_compliant = sum(
+ 1 for sample in samples if sample.get("format_compliant", False)
+ )
thinking_mode_used = self.config.thinking_mode
-
+
# Get response metrics
- response_lengths, thinking_utilization, judgment_counts = self._calculate_response_metrics(samples, thinking_mode_used)
-
+ response_lengths, thinking_utilization, judgment_counts = (
+ self._calculate_response_metrics(samples, thinking_mode_used)
+ )
+
# Response length metrics
if response_lengths:
avg_response_length = sum(response_lengths) / len(response_lengths)
- response_length_std = (sum((x - avg_response_length) ** 2 for x in response_lengths) / len(response_lengths)) ** 0.5
+ response_length_std = (
+ sum((x - avg_response_length) ** 2 for x in response_lengths)
+ / len(response_lengths)
+ ) ** 0.5
self.eval_metrics.append(("eval/avg_response_length", avg_response_length))
self.eval_metrics.append(("eval/response_length_std", response_length_std))
-
+
# Thinking utilization rate
if thinking_mode_used and samples:
thinking_utilization_rate = thinking_utilization / len(samples)
- self.eval_metrics.append(("eval/thinking_utilization_rate", thinking_utilization_rate))
-
+ self.eval_metrics.append(
+ ("eval/thinking_utilization_rate", thinking_utilization_rate)
+ )
+
# Judgment distribution metrics
total_judgments = sum(judgment_counts.values())
if total_judgments > 0:
@@ -715,54 +792,63 @@ class PairwiseJudgementEnv(BaseEnv):
freq = count / total_judgments
entropy -= freq * math.log(freq)
self.eval_metrics.append(("eval/judgment_entropy", entropy))
-
+
# Most common judgment frequency (bias detection)
max_judgment_count = max(judgment_counts.values())
most_common_judgment_freq = max_judgment_count / total_judgments
- self.eval_metrics.append(("eval/most_common_judgment_freq", most_common_judgment_freq))
-
+ self.eval_metrics.append(
+ ("eval/most_common_judgment_freq", most_common_judgment_freq)
+ )
+
# Format error rate
format_error_rate = judgment_counts["format_error"] / total_judgments
self.eval_metrics.append(("eval/format_error_rate", format_error_rate))
-
+
# Add overall dataset statistics
self.eval_metrics.append(("eval/total_items", len(self.test)))
self.eval_metrics.append(("eval/valid_scores", len(valid_scores)))
self.eval_metrics.append(("eval/subset_count", len(subset_scores)))
- self.eval_metrics.append(("eval/format_compliance_rate", format_compliant / len(samples) if samples else 0.0))
-
+ self.eval_metrics.append(
+ (
+ "eval/format_compliance_rate",
+ format_compliant / len(samples) if samples else 0.0,
+ )
+ )
+
end_time = time.time()
-
+
# Build evaluation metrics dict
eval_metrics = {
"eval/percent_correct": percent_correct,
"eval/total_samples": len(samples),
"eval/correct_samples": sum(valid_scores),
- "eval/format_compliance_rate": format_compliant / len(samples) if samples else 0.0,
+ "eval/format_compliance_rate": (
+ format_compliant / len(samples) if samples else 0.0
+ ),
}
-
+
# Add response length metrics
if response_lengths:
eval_metrics["eval/avg_response_length"] = avg_response_length
eval_metrics["eval/response_length_std"] = response_length_std
-
+
# Add thinking utilization
if thinking_mode_used and samples:
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
-
+
# Add judgment distribution metrics
if total_judgments > 0:
eval_metrics["eval/judgment_entropy"] = entropy
eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq
eval_metrics["eval/format_error_rate"] = format_error_rate
-
+
# Add subset metrics
for subset, subset_score_list in subset_scores.items():
valid_subset_scores = [s for s in subset_score_list if s is not None]
if valid_subset_scores:
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
eval_metrics[f"eval/percent_correct_{subset}"] = avg_score
-
+
try:
await self.evaluate_log(
metrics=eval_metrics,
@@ -786,10 +872,10 @@ class PairwiseJudgementEnv(BaseEnv):
"""Add rollouts to wandb for visualization."""
if item is None or scored_data is None or not scored_data.get("tokens"):
return
-
+
# Extract ground truth and question info
ground_truth = item[1]
-
+
# Extract question from the item prompt
question_info = "unknown_question"
try:
@@ -806,24 +892,24 @@ class PairwiseJudgementEnv(BaseEnv):
except Exception:
# Fallback to placeholder if extraction fails
question_info = "extraction_failed"
-
+
# Keep a reasonable number of rollouts
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
-
+
num_keep = min(num_keep, len(scored_data["tokens"]))
-
+
current_rollouts = []
mode = "thinking" if self.config.thinking_mode else "direct"
-
+
for i in range(num_keep):
# Decode the full trajectory
full_text = self.tokenizer.decode(
scored_data["tokens"][i], skip_special_tokens=True
)
score_val = scored_data["scores"][i]
-
+
# Extract the model's judgment
predicted_judgment = "unknown"
try:
@@ -834,22 +920,26 @@ class PairwiseJudgementEnv(BaseEnv):
else:
# Fallback to decoding tokens
model_response = full_text
-
- predicted_judgment = self.process_judgement(model_response, track_metrics=False)
+
+ predicted_judgment = self.process_judgement(
+ model_response, track_metrics=False
+ )
except Exception:
predicted_judgment = "parse_error"
-
- current_rollouts.append((
- full_text,
- score_val,
- ground_truth,
- predicted_judgment,
- question_info,
- mode
- ))
-
+
+ current_rollouts.append(
+ (
+ full_text,
+ score_val,
+ ground_truth,
+ predicted_judgment,
+ question_info,
+ mode,
+ )
+ )
+
self.rollouts_for_wandb.append(current_rollouts)
-
+
# Keep only recent rollouts
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
@@ -858,23 +948,23 @@ class PairwiseJudgementEnv(BaseEnv):
"""Create wandb table for rollout visualization."""
if not self.rollouts_for_wandb:
return wandb_metrics
-
+
table = wandb.Table(
columns=[
- "full_text",
- "score",
- "ground_truth",
- "predicted_judgment",
- "question_info",
- "mode"
+ "full_text",
+ "score",
+ "ground_truth",
+ "predicted_judgment",
+ "question_info",
+ "mode",
]
)
-
+
for group_rollouts in self.rollouts_for_wandb:
for rollout_tuple in group_rollouts:
if len(rollout_tuple) == 6:
table.add_data(*rollout_tuple)
-
+
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
@@ -883,11 +973,13 @@ class PairwiseJudgementEnv(BaseEnv):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
-
+
# Basic accuracy metrics
if self.percent_correct_buffer:
- wandb_metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
-
+ wandb_metrics["train/percent_correct"] = sum(
+ self.percent_correct_buffer
+ ) / len(self.percent_correct_buffer)
+
# Judgment letter distribution and accuracy
total_letters = sum(self.judgment_letter_counts.values())
if total_letters > 0:
@@ -896,44 +988,54 @@ class PairwiseJudgementEnv(BaseEnv):
for letter in self.choice_letters:
letter_count = self.judgment_letter_counts[letter]
letter_correct = self.judgment_letter_correct[letter]
-
+
# Letter frequency and accuracy
freq = letter_count / total_letters
wandb_metrics[f"train/judgment_freq_{letter}"] = freq
- wandb_metrics[f"train/judgment_acc_{letter}"] = letter_correct / letter_count if letter_count > 0 else 0.0
-
+ wandb_metrics[f"train/judgment_acc_{letter}"] = (
+ letter_correct / letter_count if letter_count > 0 else 0.0
+ )
+
# Accumulate entropy
if freq > 0:
entropy -= freq * math.log(freq)
-
+
wandb_metrics["train/judgment_entropy"] = entropy
- wandb_metrics["train/judgment_balance"] = entropy / math.log(self.config.num_choices) # Normalized entropy
-
+ wandb_metrics["train/judgment_balance"] = entropy / math.log(
+ self.config.num_choices
+ ) # Normalized entropy
+
# Error rate and other metrics
if self.total_judgments > 0:
wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments
- wandb_metrics["train/format_compliance_rate"] = 1.0 - (self.error_count / self.total_judgments)
-
+ wandb_metrics["train/format_compliance_rate"] = 1.0 - (
+ self.error_count / self.total_judgments
+ )
+
# Configuration and mode metrics
- wandb_metrics.update({
- "train/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
- "train/total_judgments": self.total_judgments,
- "config/group_size": self.config.group_size,
- "config/max_token_length": self.config.max_token_length,
- "config/num_choices": self.config.num_choices,
- })
-
+ wandb_metrics.update(
+ {
+ "train/thinking_mode_enabled": (
+ 1.0 if self.config.thinking_mode else 0.0
+ ),
+ "train/total_judgments": self.total_judgments,
+ "config/group_size": self.config.group_size,
+ "config/max_token_length": self.config.max_token_length,
+ "config/num_choices": self.config.num_choices,
+ }
+ )
+
# Reset training metrics
self._reset_metrics()
-
+
# Add evaluation metrics
for metric_name, metric_value in self.eval_metrics:
wandb_metrics[metric_name] = metric_value
self.eval_metrics = []
-
+
# Add rollout table
wandb_metrics = await self.create_rollout_table(wandb_metrics)
-
+
await super().wandb_log(wandb_metrics)