mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
9e9f1cd88e
commit
269fb71713
4 changed files with 32 additions and 19 deletions
|
|
@ -557,7 +557,7 @@ class AGIEvalEnv(BaseEnv):
|
|||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
||||
|
||||
# Build completion kwargs - only include max_tokens if > 0
|
||||
# (0 means "use model default", so we don't pass the parameter)
|
||||
completion_kwargs = {
|
||||
|
|
@ -568,19 +568,27 @@ class AGIEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
if self.config.full_debug:
|
||||
print(f" Making API request (attempt {attempt + 1}/{self.config.max_retries})...")
|
||||
print(
|
||||
f" Making API request (attempt {attempt + 1}/{self.config.max_retries})..."
|
||||
)
|
||||
try:
|
||||
model_name = self.server.servers[0].config.model_name if hasattr(self.server, 'servers') else 'unknown'
|
||||
model_name = (
|
||||
self.server.servers[0].config.model_name
|
||||
if hasattr(self.server, "servers")
|
||||
else "unknown"
|
||||
)
|
||||
except Exception:
|
||||
model_name = 'unknown'
|
||||
model_name = "unknown"
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Temperature: {self.config.eval_temperature}")
|
||||
print(f" Max tokens: {self.config.eval_max_tokens if self.config.eval_max_tokens > 0 else 'model default'}")
|
||||
|
||||
print(
|
||||
f" Max tokens: {self.config.eval_max_tokens if self.config.eval_max_tokens > 0 else 'model default'}"
|
||||
)
|
||||
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
||||
if completion.choices and completion.choices[0].message.content:
|
||||
|
|
@ -604,11 +612,16 @@ class AGIEvalEnv(BaseEnv):
|
|||
actual_error = e
|
||||
error_chain = []
|
||||
while actual_error is not None:
|
||||
error_chain.append(f"{type(actual_error).__name__}: {actual_error}")
|
||||
error_chain.append(
|
||||
f"{type(actual_error).__name__}: {actual_error}"
|
||||
)
|
||||
# Try to get the underlying cause
|
||||
if hasattr(actual_error, '__cause__') and actual_error.__cause__ is not None:
|
||||
if (
|
||||
hasattr(actual_error, "__cause__")
|
||||
and actual_error.__cause__ is not None
|
||||
):
|
||||
actual_error = actual_error.__cause__
|
||||
elif hasattr(actual_error, 'last_attempt'):
|
||||
elif hasattr(actual_error, "last_attempt"):
|
||||
# tenacity RetryError stores the last attempt's exception
|
||||
try:
|
||||
actual_error = actual_error.last_attempt.exception()
|
||||
|
|
@ -616,18 +629,18 @@ class AGIEvalEnv(BaseEnv):
|
|||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
# Always log API errors to help diagnose issues
|
||||
print(
|
||||
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Print the full error chain for debugging
|
||||
if len(error_chain) > 1:
|
||||
print(" Error chain:")
|
||||
for i, err in enumerate(error_chain):
|
||||
print(f" {' ' * i}-> {err}")
|
||||
|
||||
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
resp_text = (
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ class GPQAEvalEnv(BaseEnv):
|
|||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
||||
|
||||
# Build completion kwargs - only include max_tokens if > 0
|
||||
# (0 means "use model default", so we don't pass the parameter)
|
||||
completion_kwargs = {
|
||||
|
|
@ -490,7 +490,7 @@ class GPQAEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
|
|
|||
|
|
@ -775,7 +775,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
||||
|
||||
# Build completion kwargs - only include max_tokens if > 0
|
||||
# (0 means "use model default", so we don't pass the parameter)
|
||||
completion_kwargs = {
|
||||
|
|
@ -786,7 +786,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
|
|
|||
|
|
@ -557,7 +557,7 @@ class MMLUProEvalEnv(BaseEnv):
|
|||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
||||
|
||||
# Build completion kwargs - only include max_tokens if > 0
|
||||
# (0 means "use model default", so we don't pass the parameter)
|
||||
completion_kwargs = {
|
||||
|
|
@ -568,7 +568,7 @@ class MMLUProEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue