mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
changes linting
This commit is contained in:
parent
f6f3c04313
commit
b56d03b25c
1 changed files with 3 additions and 37 deletions
|
|
@ -405,14 +405,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
random.shuffle(train_list)
|
random.shuffle(train_list)
|
||||||
self.train = train_list
|
self.train = train_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ CRITICAL ERROR: Failed to load training dataset '{self.config.train_dataset}': {e}")
|
|
||||||
print(f" This would result in corrupted training with unrepresentative data.")
|
|
||||||
print(f" Please fix the training dataset configuration before proceeding.")
|
|
||||||
print(f" Available options:")
|
|
||||||
print(f" 1. Check if the dataset path/name is correct")
|
|
||||||
print(f" 2. Verify dataset access permissions")
|
|
||||||
print(f" 3. Ensure required dependencies are installed")
|
|
||||||
print(f" 4. Check network connectivity for HuggingFace datasets")
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Training dataset '{self.config.train_dataset}' failed to load. "
|
f"Training dataset '{self.config.train_dataset}' failed to load. "
|
||||||
f"Cannot proceed with training as fallback data would corrupt the training process. "
|
f"Cannot proceed with training as fallback data would corrupt the training process. "
|
||||||
|
|
@ -511,9 +503,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
if not self.config.full_debug:
|
if not self.config.full_debug:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
|
|
||||||
print(f" Parameters: {params}")
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
role = message.get("role", "unknown")
|
role = message.get("role", "unknown")
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
|
|
@ -526,8 +515,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
if not self.config.full_debug:
|
if not self.config.full_debug:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
|
|
||||||
|
|
||||||
if hasattr(completion, "usage"):
|
if hasattr(completion, "usage"):
|
||||||
print(f" Usage: {completion.usage}")
|
print(f" Usage: {completion.usage}")
|
||||||
|
|
||||||
|
|
@ -539,10 +526,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
if hasattr(choice, "finish_reason")
|
if hasattr(choice, "finish_reason")
|
||||||
else "unknown"
|
else "unknown"
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
|
|
||||||
)
|
|
||||||
print(f" Finish reason: {finish_reason}")
|
|
||||||
else:
|
else:
|
||||||
print(" No choices in response")
|
print(" No choices in response")
|
||||||
|
|
||||||
|
|
@ -638,7 +621,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
async def _inner_judge_call():
|
async def _inner_judge_call():
|
||||||
async with self.judge_semaphore:
|
async with self.judge_semaphore:
|
||||||
# Simple rate limiting: wait before each call
|
# Simple rate limiting: wait before each call
|
||||||
print(f" ⏱️ Rate limiting: waiting {self.config.judge_rate_limit_delay}s before judge call")
|
|
||||||
await asyncio.sleep(self.config.judge_rate_limit_delay)
|
await asyncio.sleep(self.config.judge_rate_limit_delay)
|
||||||
return await self._judge_api_call_raw(messages)
|
return await self._judge_api_call_raw(messages)
|
||||||
|
|
||||||
|
|
@ -649,7 +631,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
async def _judge_api_call_raw(self, messages: List[Dict]):
|
async def _judge_api_call_raw(self, messages: List[Dict]):
|
||||||
"""Make a raw API call to the judge model without retry logic."""
|
"""Make a raw API call to the judge model without retry logic."""
|
||||||
try:
|
try:
|
||||||
print(f" 🤖 Making judge API call to {self.config.judge_model_name}")
|
|
||||||
result = await self.judge_client.chat.completions.create(
|
result = await self.judge_client.chat.completions.create(
|
||||||
model=self.config.judge_model_name,
|
model=self.config.judge_model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -659,7 +640,6 @@ class RefusalBenchEnv(BaseEnv):
|
||||||
print(f" ✅ Judge API call successful")
|
print(f" ✅ Judge API call successful")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ❌ Judge API call failed: {type(e).__name__}: {str(e)}")
|
|
||||||
if hasattr(e, 'response') and hasattr(e.response, 'status_code'):
|
if hasattr(e, 'response') and hasattr(e.response, 'status_code'):
|
||||||
print(f" HTTP Status: {e.response.status_code}")
|
print(f" HTTP Status: {e.response.status_code}")
|
||||||
if hasattr(e, 'response') and hasattr(e.response, 'text'):
|
if hasattr(e, 'response') and hasattr(e.response, 'text'):
|
||||||
|
|
@ -773,7 +753,6 @@ Your answer:"""
|
||||||
await self._update_judge_error_metrics()
|
await self._update_judge_error_metrics()
|
||||||
|
|
||||||
if self.config.use_fallback_scoring:
|
if self.config.use_fallback_scoring:
|
||||||
print(" 🔄 Using fallback refusal detection")
|
|
||||||
await self._update_fallback_metrics()
|
await self._update_fallback_metrics()
|
||||||
return self._fallback_refusal_detection(response)
|
return self._fallback_refusal_detection(response)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1074,7 +1053,6 @@ Your answer:"""
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
print(f"\n🔄 EVAL Attempt {attempt + 1}/{max_retries}")
|
|
||||||
print(f" Prompt: {prompt_text[:100]}...")
|
print(f" Prompt: {prompt_text[:100]}...")
|
||||||
print(f" Category: {overarching_category}")
|
print(f" Category: {overarching_category}")
|
||||||
print(f" Danger: {danger}")
|
print(f" Danger: {danger}")
|
||||||
|
|
@ -1089,7 +1067,6 @@ Your answer:"""
|
||||||
print(f" Server sem value: {server.sem._value}")
|
print(f" Server sem value: {server.sem._value}")
|
||||||
print(f" Server eval_sem value: {server.eval_sem._value}")
|
print(f" Server eval_sem value: {server.eval_sem._value}")
|
||||||
|
|
||||||
print(f"🚀 Calling self.server.chat_completion...")
|
|
||||||
completion = await self.server.chat_completion(
|
completion = await self.server.chat_completion(
|
||||||
messages=messages, **completion_params
|
messages=messages, **completion_params
|
||||||
)
|
)
|
||||||
|
|
@ -1133,7 +1110,6 @@ Your answer:"""
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ EVAL Attempt {attempt + 1} FAILED!")
|
|
||||||
print(f" Error type: {type(e).__name__}")
|
print(f" Error type: {type(e).__name__}")
|
||||||
print(f" Error message: {str(e)}")
|
print(f" Error message: {str(e)}")
|
||||||
|
|
||||||
|
|
@ -1151,11 +1127,10 @@ Your answer:"""
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
print(f" ⏳ Retrying in {retry_delay} seconds...")
|
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
print(f" 💀 All retry attempts failed, giving up")
|
print(f"All retry attempts failed, giving up")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Validate thinking format if enabled
|
# Validate thinking format if enabled
|
||||||
|
|
@ -1321,18 +1296,15 @@ Your answer:"""
|
||||||
# Build evaluation metrics dict - ordered with most important metrics first
|
# Build evaluation metrics dict - ordered with most important metrics first
|
||||||
eval_metrics = {}
|
eval_metrics = {}
|
||||||
|
|
||||||
# 🔝 TOP PRIORITY: Overall and aggregated refusal rates
|
|
||||||
eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate
|
eval_metrics["eval/overall_refusal_rate"] = overall_refusal_rate
|
||||||
if normal_refusal_rates:
|
if normal_refusal_rates:
|
||||||
eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate
|
eval_metrics["eval/normal_categories_refusal_rate"] = avg_normal_refusal_rate
|
||||||
if inverted_refusal_rates:
|
if inverted_refusal_rates:
|
||||||
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate
|
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate
|
||||||
|
|
||||||
# 📊 SUMMARY METRICS: Basic counts and compliance
|
|
||||||
eval_metrics["eval/total_samples"] = len(samples)
|
eval_metrics["eval/total_samples"] = len(samples)
|
||||||
eval_metrics["eval/format_compliance_rate"] = format_compliance_rate
|
eval_metrics["eval/format_compliance_rate"] = format_compliance_rate
|
||||||
|
|
||||||
# 📋 CATEGORY-SPECIFIC METRICS: Sorted alphabetically for consistency
|
|
||||||
category_metrics = {}
|
category_metrics = {}
|
||||||
for category, rates in category_refusal_rates.items():
|
for category, rates in category_refusal_rates.items():
|
||||||
if rates:
|
if rates:
|
||||||
|
|
@ -1363,7 +1335,6 @@ Your answer:"""
|
||||||
"""Evaluate the model on the test dataset with optional averaging over multiple runs."""
|
"""Evaluate the model on the test dataset with optional averaging over multiple runs."""
|
||||||
if self.config.eval_runs_for_average == 1:
|
if self.config.eval_runs_for_average == 1:
|
||||||
# Single run - use the original behavior
|
# Single run - use the original behavior
|
||||||
print("🔄 Running single evaluation...")
|
|
||||||
result = await self._evaluate_single_run(1, *args, **kwargs)
|
result = await self._evaluate_single_run(1, *args, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -1378,12 +1349,10 @@ Your answer:"""
|
||||||
print(f"Error logging evaluation results: {e}")
|
print(f"Error logging evaluation results: {e}")
|
||||||
else:
|
else:
|
||||||
# Multiple runs - perform averaging
|
# Multiple runs - perform averaging
|
||||||
print(f"🔄 Running evaluation {self.config.eval_runs_for_average} times for averaging...")
|
|
||||||
all_run_results = []
|
all_run_results = []
|
||||||
overall_start_time = time.time()
|
overall_start_time = time.time()
|
||||||
|
|
||||||
for run_num in range(1, self.config.eval_runs_for_average + 1):
|
for run_num in range(1, self.config.eval_runs_for_average + 1):
|
||||||
print(f"\n📊 Starting evaluation run {run_num}/{self.config.eval_runs_for_average}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._evaluate_single_run(run_num, *args, **kwargs)
|
result = await self._evaluate_single_run(run_num, *args, **kwargs)
|
||||||
|
|
@ -1399,12 +1368,11 @@ Your answer:"""
|
||||||
generation_parameters=result["generation_parameters"],
|
generation_parameters=result["generation_parameters"],
|
||||||
task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}",
|
task_name=f"{self.name}_eval_run_{run_num}" if self.name else f"RefusalBench_eval_run_{run_num}",
|
||||||
)
|
)
|
||||||
print(f"✅ Run {run_num} completed and logged")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Error logging run {run_num} results: {e}")
|
print(f"Error logging run {run_num} results: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error in evaluation run {run_num}: {e}")
|
print(f"Error in evaluation run {run_num}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
overall_end_time = time.time()
|
overall_end_time = time.time()
|
||||||
|
|
@ -1414,7 +1382,6 @@ Your answer:"""
|
||||||
return
|
return
|
||||||
|
|
||||||
# Calculate averages across all runs
|
# Calculate averages across all runs
|
||||||
print(f"\n🧮 Computing averages across {len(all_run_results)} successful runs...")
|
|
||||||
averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time)
|
averaged_results = await self._compute_averaged_results(all_run_results, overall_start_time, overall_end_time)
|
||||||
|
|
||||||
# Log averaged results
|
# Log averaged results
|
||||||
|
|
@ -1427,7 +1394,6 @@ Your answer:"""
|
||||||
generation_parameters=averaged_results["generation_parameters"],
|
generation_parameters=averaged_results["generation_parameters"],
|
||||||
task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged",
|
task_name=f"{self.name}_eval_averaged" if self.name else "RefusalBench_eval_averaged",
|
||||||
)
|
)
|
||||||
print(f"✅ Averaged results across {len(all_run_results)} runs logged successfully")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error logging averaged results: {e}")
|
print(f"❌ Error logging averaged results: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue