mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
2f37714e84
commit
ffc210e470
1 changed files with 170 additions and 82 deletions
|
|
@ -225,45 +225,60 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
"""Format text for debug output (first 100 + last 100 chars)."""
|
||||
if not text:
|
||||
return f"{label}: <empty>"
|
||||
|
||||
|
||||
text_clean = text.strip()
|
||||
if len(text_clean) <= 200:
|
||||
return f"{label}: '{text_clean}'"
|
||||
|
||||
|
||||
first_100 = text_clean[:100]
|
||||
last_100 = text_clean[-100:]
|
||||
return f"{label}: '{first_100}...{last_100}' (total {len(text_clean)} chars)"
|
||||
|
||||
def _log_full_debug_request(self, messages: List[Dict], params: Dict, category: str = "unknown", item_id: str = "unknown", context: str = ""):
|
||||
def _log_full_debug_request(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
params: Dict,
|
||||
category: str = "unknown",
|
||||
item_id: str = "unknown",
|
||||
context: str = "",
|
||||
):
|
||||
"""Log full debug information for API requests."""
|
||||
if not self.config.full_debug:
|
||||
return
|
||||
|
||||
|
||||
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
|
||||
print(f" Category: {category}")
|
||||
print(f" Item ID: {item_id}")
|
||||
print(f" Parameters: {params}")
|
||||
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
print(f" Message {i+1} ({role}): {self._format_debug_text(content, 'Content')}")
|
||||
print(
|
||||
f" Message {i+1} ({role}): {self._format_debug_text(content, 'Content')}"
|
||||
)
|
||||
|
||||
def _log_full_debug_response(self, completion, context: str = ""):
|
||||
"""Log full debug information for API responses."""
|
||||
if not self.config.full_debug:
|
||||
return
|
||||
|
||||
|
||||
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
|
||||
|
||||
if hasattr(completion, 'usage'):
|
||||
|
||||
if hasattr(completion, "usage"):
|
||||
print(f" Usage: {completion.usage}")
|
||||
|
||||
if hasattr(completion, 'choices') and completion.choices:
|
||||
|
||||
if hasattr(completion, "choices") and completion.choices:
|
||||
for i, choice in enumerate(completion.choices):
|
||||
content = choice.message.content if hasattr(choice, 'message') else ""
|
||||
finish_reason = choice.finish_reason if hasattr(choice, 'finish_reason') else "unknown"
|
||||
print(f" Choice {i+1}: {self._format_debug_text(content, 'Response')}")
|
||||
content = choice.message.content if hasattr(choice, "message") else ""
|
||||
finish_reason = (
|
||||
choice.finish_reason
|
||||
if hasattr(choice, "finish_reason")
|
||||
else "unknown"
|
||||
)
|
||||
print(
|
||||
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
|
||||
)
|
||||
print(f" Finish reason: {finish_reason}")
|
||||
else:
|
||||
print(f" No choices in response")
|
||||
|
|
@ -408,12 +423,20 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
|
||||
# Show debug mode status
|
||||
if self.config.full_debug:
|
||||
print(f"\n🔍 FULL DEBUG MODE ENABLED - Will log all API requests and responses")
|
||||
print(f" 📊 Will show: category, item ID, first/last 100 chars of prompts and responses")
|
||||
print(f" ⚙️ Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s")
|
||||
print(
|
||||
f"\n🔍 FULL DEBUG MODE ENABLED - Will log all API requests and responses"
|
||||
)
|
||||
print(
|
||||
f" 📊 Will show: category, item ID, first/last 100 chars of prompts and responses"
|
||||
)
|
||||
print(
|
||||
f" ⚙️ Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s"
|
||||
)
|
||||
print(f" 📏 Min response length: {self.config.min_response_length} chars")
|
||||
else:
|
||||
print(f"\n🔍 Full debug mode disabled - Use full_debug=True to enable detailed logging")
|
||||
print(
|
||||
f"\n🔍 Full debug mode disabled - Use full_debug=True to enable detailed logging"
|
||||
)
|
||||
|
||||
# Debug: Show sample evaluation item structure
|
||||
if len(self.test) > 0:
|
||||
|
|
@ -688,73 +711,98 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
# Retry logic for training trajectories
|
||||
max_retries = self.config.max_retries
|
||||
retry_delay = self.config.retry_delay
|
||||
|
||||
|
||||
# Get category info for debug logging (this is synthetic training data)
|
||||
category = "synthetic_training"
|
||||
item_id = f"train_{self.iter if hasattr(self, 'iter') else 'unknown'}"
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Log full debug request
|
||||
self._log_full_debug_request(
|
||||
messages, completion_params, category, item_id,
|
||||
f"TRAINING attempt {attempt + 1}/{max_retries}"
|
||||
messages,
|
||||
completion_params,
|
||||
category,
|
||||
item_id,
|
||||
f"TRAINING attempt {attempt + 1}/{max_retries}",
|
||||
)
|
||||
|
||||
|
||||
completions = await self.server.chat_completion(
|
||||
messages=messages, **completion_params
|
||||
)
|
||||
|
||||
|
||||
# Log full debug response
|
||||
self._log_full_debug_response(completions, f"TRAINING attempt {attempt + 1}/{max_retries}")
|
||||
self._log_full_debug_response(
|
||||
completions, f"TRAINING attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
|
||||
# Check if we got valid completions
|
||||
if not completions.choices:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: No choices in collect_trajectories (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: No choices in collect_trajectories (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: No choices in collect_trajectories after {max_retries} attempts")
|
||||
print(
|
||||
f"DEBUG: No choices in collect_trajectories after {max_retries} attempts"
|
||||
)
|
||||
return None, []
|
||||
|
||||
|
||||
# Check if any completion has None content
|
||||
valid_completions = []
|
||||
for completion_choice in completions.choices:
|
||||
if (completion_choice.message.content is not None
|
||||
if (
|
||||
completion_choice.message.content is not None
|
||||
and isinstance(completion_choice.message.content, str)
|
||||
and len(completion_choice.message.content.strip()) >= self.config.min_response_length):
|
||||
and len(completion_choice.message.content.strip())
|
||||
>= self.config.min_response_length
|
||||
):
|
||||
valid_completions.append(completion_choice)
|
||||
|
||||
|
||||
# If we don't have enough valid completions, retry
|
||||
if len(valid_completions) < len(completions.choices) // 2: # If less than half are valid
|
||||
if (
|
||||
len(valid_completions) < len(completions.choices) // 2
|
||||
): # If less than half are valid
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions after {max_retries} attempts")
|
||||
print(
|
||||
f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions after {max_retries} attempts"
|
||||
)
|
||||
# Continue with what we have
|
||||
|
||||
|
||||
# Build trajectories using valid completions
|
||||
to_score = []
|
||||
for completion_choice in valid_completions:
|
||||
# Add assistant response to existing messages
|
||||
trajectory_messages = messages + [
|
||||
{"role": "assistant", "content": completion_choice.message.content}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": completion_choice.message.content,
|
||||
}
|
||||
]
|
||||
to_score.append((tuple(trajectory_messages), item[1]))
|
||||
|
||||
|
||||
# Success - we got at least some valid trajectories
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: collect_trajectories API call failed (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
print(
|
||||
f"DEBUG: collect_trajectories API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: collect_trajectories API call failed after {max_retries} attempts: {e}")
|
||||
print(
|
||||
f"DEBUG: collect_trajectories API call failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
return None, []
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
|
|
@ -871,29 +919,36 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
# Retry logic for failed API calls
|
||||
max_retries = self.config.max_retries
|
||||
retry_delay = self.config.retry_delay
|
||||
|
||||
|
||||
# Get category and item info for debug logging
|
||||
category = test_item.get("subset", "unknown")
|
||||
item_id = test_item.get("id", "unknown")
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Log full debug request
|
||||
self._log_full_debug_request(
|
||||
messages, completion_params, category, item_id,
|
||||
f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}"
|
||||
messages,
|
||||
completion_params,
|
||||
category,
|
||||
item_id,
|
||||
f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}",
|
||||
)
|
||||
|
||||
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages, **completion_params
|
||||
)
|
||||
|
||||
|
||||
# Log full debug response
|
||||
self._log_full_debug_response(completion, f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}")
|
||||
self._log_full_debug_response(
|
||||
completion, f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: No choices in completion (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: No choices in completion (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
|
|
@ -901,51 +956,69 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
return {"score": 0.0, "sample": None}
|
||||
|
||||
model_response = completion.choices[0].message.content
|
||||
|
||||
|
||||
# Check for None content or very short responses (likely just EOS token)
|
||||
if model_response is None:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: model_response is None (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: model_response is None (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
print(f"DEBUG: Completion: {completion}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: model_response is None after {max_retries} attempts")
|
||||
print(
|
||||
f"DEBUG: model_response is None after {max_retries} attempts"
|
||||
)
|
||||
print(f"DEBUG: Final completion: {completion}")
|
||||
return {"score": 0.0, "sample": None}
|
||||
|
||||
|
||||
if not isinstance(model_response, str):
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: model_response is not a string. Type: {type(model_response)}, Value: {model_response} (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: model_response is not a string. Type: {type(model_response)}, Value: {model_response} (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: model_response is not a string after {max_retries} attempts. Type: {type(model_response)}, Value: {model_response}")
|
||||
print(
|
||||
f"DEBUG: model_response is not a string after {max_retries} attempts. Type: {type(model_response)}, Value: {model_response}"
|
||||
)
|
||||
return {"score": 0.0, "sample": None}
|
||||
|
||||
|
||||
# Check for very short responses (likely just EOS token)
|
||||
if len(model_response.strip()) < self.config.min_response_length:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: Very short response (likely EOS token only): '{model_response}' (attempt {attempt + 1}/{max_retries})")
|
||||
print(f"DEBUG: Completion tokens: {completion.usage.completion_tokens if hasattr(completion, 'usage') else 'unknown'}")
|
||||
print(
|
||||
f"DEBUG: Very short response (likely EOS token only): '{model_response}' (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Completion tokens: {completion.usage.completion_tokens if hasattr(completion, 'usage') else 'unknown'}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: Very short response after {max_retries} attempts: '{model_response}'")
|
||||
print(
|
||||
f"DEBUG: Very short response after {max_retries} attempts: '{model_response}'"
|
||||
)
|
||||
return {"score": 0.0, "sample": None}
|
||||
|
||||
|
||||
# Success - we got a valid response
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: API call failed (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
print(
|
||||
f"DEBUG: API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: API call failed after {max_retries} attempts: {e}")
|
||||
print(
|
||||
f"DEBUG: API call failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
predicted_answer = self.process_judgement(
|
||||
model_response, track_metrics=False
|
||||
)
|
||||
|
|
@ -1069,7 +1142,7 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
# Get category and item info for debug logging
|
||||
category = test_item.get("subset", "unknown")
|
||||
item_id = test_item.get("id", "unknown")
|
||||
|
||||
|
||||
for prompt, response_text, is_correct in prompts_and_responses:
|
||||
messages = self._prepare_completion_input(prompt)
|
||||
completion_params = self._get_eval_completion_params()
|
||||
|
|
@ -1078,52 +1151,63 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
max_retries = self.config.max_retries
|
||||
retry_delay = self.config.retry_delay
|
||||
success = False
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Log full debug request
|
||||
self._log_full_debug_request(
|
||||
messages, completion_params, category, item_id,
|
||||
f"TIES_EVAL attempt {attempt + 1}/{max_retries}"
|
||||
messages,
|
||||
completion_params,
|
||||
category,
|
||||
item_id,
|
||||
f"TIES_EVAL attempt {attempt + 1}/{max_retries}",
|
||||
)
|
||||
|
||||
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages, **completion_params
|
||||
)
|
||||
|
||||
|
||||
# Log full debug response
|
||||
self._log_full_debug_response(completion, f"TIES_EVAL attempt {attempt + 1}/{max_retries}")
|
||||
self._log_full_debug_response(
|
||||
completion, f"TIES_EVAL attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
|
||||
if not completion.choices:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: No choices in ties completion (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: No choices in ties completion (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
break # Failed after all retries
|
||||
|
||||
|
||||
model_response = completion.choices[0].message.content
|
||||
|
||||
|
||||
# Check for None content or very short responses
|
||||
if model_response is None:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: ties model_response is None (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: ties model_response is None (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
break # Failed after all retries
|
||||
|
||||
|
||||
if not isinstance(model_response, str):
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: ties model_response is not a string. Type: {type(model_response)} (attempt {attempt + 1}/{max_retries})")
|
||||
print(
|
||||
f"DEBUG: ties model_response is not a string. Type: {type(model_response)} (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
break # Failed after all retries
|
||||
|
||||
|
||||
# For ties evaluation, don't check response format - invalid ratings are part of normal evaluation
|
||||
# Only retry for technical failures (None content, API errors, etc.)
|
||||
|
||||
|
||||
# Success - process the rating
|
||||
rating = self._process_rating_judgment(model_response)
|
||||
ratings.append(rating)
|
||||
|
|
@ -1138,16 +1222,20 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
)
|
||||
success = True
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"DEBUG: ties API call failed (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
print(
|
||||
f"DEBUG: ties API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
print(f"DEBUG: ties API call failed after {max_retries} attempts: {e}")
|
||||
print(
|
||||
f"DEBUG: ties API call failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# If we failed after all retries, add error rating
|
||||
if not success:
|
||||
ratings.append(-1) # Error rating
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue