mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
This commit fixes compatibility issues with smolagents 1.22.0 ChatMessage objects and improves the documentation for easier setup. Changes: - Fix smolagents_model.py to handle ChatMessage objects (not just dicts) in _extract_user_message() and _format_chat_messages() - Fix smolagents_env.py to handle ChatMessage objects in trajectory scoring and data group creation - Update README.md with clearer installation instructions, Quick Start section, and automatic GAIA dataset download documentation - Add test_run.sh script for easy testing with OpenAI models Tested with: - smolagents 1.22.0 - gpt-4o-mini via OpenAI API - Tavily web search tools - Automatic GAIA dataset download 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
247 lines
9.6 KiB
Python
247 lines
9.6 KiB
Python
"""
|
|
Process-safe implementation of the AtroposServerModel for SmolaGents.
|
|
"""
|
|
|
|
import logging
|
|
import traceback
|
|
|
|
from smolagents.models import ChatMessage, MessageRole, Model
|
|
|
|
from .server_proxy import ServerProxy
|
|
|
|
# Configure logger for the model class
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class ProcessSafeAtroposServerModel(Model):
|
|
"""
|
|
A SmolaGents Model implementation that works with a server proxy.
|
|
|
|
This class is designed to be used in separate processes and
|
|
communicates with the Atropos server through a proxy mechanism.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_proxy: ServerProxy,
|
|
use_chat_completion: bool = False,
|
|
model_id: str = None,
|
|
**kwargs,
|
|
):
|
|
self.server_proxy = server_proxy
|
|
self.use_chat_completion = use_chat_completion
|
|
|
|
# Automatically set chat completion for GPT models which require it
|
|
if model_id and any(
|
|
name in model_id.lower()
|
|
for name in ["gpt-4", "gpt-3.5-turbo", "claude", "gemini", "o", "llama"]
|
|
):
|
|
logger.info(
|
|
f"Model {model_id} detected as a chat model. Forcing chat completion API."
|
|
)
|
|
self.use_chat_completion = True
|
|
|
|
# Log the configuration
|
|
logger.info(
|
|
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, "
|
|
f"use_chat_completion={self.use_chat_completion}"
|
|
)
|
|
|
|
super().__init__(model_id=model_id, **kwargs)
|
|
|
|
def _prepare_completion_args(self, messages, stop_sequences=None, **kwargs):
|
|
"""
|
|
Convert SmolaGents message format to Atropos server parameters.
|
|
"""
|
|
|
|
# Always use chat completion if configured that way
|
|
if self.use_chat_completion:
|
|
# For chat completion, we format messages and don't use prompt
|
|
server_args = {
|
|
"messages": self._format_chat_messages(messages),
|
|
"max_tokens": kwargs.get("max_tokens", 2048),
|
|
"temperature": kwargs.get("temperature", 0.0),
|
|
"stop": stop_sequences,
|
|
}
|
|
logger.debug(
|
|
f"Prepared chat completion args: messages count={len(server_args['messages'])}"
|
|
)
|
|
return server_args
|
|
else:
|
|
# Extract the user message for completion API
|
|
prompt = self._extract_user_message(messages)
|
|
server_args = {
|
|
"prompt": prompt,
|
|
"max_tokens": kwargs.get("max_tokens", 2048),
|
|
"temperature": kwargs.get("temperature", 0.0),
|
|
"stop": stop_sequences,
|
|
}
|
|
logger.debug(f"Prepared completion args: prompt length={len(prompt)}")
|
|
return server_args
|
|
|
|
def _extract_user_message(self, messages):
|
|
"""Extract content from the last user message."""
|
|
for msg in reversed(messages):
|
|
# Handle both dict and ChatMessage objects
|
|
role = msg.role if hasattr(msg, "role") else msg["role"]
|
|
role_str = role.value if hasattr(role, "value") else str(role)
|
|
|
|
if role_str.lower() == "user":
|
|
content = msg.content if hasattr(msg, "content") else msg["content"]
|
|
if isinstance(content, list):
|
|
# Handle list format [{"type": "text", "text": "content"}]
|
|
return "\n".join(
|
|
item["text"] for item in content if item["type"] == "text"
|
|
)
|
|
return content
|
|
raise ValueError("No user message found")
|
|
|
|
def _format_chat_messages(self, messages):
|
|
"""Format messages for the chat completion API."""
|
|
formatted_messages = []
|
|
|
|
# For OpenAI API, we need to map roles to the ones they support
|
|
for i, msg in enumerate(messages):
|
|
# Handle both dict and ChatMessage objects
|
|
role = msg.role if hasattr(msg, "role") else msg["role"]
|
|
content = msg.content if hasattr(msg, "content") else msg["content"]
|
|
|
|
# Map any role to either system, user, or assistant
|
|
if isinstance(role, str):
|
|
role_str = role.lower()
|
|
elif hasattr(role, "value"):
|
|
role_str = str(role.value).lower()
|
|
else:
|
|
role_str = str(role).lower()
|
|
|
|
# Mapping to OpenAI roles
|
|
if role_str == "system":
|
|
openai_role = "system"
|
|
elif role_str == "user":
|
|
openai_role = "user"
|
|
elif role_str == "assistant":
|
|
openai_role = "assistant"
|
|
elif role_str in (
|
|
"tool_call",
|
|
"tool_response",
|
|
"function_call",
|
|
"function_response",
|
|
):
|
|
# Silently map tool and function calls/responses to user roles
|
|
openai_role = "user"
|
|
else:
|
|
# Default everything else to user without logging
|
|
openai_role = "user"
|
|
|
|
# Extract text content if it's in the list format
|
|
if isinstance(content, list):
|
|
text_content = "\n".join(
|
|
item["text"] for item in content if item["type"] == "text"
|
|
)
|
|
formatted_messages.append(
|
|
{"role": openai_role, "content": text_content}
|
|
)
|
|
else:
|
|
formatted_messages.append({"role": openai_role, "content": content})
|
|
|
|
return formatted_messages
|
|
|
|
def generate(
|
|
self,
|
|
messages: list[dict[str, str | list[dict]]],
|
|
stop_sequences: list[str] | None = None,
|
|
grammar: str | None = None,
|
|
tools_to_call_from: list | None = None,
|
|
**kwargs,
|
|
) -> ChatMessage:
|
|
"""
|
|
Process the input messages and return the model's response.
|
|
Uses the server proxy to communicate with the Atropos server.
|
|
|
|
Parameters:
|
|
messages: A list of message dictionaries to be processed.
|
|
stop_sequences: A list of strings that will stop the generation if encountered.
|
|
grammar: The grammar or formatting structure to use (not used with Atropos).
|
|
tools_to_call_from: List of tools (not used with Atropos).
|
|
**kwargs: Additional keyword arguments for the server.
|
|
|
|
Returns:
|
|
ChatMessage: A chat message object containing the model's response.
|
|
"""
|
|
# Special handling for CodeAgent stop sequences
|
|
if stop_sequences is None:
|
|
stop_sequences = ["Observation:", "<end_code>", "Calling tools:"]
|
|
|
|
logger.info(
|
|
f"Generate called with {len(messages)} messages, use_chat_completion={self.use_chat_completion}"
|
|
)
|
|
|
|
# Prepare completion arguments
|
|
completion_kwargs = self._prepare_completion_args(
|
|
messages=messages, stop_sequences=stop_sequences, **kwargs
|
|
)
|
|
|
|
# Extract timeout from kwargs or use default (but not used in this method)
|
|
kwargs.pop("timeout", 120) # Default 2 minutes
|
|
|
|
try:
|
|
# Use chat_completion if configured
|
|
if self.use_chat_completion:
|
|
logger.info("Using chat completion API through proxy")
|
|
|
|
# Convert prompt to messages format if needed
|
|
if (
|
|
"prompt" in completion_kwargs
|
|
and "messages" not in completion_kwargs
|
|
):
|
|
logger.info("Converting prompt to messages format")
|
|
completion_kwargs["messages"] = [
|
|
{"role": "user", "content": completion_kwargs.pop("prompt")}
|
|
]
|
|
|
|
# Call chat_completion via proxy
|
|
resp = self.server_proxy.chat_completion(**completion_kwargs)
|
|
|
|
# Process response
|
|
if resp and hasattr(resp, "choices") and len(resp.choices) > 0:
|
|
content = resp.choices[0].message.content
|
|
logger.info(f"Got response with {len(content)} chars")
|
|
else:
|
|
content = "No response content"
|
|
logger.warning("No content found in response")
|
|
else:
|
|
logger.info("Using completion API through proxy")
|
|
|
|
# Call completion via proxy
|
|
resp = self.server_proxy.completion(**completion_kwargs)
|
|
|
|
# Process response
|
|
if resp and hasattr(resp, "choices") and len(resp.choices) > 0:
|
|
content = resp.choices[0].text
|
|
logger.info(f"Got response with {len(content)} chars")
|
|
else:
|
|
content = "No response content"
|
|
logger.warning("No content found in response")
|
|
|
|
# Track token usage
|
|
if hasattr(resp, "usage"):
|
|
self.last_input_token_count = resp.usage.prompt_tokens
|
|
self.last_output_token_count = resp.usage.completion_tokens
|
|
logger.info(
|
|
f"Token usage: input={self.last_input_token_count}, output={self.last_output_token_count}"
|
|
)
|
|
|
|
# Return result in SmolaGents format
|
|
logger.info("Successfully returning ChatMessage")
|
|
return ChatMessage(role=MessageRole.ASSISTANT, content=content, raw=resp)
|
|
|
|
except Exception as e:
|
|
# Provide more detailed error information
|
|
error_msg = f"Error during server proxy call: {type(e).__name__}: {str(e)}"
|
|
logger.error(error_msg)
|
|
|
|
# Print full stack trace for debugging
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
|
|
raise ValueError(error_msg)
|