changes linting

This commit is contained in:
Jai Suphavadeeprasit 2025-08-28 03:53:12 -04:00
parent f6f3c04313
commit b56d03b25c

View file

@ -405,14 +405,6 @@ class RefusalBenchEnv(BaseEnv):
random.shuffle(train_list)
self.train = train_list
except Exception as e:
print(f"❌ CRITICAL ERROR: Failed to load training dataset '{self.config.train_dataset}': {e}")
print(f" This would result in corrupted training with unrepresentative data.")
print(f" Please fix the training dataset configuration before proceeding.")
print(f" Available options:")
print(f" 1. Check if the dataset path/name is correct")
print(f" 2. Verify dataset access permissions")
print(f" 3. Ensure required dependencies are installed")
print(f" 4. Check network connectivity for HuggingFace datasets")
raise ValueError(
f"Training dataset '{self.config.train_dataset}' failed to load. "
f"Cannot proceed with training as fallback data would corrupt the training process. "
@ -511,9 +503,6 @@ class RefusalBenchEnv(BaseEnv):
if not self.config.full_debug:
return
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
print(f" Parameters: {params}")
for i, message in enumerate(messages):
role = message.get("role", "unknown")
content = message.get("content", "")
@ -526,8 +515,6 @@ class RefusalBenchEnv(BaseEnv):
if not self.config.full_debug:
return
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
if hasattr(completion, "usage"):
print(f" Usage: {completion.usage}")
@ -539,10 +526,6 @@ class RefusalBenchEnv(BaseEnv):
if hasattr(choice, "finish_reason")
else "unknown"
)
print(
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
)
print(f" Finish reason: {finish_reason}")
else:
print(" No choices in response")
@ -638,7 +621,6 @@ class RefusalBenchEnv(BaseEnv):
async def _inner_judge_call():
async with self.judge_semaphore:
# Simple rate limiting: wait before each call
print(f" ⏱️ Rate limiting: waiting {self.config.judge_rate_limit_delay}s before judge call")
await asyncio.sleep(self.config.judge_rate_limit_delay)
return await self._judge_api_call_raw(messages)
@ -649,7 +631,6 @@ class RefusalBenchEnv(BaseEnv):
async def _judge_api_call_raw(self, messages: List[Dict]):
"""Make a raw API call to the judge model without retry logic."""
try:
print(f" 🤖 Making judge API call to {self.config.judge_model_name}")
result = await self.judge_client.chat.completions.create(
model=self.config.judge_model_name,
messages=messages,
@ -659,7 +640,6 @@ class RefusalBenchEnv(BaseEnv):
print(f" ✅ Judge API call successful")
return result
except Exception as e:
print(f" ❌ Judge API call failed: {type(e).__name__}: {str(e)}")
if hasattr(e, 'response') and hasattr(e.response, 'status_code'):
print(f" HTTP Status: {e.response.status_code}")
if hasattr(e, 'response') and hasattr(e.response, 'text'):
@ -773,7 +753,6 @@ Your answer:"""
await self._update_judge_error_metrics()
if self.config.use_fallback_scoring:
print(" 🔄 Using fallback refusal detection")
await self._update_fallback_metrics()
return self._fallback_refusal_detection(response)
else:
@ -1074,7 +1053,6 @@ Your answer:"""
for attempt in range(max_retries):
try:
print(f"\n🔄 EVAL Attempt {attempt + 1}/{max_retries}")
print(f" Prompt: {prompt_text[:100]}...")
print(f" Category: {overarching_category}")
print(f" Danger: {danger}")
@ -1089,7 +1067,6 @@ Your answer:"""
print(f" Server sem value: {server.sem._value}")
print(f" Server eval_sem value: {server.eval_sem._value}")
print(f"🚀 Calling self.server.chat_completion...")
completion = await self.server.chat_completion(
messages=messages, **completion_params
)
@ -1133,7 +1110,6 @@ Your answer:"""
break
except Exception as e:
print(f"\n❌ EVAL Attempt {attempt + 1} FAILED!")
print(f" Error type: {type(e).__name__}")
print(f" Error message: {str(e)}")
@ -1151,11 +1127,10 @@ Your answer:"""
traceback.print_exc()
if attempt < max_retries - 1:
print(f" ⏳ Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
continue
else:
print(f" 💀 All retry attempts failed, giving up")
print(f"All retry attempts failed, giving up")
raise
# Validate thinking format if enabled
@ -1321,18 +1296,15 @@ Your answer:"""
# Build evaluation metrics dict - ordered with most important metrics first
eval_metrics = {}
# 🔝 TOP PRIORITY: Overall and aggregated refusal rates
eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate
if normal_refusal_rates:
eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate
if inverted_refusal_rates:
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate
# 📊 SUMMARY METRICS: Basic counts and compliance
eval_metrics["eval/total_samples"] = len(samples)
eval_metrics["eval/format_compliance_rate"] = format_compliance_rate
# 📋 CATEGORY-SPECIFIC METRICS: Sorted alphabetically for consistency
category_metrics = {}
for category, rates in category_refusal_rates.items():
if rates:
@ -1363,7 +1335,6 @@ Your answer:"""
"""Evaluate the model on the test dataset with optional averaging over multiple runs."""
if self.config.eval_runs_for_average == 1:
# Single run - use the original behavior
print("🔄 Running single evaluation...")
result = await self._evaluate_single_run(1, *args, **kwargs)
try:
@ -1378,12 +1349,10 @@ Your answer:"""
print(f"Error logging evaluation results: {e}")
else:
# Multiple runs - perform averaging
print(f"🔄 Running evaluation {self.config.eval_runs_for_average} times for averaging...")
all_run_results = []
overall_start_time = time.time()
for run_num in range(1, self.config.eval_runs_for_average + 1):
print(f"\n📊 Starting evaluation run {run_num}/{self.config.eval_runs_for_average}")
try:
result = await self._evaluate_single_run(run_num, *args, **kwargs)
@ -1399,12 +1368,11 @@ Your answer:"""
generation_parameters=result["generation_parameters"],
task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}",
)
print(f"✅ Run {run_num} completed and logged")
except Exception as e:
print(f"⚠️ Error logging run {run_num} results: {e}")
print(f"Error logging run {run_num} results: {e}")
except Exception as e:
print(f"Error in evaluation run {run_num}: {e}")
print(f"Error in evaluation run {run_num}: {e}")
continue
overall_end_time = time.time()
@ -1414,7 +1382,6 @@ Your answer:"""
return
# Calculate averages across all runs
print(f"\n🧮 Computing averages across {len(all_run_results)} successful runs...")
averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time)
# Log averaged results
@ -1427,7 +1394,6 @@ Your answer:"""
generation_parameters=averaged_results["generation_parameters"],
task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged",
)
print(f"✅ Averaged results across {len(all_run_results)} runs logged successfully")
except Exception as e:
print(f"❌ Error logging averaged results: {e}")