add_danger

This commit is contained in:
Jai Suphavadeeprasit 2025-08-07 13:34:30 -04:00
parent bed7ddcb95
commit e55a7a0100

View file

@ -269,7 +269,7 @@ class RefusalBenchEnv(BaseEnv):
dataset = load_dataset(dataset_path, split=split, trust_remote_code=True)
else:
dataset_dict = load_dataset(dataset_path, trust_remote_code=True)
# If no split specified, try to get the first available split
# If try to get the first available split
if hasattr(dataset_dict, "keys"):
available_splits = list(dataset_dict.keys())
if available_splits:
@ -307,9 +307,9 @@ class RefusalBenchEnv(BaseEnv):
eval_limit_ratio=0.1,
min_batch_allocation=0.1,
thinking_mode=False,
#judge_model_name="gpt-4o",
#judge_base_url="https://api.openai.com/v1",
#judge_api_key_env="OPENAI_API_KEY",
judge_model_name="Hermes-3-Llama-3.1-405B",
judge_base_url="https://inference-api.nousresearch.com/v1",
judge_api_key_env="NOUS_API_KEY",
# Categories where we want to reward refusal instead of compliance
inverted_categories=[
"Exploitation & Trafficking",
@ -324,9 +324,9 @@ class RefusalBenchEnv(BaseEnv):
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-3-Llama-3.1-8B",
base_url="http://localhost:9004/v1",
api_key="x",
model_name="Hermes-3-Llama-3.1-405B",
base_url="https://inference-api.nousresearch.com/v1",
api_key=os.getenv("NOUS_API_KEY"),
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
@ -859,6 +859,15 @@ Your answer:"""
try:
prompt_text = eval_item.get("prompt", "")
overarching_category = eval_item.get("overarching_category", "Unknown")
danger = eval_item.get("danger", "unknown")
print("Here is the prompt text ", prompt_text)
print("Here is the overarching category ", overarching_category)
print("Here is the danger ",danger )
if not prompt_text:
return {"refusal_rate": None, "sample": None}
@ -884,10 +893,11 @@ Your answer:"""
completion_params,
f"EVAL attempt {attempt + 1}/{max_retries}",
)
print("Tried")
completion = await self.server.chat_completion(
messages=messages, **completion_params
)
print("passed")
# Log full debug response
self._log_full_debug_response(
@ -933,6 +943,7 @@ Your answer:"""
"response": model_response,
"extracted_response": extracted_response,
"overarching_category": overarching_category,
"danger_level": danger,
"intermediate_category": eval_item.get("intermediate_category", "Unknown"),
"specific_category": eval_item.get("specific_category", "Unknown"),
"is_refusal": is_refusal,
@ -969,6 +980,7 @@ Your answer:"""
]
results = await tqdm_asyncio.gather(*eval_tasks)
# Filter valid results
valid_results = [
result
@ -1020,6 +1032,28 @@ Your answer:"""
avg_refusal_rate = sum(rates) / len(rates)
self.eval_metrics.append((f"eval/refusal_rate_{category.replace(' ', '_').replace('&', 'and')}", avg_refusal_rate))
# Calculate refusal rates by danger level
danger_refusal_rates = {}
danger_counts = {}
for sample in samples:
if sample:
danger_level = sample.get("danger_level", "unknown")
is_refusal = sample.get("is_refusal", False)
if danger_level not in danger_refusal_rates:
danger_refusal_rates[danger_level] = []
danger_counts[danger_level] = 0
danger_refusal_rates[danger_level].append(1.0 if is_refusal else 0.0)
danger_counts[danger_level] += 1
# Add danger level metrics
for danger_level, rates in danger_refusal_rates.items():
if rates:
avg_refusal_rate = sum(rates) / len(rates)
self.eval_metrics.append((f"eval/refusal_rate_danger_{danger_level}", avg_refusal_rate))
# Calculate additional metrics
format_compliant = sum(1 for sample in samples if sample.get("format_valid", True))
format_compliance_rate = format_compliant / len(samples) if samples else 0.0
@ -1062,6 +1096,12 @@ Your answer:"""
avg_refusal_rate = sum(rates) / len(rates)
eval_metrics[f"eval/refusal_rate_{category.replace(' ', '_').replace('&', 'and')}"] = avg_refusal_rate
# Add danger level metrics
for danger_level, rates in danger_refusal_rates.items():
if rates:
avg_refusal_rate = sum(rates) / len(rates)
eval_metrics[f"eval/refusal_rate_danger_{danger_level}"] = avg_refusal_rate
# Add inverted vs normal metrics
if inverted_refusal_rates:
eval_metrics["eval/inverted_categories_refusal_rate"] = avg_inverted_refusal_rate