yc-bench/src/yc_bench/agent/runtime/litellm_runtime.py

318 lines
12 KiB
Python

from __future__ import annotations
import json
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from dataclasses import dataclass, field
import litellm
from .base import AgentRuntime
from .schemas import RuntimeTurnResult
from ..prompt import SYSTEM_PROMPT
from ..tools.run_command_schema import normalize_result
logger = logging.getLogger(__name__)
litellm.suppress_debug_info = True
litellm.drop_params = True # silently drop unsupported params (e.g. tool_choice for mini/nano models)
# Tool schema passed to the LLM on every call
_RUN_COMMAND_TOOL = {
"type": "function",
"function": {
"name": "run_command",
"description": (
"Execute one benchmark CLI command inside the sandbox "
"and return structured execution output."
),
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The full yc-bench CLI command to execute.",
}
},
"required": ["command"],
},
},
}
@dataclass
class _Session:
messages: list = field(default_factory=list)
scratchpad: str | None = None
turn_logs: list = field(default_factory=list) # full API call logs per turn
class LiteLLMRuntime(AgentRuntime):
def __init__(self, settings, command_executor):
self._settings = settings
self._command_executor = command_executor
self._sessions: dict[str, _Session] = {}
self._request_timeout_seconds = settings.request_timeout_seconds
self._retry_max_attempts = settings.retry_max_attempts
self._retry_backoff_seconds = settings.retry_backoff_seconds
if self._request_timeout_seconds <= 0:
raise ValueError("request_timeout_seconds must be > 0")
if self._retry_max_attempts <= 0:
raise ValueError("retry_max_attempts must be > 0")
if self._retry_backoff_seconds <= 0:
raise ValueError("retry_backoff_seconds must be > 0")
# API key: check provider-specific env vars, then generic fallbacks.
# LiteLLM reads these natively for their respective providers, but we
# also pass the key explicitly via kwargs to be safe.
self._api_key = (
os.environ.get("ANTHROPIC_API_KEY")
or os.environ.get("OPENAI_API_KEY")
or os.environ.get("OPENROUTER_API_KEY")
or None
)
# Base URL: only needed for raw OpenAI-compatible endpoints.
# openrouter/ model prefix is handled natively by LiteLLM without this.
self._api_base = os.environ.get("OPENAI_BASE_URL") or None
self._history_keep_rounds = settings.history_keep_rounds
logger.info(
"LiteLLMRuntime configured: model=%s api_base=%s history_keep_rounds=%d",
self._settings.model,
self._api_base or "default",
self._history_keep_rounds,
)
# ------------------------------------------------------------------
# AgentRuntime interface
# ------------------------------------------------------------------
def run_turn(self, request):
session = self._get_or_create_session(request.session_id)
# Update scratchpad on session (appended to system prompt, not message history)
session.scratchpad = request.scratchpad
# Proactively drop old rounds before appending new input.
self._proactive_truncate(session)
session.messages.append({"role": "user", "content": request.user_input})
result = None
last_err = None
for attempt in range(1, self._retry_max_attempts + 1):
try:
result = self._run_with_timeout(session)
break
except Exception as e:
last_err = e
if self._is_context_length_error(e):
logger.warning(
"Context-length error on attempt %d despite proactive truncation "
"(history_keep_rounds=%d). Consider reducing YC_BENCH_HISTORY_KEEP_ROUNDS.",
attempt,
self._history_keep_rounds,
)
continue
logger.warning("Turn attempt %d failed: %s", attempt, e)
if attempt >= self._retry_max_attempts:
raise RuntimeError(
f"Failed to run turn after {self._retry_max_attempts} attempts"
) from last_err
time.sleep(self._retry_backoff_seconds * (2 ** (attempt - 1)))
if result is None:
raise RuntimeError("run_turn failed without result") from last_err
final_output, tool_calls_made, resume_payload, turn_cost = result
# Include latest turn log for transcript saving
latest_log = session.turn_logs[-1] if session.turn_logs else {}
return RuntimeTurnResult(
final_output=final_output,
raw_result={
"tool_calls": tool_calls_made,
"prompt_tokens": latest_log.get("prompt_tokens", 0),
"completion_tokens": latest_log.get("completion_tokens", 0),
},
checkpoint_advanced=resume_payload is not None,
resume_payload=resume_payload,
turn_cost_usd=turn_cost,
)
def clear_session(self, session_id):
self._sessions.pop(session_id, None)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _run_with_timeout(self, session):
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self._do_turn, session)
try:
return future.result(timeout=self._request_timeout_seconds)
except FuturesTimeoutError as exc:
future.cancel()
raise TimeoutError(
f"LiteLLM call timed out after {self._request_timeout_seconds}s"
) from exc
def _do_turn(self, session):
"""One LLM call + tool execution. Returns (final_output, tool_calls_made, resume_payload, cost_usd)."""
system_prompt = self._settings.system_prompt or SYSTEM_PROMPT
# Append scratchpad to system prompt (avoids duplication in message history)
if session.scratchpad:
system_prompt = system_prompt + f"\n\n## Your Scratchpad Notes\n{session.scratchpad}"
messages = [{"role": "system", "content": system_prompt}] + session.messages
kwargs = dict(
model=self._settings.model,
messages=messages,
tools=[_RUN_COMMAND_TOOL],
tool_choice="auto",
timeout=self._request_timeout_seconds,
)
if self._api_base:
kwargs["api_base"] = self._api_base
# Let LiteLLM resolve API keys from provider-specific env vars
# (ANTHROPIC_API_KEY, GEMINI_API_KEY, OPENROUTER_API_KEY, etc.)
# rather than passing a single key that may not match the provider.
response = litellm.completion(**kwargs)
# Log token usage and cost per call
turn_cost = 0.0
prompt_tokens = 0
completion_tokens = 0
usage = getattr(response, "usage", None)
if usage:
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
cost = getattr(response, "_hidden_params", {}).get("response_cost") or 0
turn_cost = float(cost)
logger.info(
"LLM call: prompt_tokens=%s completion_tokens=%s cost=$%.6f",
prompt_tokens, completion_tokens, turn_cost,
)
message = response.choices[0].message
# Save full turn log for analysis
session.turn_logs.append({
"messages_sent": messages, # full messages array as sent to API
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"cost_usd": turn_cost,
})
tool_calls = getattr(message, "tool_calls", None) or []
tool_calls_made = []
resume_payload = None
if tool_calls:
# Persist assistant message with tool calls
session.messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
command = args.get("command", "")
except Exception:
command = ""
raw = self._command_executor(command)
normalized = normalize_result(raw)
tool_result_str = json.dumps(normalized.__dict__)
tool_calls_made.append({"command": command, "result": tool_result_str})
# Extract resume payload when the agent advances simulation time
if command.startswith("yc-bench sim resume"):
try:
stdout = normalized.__dict__.get("stdout", "")
if isinstance(stdout, str) and stdout.strip():
payload = json.loads(stdout)
if isinstance(payload, dict):
resume_payload = payload
except Exception:
pass
session.messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": tool_result_str,
})
cmds = [tc["command"] for tc in tool_calls_made]
final_output = f"Executed {len(tool_calls)} tool call(s): {', '.join(cmds)}"
else:
content = message.content or ""
session.messages.append({"role": "assistant", "content": content})
final_output = content
return final_output, tool_calls_made, resume_payload, turn_cost
def _get_or_create_session(self, session_id: str) -> _Session:
if session_id not in self._sessions:
self._sessions[session_id] = _Session()
return self._sessions[session_id]
def _is_context_length_error(self, err: Exception) -> bool:
text = str(err).lower()
patterns = (
"context length",
"maximum context",
"max context",
"too many tokens",
"prompt is too long",
"token limit",
"context window",
)
return any(p in text for p in patterns)
def _round_start_indices(self, messages: list) -> list[int]:
"""Return indices of user messages — each marks the start of a round."""
return [i for i, m in enumerate(messages) if m.get("role") == "user"]
def _proactive_truncate(self, session: _Session) -> None:
"""Drop oldest rounds before each turn so at most history_keep_rounds remain."""
messages = session.messages
user_indices = self._round_start_indices(messages)
if len(user_indices) <= self._history_keep_rounds:
return
cutoff = user_indices[-self._history_keep_rounds]
marker = {
"role": "user",
"content": (
f"[Earlier turns removed. Only the last {self._history_keep_rounds} "
"turns are retained in this context window.]"
),
}
session.messages = [marker] + messages[cutoff:]
logger.info(
"Proactive truncation: kept last %d rounds (%d%d messages).",
self._history_keep_rounds,
len(messages),
len(session.messages),
)
__all__ = ["LiteLLMRuntime"]