mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
changes linting
This commit is contained in:
parent
f6f3c04313
commit
b56d03b25c
1 changed files with 3 additions and 37 deletions
|
|
@ -405,14 +405,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
random.shuffle(train_list)
|
||||
self.train = train_list
|
||||
except Exception as e:
|
||||
print(f"❌ CRITICAL ERROR: Failed to load training dataset '{self.config.train_dataset}': {e}")
|
||||
print(f" This would result in corrupted training with unrepresentative data.")
|
||||
print(f" Please fix the training dataset configuration before proceeding.")
|
||||
print(f" Available options:")
|
||||
print(f" 1. Check if the dataset path/name is correct")
|
||||
print(f" 2. Verify dataset access permissions")
|
||||
print(f" 3. Ensure required dependencies are installed")
|
||||
print(f" 4. Check network connectivity for HuggingFace datasets")
|
||||
raise ValueError(
|
||||
f"Training dataset '{self.config.train_dataset}' failed to load. "
|
||||
f"Cannot proceed with training as fallback data would corrupt the training process. "
|
||||
|
|
@ -511,9 +503,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
if not self.config.full_debug:
|
||||
return
|
||||
|
||||
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
|
||||
print(f" Parameters: {params}")
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
|
|
@ -526,8 +515,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
if not self.config.full_debug:
|
||||
return
|
||||
|
||||
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
|
||||
|
||||
if hasattr(completion, "usage"):
|
||||
print(f" Usage: {completion.usage}")
|
||||
|
||||
|
|
@ -539,10 +526,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
if hasattr(choice, "finish_reason")
|
||||
else "unknown"
|
||||
)
|
||||
print(
|
||||
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
|
||||
)
|
||||
print(f" Finish reason: {finish_reason}")
|
||||
else:
|
||||
print(" No choices in response")
|
||||
|
||||
|
|
@ -638,7 +621,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
async def _inner_judge_call():
|
||||
async with self.judge_semaphore:
|
||||
# Simple rate limiting: wait before each call
|
||||
print(f" ⏱️ Rate limiting: waiting {self.config.judge_rate_limit_delay}s before judge call")
|
||||
await asyncio.sleep(self.config.judge_rate_limit_delay)
|
||||
return await self._judge_api_call_raw(messages)
|
||||
|
||||
|
|
@ -649,7 +631,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
async def _judge_api_call_raw(self, messages: List[Dict]):
|
||||
"""Make a raw API call to the judge model without retry logic."""
|
||||
try:
|
||||
print(f" 🤖 Making judge API call to {self.config.judge_model_name}")
|
||||
result = await self.judge_client.chat.completions.create(
|
||||
model=self.config.judge_model_name,
|
||||
messages=messages,
|
||||
|
|
@ -659,7 +640,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
print(f" ✅ Judge API call successful")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ❌ Judge API call failed: {type(e).__name__}: {str(e)}")
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'status_code'):
|
||||
print(f" HTTP Status: {e.response.status_code}")
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'text'):
|
||||
|
|
@ -773,7 +753,6 @@ Your answer:"""
|
|||
await self._update_judge_error_metrics()
|
||||
|
||||
if self.config.use_fallback_scoring:
|
||||
print(" 🔄 Using fallback refusal detection")
|
||||
await self._update_fallback_metrics()
|
||||
return self._fallback_refusal_detection(response)
|
||||
else:
|
||||
|
|
@ -1074,7 +1053,6 @@ Your answer:"""
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"\n🔄 EVAL Attempt {attempt + 1}/{max_retries}")
|
||||
print(f" Prompt: {prompt_text[:100]}...")
|
||||
print(f" Category: {overarching_category}")
|
||||
print(f" Danger: {danger}")
|
||||
|
|
@ -1089,7 +1067,6 @@ Your answer:"""
|
|||
print(f" Server sem value: {server.sem._value}")
|
||||
print(f" Server eval_sem value: {server.eval_sem._value}")
|
||||
|
||||
print(f"🚀 Calling self.server.chat_completion...")
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages, **completion_params
|
||||
)
|
||||
|
|
@ -1133,7 +1110,6 @@ Your answer:"""
|
|||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ EVAL Attempt {attempt + 1} FAILED!")
|
||||
print(f" Error type: {type(e).__name__}")
|
||||
print(f" Error message: {str(e)}")
|
||||
|
||||
|
|
@ -1151,11 +1127,10 @@ Your answer:"""
|
|||
traceback.print_exc()
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
print(f" ⏳ Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f" 💀 All retry attempts failed, giving up")
|
||||
print(f"All retry attempts failed, giving up")
|
||||
raise
|
||||
|
||||
# Validate thinking format if enabled
|
||||
|
|
@ -1321,18 +1296,15 @@ Your answer:"""
|
|||
# Build evaluation metrics dict - ordered with most important metrics first
|
||||
eval_metrics = {}
|
||||
|
||||
# 🔝 TOP PRIORITY: Overall and aggregated refusal rates
|
||||
eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate
|
||||
if normal_refusal_rates:
|
||||
eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate
|
||||
if inverted_refusal_rates:
|
||||
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate
|
||||
|
||||
# 📊 SUMMARY METRICS: Basic counts and compliance
|
||||
eval_metrics["eval/total_samples"] = len(samples)
|
||||
eval_metrics["eval/format_compliance_rate"] = format_compliance_rate
|
||||
|
||||
# 📋 CATEGORY-SPECIFIC METRICS: Sorted alphabetically for consistency
|
||||
category_metrics = {}
|
||||
for category, rates in category_refusal_rates.items():
|
||||
if rates:
|
||||
|
|
@ -1363,7 +1335,6 @@ Your answer:"""
|
|||
"""Evaluate the model on the test dataset with optional averaging over multiple runs."""
|
||||
if self.config.eval_runs_for_average == 1:
|
||||
# Single run - use the original behavior
|
||||
print("🔄 Running single evaluation...")
|
||||
result = await self._evaluate_single_run(1, *args, **kwargs)
|
||||
|
||||
try:
|
||||
|
|
@ -1378,12 +1349,10 @@ Your answer:"""
|
|||
print(f"Error logging evaluation results: {e}")
|
||||
else:
|
||||
# Multiple runs - perform averaging
|
||||
print(f"🔄 Running evaluation {self.config.eval_runs_for_average} times for averaging...")
|
||||
all_run_results = []
|
||||
overall_start_time = time.time()
|
||||
|
||||
for run_num in range(1, self.config.eval_runs_for_average + 1):
|
||||
print(f"\n📊 Starting evaluation run {run_num}/{self.config.eval_runs_for_average}")
|
||||
|
||||
try:
|
||||
result = await self._evaluate_single_run(run_num, *args, **kwargs)
|
||||
|
|
@ -1399,12 +1368,11 @@ Your answer:"""
|
|||
generation_parameters=result["generation_parameters"],
|
||||
task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}",
|
||||
)
|
||||
print(f"✅ Run {run_num} completed and logged")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error logging run {run_num} results: {e}")
|
||||
print(f"Error logging run {run_num} results: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in evaluation run {run_num}: {e}")
|
||||
print(f"Error in evaluation run {run_num}: {e}")
|
||||
continue
|
||||
|
||||
overall_end_time = time.time()
|
||||
|
|
@ -1414,7 +1382,6 @@ Your answer:"""
|
|||
return
|
||||
|
||||
# Calculate averages across all runs
|
||||
print(f"\n🧮 Computing averages across {len(all_run_results)} successful runs...")
|
||||
averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time)
|
||||
|
||||
# Log averaged results
|
||||
|
|
@ -1427,7 +1394,6 @@ Your answer:"""
|
|||
generation_parameters=averaged_results["generation_parameters"],
|
||||
task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged",
|
||||
)
|
||||
print(f"✅ Averaged results across {len(all_run_results)} runs logged successfully")
|
||||
except Exception as e:
|
||||
print(f"❌ Error logging averaged results: {e}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue