mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix eval ctx len
This commit is contained in:
parent
85296c519e
commit
c871f6a56a
4 changed files with 86 additions and 28 deletions
|
|
@ -557,15 +557,31 @@ class AGIEvalEnv(BaseEnv):
|
|||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
||||
# Build completion kwargs - only include max_tokens if > 0
|
||||
# (0 means "use model default", so we don't pass the parameter)
|
||||
completion_kwargs = {
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.config.eval_temperature,
|
||||
"split": "eval",
|
||||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
temperature=self.config.eval_temperature,
|
||||
max_tokens=self.config.eval_max_tokens,
|
||||
split="eval",
|
||||
)
|
||||
if self.config.full_debug:
|
||||
print(f" Making API request (attempt {attempt + 1}/{self.config.max_retries})...")
|
||||
try:
|
||||
model_name = self.server.servers[0].config.model_name if hasattr(self.server, 'servers') else 'unknown'
|
||||
except Exception:
|
||||
model_name = 'unknown'
|
||||
print(f" Model: {model_name}")
|
||||
print(f" Temperature: {self.config.eval_temperature}")
|
||||
print(f" Max tokens: {self.config.eval_max_tokens if self.config.eval_max_tokens > 0 else 'model default'}")
|
||||
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
||||
if completion.choices and completion.choices[0].message.content:
|
||||
model_response = completion.choices[0].message.content
|
||||
|
|
@ -584,10 +600,34 @@ class AGIEvalEnv(BaseEnv):
|
|||
await asyncio.sleep(self.config.retry_delay)
|
||||
|
||||
except Exception as e:
|
||||
# Extract the underlying error from RetryError if present
|
||||
actual_error = e
|
||||
error_chain = []
|
||||
while actual_error is not None:
|
||||
error_chain.append(f"{type(actual_error).__name__}: {actual_error}")
|
||||
# Try to get the underlying cause
|
||||
if hasattr(actual_error, '__cause__') and actual_error.__cause__ is not None:
|
||||
actual_error = actual_error.__cause__
|
||||
elif hasattr(actual_error, 'last_attempt'):
|
||||
# tenacity RetryError stores the last attempt's exception
|
||||
try:
|
||||
actual_error = actual_error.last_attempt.exception()
|
||||
except Exception:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Always log API errors to help diagnose issues
|
||||
print(
|
||||
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
# Print the full error chain for debugging
|
||||
if len(error_chain) > 1:
|
||||
print(" Error chain:")
|
||||
for i, err in enumerate(error_chain):
|
||||
print(f" {' ' * i}-> {err}")
|
||||
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
resp_text = (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue