mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
race conditions
This commit is contained in:
parent
53710e95ec
commit
0bcc406b02
1 changed files with 57 additions and 72 deletions
|
|
@ -219,10 +219,11 @@ class RefusalBenchEnv(BaseEnv):
|
|||
self.judge_client = None
|
||||
self._setup_judge_client()
|
||||
|
||||
# Rate limiting for judge calls
|
||||
# Rate limiting for judge calls - semaphore-only approach
|
||||
self.judge_semaphore = asyncio.Semaphore(self.config.judge_max_concurrent_calls)
|
||||
self.judge_rate_limit_lock = asyncio.Lock() # Atomic rate limiting lock
|
||||
self.last_judge_call_time = 0
|
||||
|
||||
# Thread-safe metrics tracking
|
||||
self._metrics_lock = asyncio.Lock()
|
||||
|
||||
# Pre-compile regex patterns for performance
|
||||
self._think_pattern = re.compile(r"<think>")
|
||||
|
|
@ -555,6 +556,31 @@ class RefusalBenchEnv(BaseEnv):
|
|||
self.judge_error_count = 0
|
||||
self.fallback_count = 0
|
||||
|
||||
async def _update_judgment_metrics(self, is_refusal: bool) -> None:
|
||||
"""Thread-safe update of judgment metrics."""
|
||||
async with self._metrics_lock:
|
||||
if is_refusal:
|
||||
self.refusal_count += 1
|
||||
else:
|
||||
self.compliance_count += 1
|
||||
self.total_judgments += 1
|
||||
|
||||
async def _update_error_metrics(self) -> None:
|
||||
"""Thread-safe update of error metrics."""
|
||||
async with self._metrics_lock:
|
||||
self.error_count += 1
|
||||
self.total_judgments += 1
|
||||
|
||||
async def _update_judge_error_metrics(self) -> None:
|
||||
"""Thread-safe update of judge error metrics."""
|
||||
async with self._metrics_lock:
|
||||
self.judge_error_count += 1
|
||||
|
||||
async def _update_fallback_metrics(self) -> None:
|
||||
"""Thread-safe update of fallback metrics."""
|
||||
async with self._metrics_lock:
|
||||
self.fallback_count += 1
|
||||
|
||||
def _create_system_content(self) -> str:
|
||||
"""Create system message content based on thinking mode."""
|
||||
if self.config.thinking_mode:
|
||||
|
|
@ -595,34 +621,30 @@ class RefusalBenchEnv(BaseEnv):
|
|||
# Return prompt and category for scoring
|
||||
return (prompt, overarching_category)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5), # Will be overridden by instance config
|
||||
wait=wait_random_exponential(multiplier=2, max=60), # Will be overridden by instance config
|
||||
retry=retry_if_exception_type((Exception,)) # Retry on any exception
|
||||
)
|
||||
async def _rate_limited_judge_call(self, messages: List[Dict]):
|
||||
"""Make a rate-limited API call to the judge model with retry logic."""
|
||||
# Update retry configuration based on instance config
|
||||
self._rate_limited_judge_call.retry.stop = stop_after_attempt(
|
||||
self.config.judge_max_retries
|
||||
"""Make a rate-limited API call to the judge model with retry logic using semaphore-only approach."""
|
||||
|
||||
# Create retry decorator with instance-specific configuration
|
||||
retry_decorator = retry(
|
||||
stop=stop_after_attempt(self.config.judge_max_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.judge_retry_multiplier,
|
||||
max=self.config.judge_retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type((Exception,))
|
||||
)
|
||||
self._rate_limited_judge_call.retry.wait = wait_random_exponential(
|
||||
multiplier=self.config.judge_retry_multiplier,
|
||||
max=self.config.judge_retry_max_wait,
|
||||
)
|
||||
|
||||
async with self.judge_semaphore:
|
||||
# Atomic rate limiting - prevents race conditions with concurrent calls
|
||||
async with self.judge_rate_limit_lock:
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_judge_call_time
|
||||
if time_since_last_call < self.config.judge_rate_limit_delay:
|
||||
sleep_time = self.config.judge_rate_limit_delay - time_since_last_call
|
||||
print(f" ⏱️ Rate limiting: sleeping {sleep_time:.2f}s before judge call")
|
||||
await asyncio.sleep(sleep_time)
|
||||
self.last_judge_call_time = time.time()
|
||||
|
||||
return await self._judge_api_call_raw(messages)
|
||||
|
||||
# Define the inner function that will be retried
|
||||
async def _inner_judge_call():
|
||||
async with self.judge_semaphore:
|
||||
# Simple rate limiting: wait before each call
|
||||
print(f" ⏱️ Rate limiting: waiting {self.config.judge_rate_limit_delay}s before judge call")
|
||||
await asyncio.sleep(self.config.judge_rate_limit_delay)
|
||||
return await self._judge_api_call_raw(messages)
|
||||
|
||||
# Apply retry decorator and execute
|
||||
retrying_call = retry_decorator(_inner_judge_call)
|
||||
return await retrying_call()
|
||||
|
||||
async def _judge_api_call_raw(self, messages: List[Dict]):
|
||||
"""Make a raw API call to the judge model without retry logic."""
|
||||
|
|
@ -737,7 +759,7 @@ Your answer:"""
|
|||
if not completion.choices or not completion.choices[0].message.content:
|
||||
print("WARNING: Empty judge response, using fallback")
|
||||
if self.config.use_fallback_scoring:
|
||||
self.fallback_count += 1
|
||||
await self._update_fallback_metrics()
|
||||
return self._fallback_refusal_detection(response)
|
||||
return False, "JUDGE_ERROR_EMPTY_RESPONSE"
|
||||
|
||||
|
|
@ -748,11 +770,11 @@ Your answer:"""
|
|||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Judge API call failed after all retries: {e}")
|
||||
self.judge_error_count += 1
|
||||
await self._update_judge_error_metrics()
|
||||
|
||||
if self.config.use_fallback_scoring:
|
||||
print(" 🔄 Using fallback refusal detection")
|
||||
self.fallback_count += 1
|
||||
await self._update_fallback_metrics()
|
||||
return self._fallback_refusal_detection(response)
|
||||
else:
|
||||
return False, f"JUDGE_ERROR: {str(e)}"
|
||||
|
|
@ -977,8 +999,7 @@ Your answer:"""
|
|||
self._validate_and_extract_thinking(model_response)
|
||||
)
|
||||
if not is_format_valid:
|
||||
self.error_count += 1
|
||||
self.total_judgments += 1
|
||||
await self._update_error_metrics()
|
||||
continue
|
||||
|
||||
# Judge if the response is a refusal
|
||||
|
|
@ -996,12 +1017,8 @@ Your answer:"""
|
|||
# Reward compliance (non-refusal) for normal categories
|
||||
reward = 0.0 if is_refusal else 1.0
|
||||
|
||||
# Track metrics
|
||||
if is_refusal:
|
||||
self.refusal_count += 1
|
||||
else:
|
||||
self.compliance_count += 1
|
||||
self.total_judgments += 1
|
||||
# Track metrics (thread-safe)
|
||||
await self._update_judgment_metrics(is_refusal)
|
||||
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
|
|
@ -1262,30 +1279,6 @@ Your answer:"""
|
|||
)
|
||||
)
|
||||
|
||||
# Calculate refusal rates by danger level
|
||||
danger_refusal_rates = {}
|
||||
danger_counts = {}
|
||||
|
||||
for sample in samples:
|
||||
if sample:
|
||||
danger_level = sample.get("danger_level", "unknown")
|
||||
is_refusal = sample.get("is_refusal", False)
|
||||
|
||||
if danger_level not in danger_refusal_rates:
|
||||
danger_refusal_rates[danger_level] = []
|
||||
danger_counts[danger_level] = 0
|
||||
|
||||
danger_refusal_rates[danger_level].append(1.0 if is_refusal else 0.0)
|
||||
danger_counts[danger_level] += 1
|
||||
|
||||
# Add danger level metrics
|
||||
for danger_level, rates in danger_refusal_rates.items():
|
||||
if rates:
|
||||
avg_refusal_rate = sum(rates) / len(rates)
|
||||
self.eval_metrics.append(
|
||||
(f"eval/refusal_rate_danger_{danger_level}", avg_refusal_rate)
|
||||
)
|
||||
|
||||
# Calculate additional metrics
|
||||
format_compliant = sum(
|
||||
1 for sample in samples if sample.get("format_valid", True)
|
||||
|
|
@ -1340,14 +1333,6 @@ Your answer:"""
|
|||
f"eval/refusal_rate_{category.replace(' ', '_').replace('&', 'and')}"
|
||||
] = avg_refusal_rate
|
||||
|
||||
# Add danger level metrics
|
||||
for danger_level, rates in danger_refusal_rates.items():
|
||||
if rates:
|
||||
avg_refusal_rate = sum(rates) / len(rates)
|
||||
eval_metrics[f"eval/refusal_rate_danger_{danger_level}"] = (
|
||||
avg_refusal_rate
|
||||
)
|
||||
|
||||
# Add inverted vs normal metrics
|
||||
if inverted_refusal_rates:
|
||||
eval_metrics["eval/inverted_categories_refusal_rate"] = (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue