changes linting

This commit is contained in:
Jai Suphavadeeprasit 2025-08-28 03:53:12 -04:00
parent f6f3c04313
commit b56d03b25c

View file

@ -405,14 +405,6 @@ class RefusalBenchEnv(BaseEnv):
random.shuffle(train_list) random.shuffle(train_list)
self.train = train_list self.train = train_list
except Exception as e: except Exception as e:
print(f"❌ CRITICAL ERROR: Failed to load training dataset '{self.config.train_dataset}': {e}")
print(f" This would result in corrupted training with unrepresentative data.")
print(f" Please fix the training dataset configuration before proceeding.")
print(f" Available options:")
print(f" 1. Check if the dataset path/name is correct")
print(f" 2. Verify dataset access permissions")
print(f" 3. Ensure required dependencies are installed")
print(f" 4. Check network connectivity for HuggingFace datasets")
raise ValueError( raise ValueError(
f"Training dataset '{self.config.train_dataset}' failed to load. " f"Training dataset '{self.config.train_dataset}' failed to load. "
f"Cannot proceed with training as fallback data would corrupt the training process. " f"Cannot proceed with training as fallback data would corrupt the training process. "
@ -511,9 +503,6 @@ class RefusalBenchEnv(BaseEnv):
if not self.config.full_debug: if not self.config.full_debug:
return return
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
print(f" Parameters: {params}")
for i, message in enumerate(messages): for i, message in enumerate(messages):
role = message.get("role", "unknown") role = message.get("role", "unknown")
content = message.get("content", "") content = message.get("content", "")
@ -526,8 +515,6 @@ class RefusalBenchEnv(BaseEnv):
if not self.config.full_debug: if not self.config.full_debug:
return return
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
if hasattr(completion, "usage"): if hasattr(completion, "usage"):
print(f" Usage: {completion.usage}") print(f" Usage: {completion.usage}")
@ -539,10 +526,6 @@ class RefusalBenchEnv(BaseEnv):
if hasattr(choice, "finish_reason") if hasattr(choice, "finish_reason")
else "unknown" else "unknown"
) )
print(
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
)
print(f" Finish reason: {finish_reason}")
else: else:
print(" No choices in response") print(" No choices in response")
@ -638,7 +621,6 @@ class RefusalBenchEnv(BaseEnv):
async def _inner_judge_call(): async def _inner_judge_call():
async with self.judge_semaphore: async with self.judge_semaphore:
# Simple rate limiting: wait before each call # Simple rate limiting: wait before each call
print(f" ⏱️ Rate limiting: waiting {self.config.judge_rate_limit_delay}s before judge call")
await asyncio.sleep(self.config.judge_rate_limit_delay) await asyncio.sleep(self.config.judge_rate_limit_delay)
return await self._judge_api_call_raw(messages) return await self._judge_api_call_raw(messages)
@ -649,7 +631,6 @@ class RefusalBenchEnv(BaseEnv):
async def _judge_api_call_raw(self, messages: List[Dict]): async def _judge_api_call_raw(self, messages: List[Dict]):
"""Make a raw API call to the judge model without retry logic.""" """Make a raw API call to the judge model without retry logic."""
try: try:
print(f" 🤖 Making judge API call to {self.config.judge_model_name}")
result = await self.judge_client.chat.completions.create( result = await self.judge_client.chat.completions.create(
model=self.config.judge_model_name, model=self.config.judge_model_name,
messages=messages, messages=messages,
@ -659,7 +640,6 @@ class RefusalBenchEnv(BaseEnv):
print(f" ✅ Judge API call successful") print(f" ✅ Judge API call successful")
return result return result
except Exception as e: except Exception as e:
print(f" ❌ Judge API call failed: {type(e).__name__}: {str(e)}")
if hasattr(e, 'response') and hasattr(e.response, 'status_code'): if hasattr(e, 'response') and hasattr(e.response, 'status_code'):
print(f" HTTP Status: {e.response.status_code}") print(f" HTTP Status: {e.response.status_code}")
if hasattr(e, 'response') and hasattr(e.response, 'text'): if hasattr(e, 'response') and hasattr(e.response, 'text'):
@ -773,7 +753,6 @@ Your answer:"""
await self._update_judge_error_metrics() await self._update_judge_error_metrics()
if self.config.use_fallback_scoring: if self.config.use_fallback_scoring:
print(" 🔄 Using fallback refusal detection")
await self._update_fallback_metrics() await self._update_fallback_metrics()
return self._fallback_refusal_detection(response) return self._fallback_refusal_detection(response)
else: else:
@ -1074,7 +1053,6 @@ Your answer:"""
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
print(f"\n🔄 EVAL Attempt {attempt + 1}/{max_retries}")
print(f" Prompt: {prompt_text[:100]}...") print(f" Prompt: {prompt_text[:100]}...")
print(f" Category: {overarching_category}") print(f" Category: {overarching_category}")
print(f" Danger: {danger}") print(f" Danger: {danger}")
@ -1089,7 +1067,6 @@ Your answer:"""
print(f" Server sem value: {server.sem._value}") print(f" Server sem value: {server.sem._value}")
print(f" Server eval_sem value: {server.eval_sem._value}") print(f" Server eval_sem value: {server.eval_sem._value}")
print(f"🚀 Calling self.server.chat_completion...")
completion = await self.server.chat_completion( completion = await self.server.chat_completion(
messages=messages, **completion_params messages=messages, **completion_params
) )
@ -1133,7 +1110,6 @@ Your answer:"""
break break
except Exception as e: except Exception as e:
print(f"\n❌ EVAL Attempt {attempt + 1} FAILED!")
print(f" Error type: {type(e).__name__}") print(f" Error type: {type(e).__name__}")
print(f" Error message: {str(e)}") print(f" Error message: {str(e)}")
@ -1151,11 +1127,10 @@ Your answer:"""
traceback.print_exc() traceback.print_exc()
if attempt < max_retries - 1: if attempt < max_retries - 1:
print(f" ⏳ Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
continue continue
else: else:
print(f" 💀 All retry attempts failed, giving up") print(f"All retry attempts failed, giving up")
raise raise
# Validate thinking format if enabled # Validate thinking format if enabled
@ -1321,18 +1296,15 @@ Your answer:"""
# Build evaluation metrics dict - ordered with most important metrics first # Build evaluation metrics dict - ordered with most important metrics first
eval_metrics = {} eval_metrics = {}
# 🔝 TOP PRIORITY: Overall and aggregated refusal rates
eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate
if normal_refusal_rates: if normal_refusal_rates:
eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate
if inverted_refusal_rates: if inverted_refusal_rates:
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate
# 📊 SUMMARY METRICS: Basic counts and compliance
eval_metrics["eval/total_samples"] = len(samples) eval_metrics["eval/total_samples"] = len(samples)
eval_metrics["eval/format_compliance_rate"] = format_compliance_rate eval_metrics["eval/format_compliance_rate"] = format_compliance_rate
# 📋 CATEGORY-SPECIFIC METRICS: Sorted alphabetically for consistency
category_metrics = {} category_metrics = {}
for category, rates in category_refusal_rates.items(): for category, rates in category_refusal_rates.items():
if rates: if rates:
@ -1363,7 +1335,6 @@ Your answer:"""
"""Evaluate the model on the test dataset with optional averaging over multiple runs.""" """Evaluate the model on the test dataset with optional averaging over multiple runs."""
if self.config.eval_runs_for_average == 1: if self.config.eval_runs_for_average == 1:
# Single run - use the original behavior # Single run - use the original behavior
print("🔄 Running single evaluation...")
result = await self._evaluate_single_run(1, *args, **kwargs) result = await self._evaluate_single_run(1, *args, **kwargs)
try: try:
@ -1378,12 +1349,10 @@ Your answer:"""
print(f"Error logging evaluation results: {e}") print(f"Error logging evaluation results: {e}")
else: else:
# Multiple runs - perform averaging # Multiple runs - perform averaging
print(f"🔄 Running evaluation {self.config.eval_runs_for_average} times for averaging...")
all_run_results = [] all_run_results = []
overall_start_time = time.time() overall_start_time = time.time()
for run_num in range(1, self.config.eval_runs_for_average + 1): for run_num in range(1, self.config.eval_runs_for_average + 1):
print(f"\n📊 Starting evaluation run {run_num}/{self.config.eval_runs_for_average}")
try: try:
result = await self._evaluate_single_run(run_num, *args, **kwargs) result = await self._evaluate_single_run(run_num, *args, **kwargs)
@ -1399,12 +1368,11 @@ Your answer:"""
generation_parameters=result["generation_parameters"], generation_parameters=result["generation_parameters"],
task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}", task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}",
) )
print(f"✅ Run {run_num} completed and logged")
except Exception as e: except Exception as e:
print(f"⚠️ Error logging run {run_num} results: {e}") print(f"Error logging run {run_num} results: {e}")
except Exception as e: except Exception as e:
print(f"Error in evaluation run {run_num}: {e}") print(f"Error in evaluation run {run_num}: {e}")
continue continue
overall_end_time = time.time() overall_end_time = time.time()
@ -1414,7 +1382,6 @@ Your answer:"""
return return
# Calculate averages across all runs # Calculate averages across all runs
print(f"\n🧮 Computing averages across {len(all_run_results)} successful runs...")
averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time) averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time)
# Log averaged results # Log averaged results
@ -1427,7 +1394,6 @@ Your answer:"""
generation_parameters=averaged_results["generation_parameters"], generation_parameters=averaged_results["generation_parameters"],
task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged", task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged",
) )
print(f"✅ Averaged results across {len(all_run_results)} runs logged successfully")
except Exception as e: except Exception as e:
print(f"❌ Error logging averaged results: {e}") print(f"❌ Error logging averaged results: {e}")