atropos/environments/smolagents_integration/smolagents_model.py
hjc-puro dddfb30c5b Fix smolagents ChatMessage compatibility and improve documentation
This commit fixes compatibility issues with smolagents 1.22.0 ChatMessage
objects and improves the documentation for easier setup.

Changes:
- Fix smolagents_model.py to handle ChatMessage objects (not just dicts)
  in _extract_user_message() and _format_chat_messages()
- Fix smolagents_env.py to handle ChatMessage objects in trajectory
  scoring and data group creation
- Update README.md with clearer installation instructions, Quick Start
  section, and automatic GAIA dataset download documentation
- Add test_run.sh script for easy testing with OpenAI models

Tested with:
- smolagents 1.22.0
- gpt-4o-mini via OpenAI API
- Tavily web search tools
- Automatic GAIA dataset download

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-29 21:27:52 +00:00

247 lines
9.6 KiB
Python

"""
Process-safe implementation of the AtroposServerModel for SmolaGents.
"""
import logging
import traceback
from smolagents.models import ChatMessage, MessageRole, Model
from .server_proxy import ServerProxy
# Configure logger for the model class
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ProcessSafeAtroposServerModel(Model):
"""
A SmolaGents Model implementation that works with a server proxy.
This class is designed to be used in separate processes and
communicates with the Atropos server through a proxy mechanism.
"""
def __init__(
self,
server_proxy: ServerProxy,
use_chat_completion: bool = False,
model_id: str = None,
**kwargs,
):
self.server_proxy = server_proxy
self.use_chat_completion = use_chat_completion
# Automatically set chat completion for GPT models which require it
if model_id and any(
name in model_id.lower()
for name in ["gpt-4", "gpt-3.5-turbo", "claude", "gemini", "o", "llama"]
):
logger.info(
f"Model {model_id} detected as a chat model. Forcing chat completion API."
)
self.use_chat_completion = True
# Log the configuration
logger.info(
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, "
f"use_chat_completion={self.use_chat_completion}"
)
super().__init__(model_id=model_id, **kwargs)
def _prepare_completion_args(self, messages, stop_sequences=None, **kwargs):
"""
Convert SmolaGents message format to Atropos server parameters.
"""
# Always use chat completion if configured that way
if self.use_chat_completion:
# For chat completion, we format messages and don't use prompt
server_args = {
"messages": self._format_chat_messages(messages),
"max_tokens": kwargs.get("max_tokens", 2048),
"temperature": kwargs.get("temperature", 0.0),
"stop": stop_sequences,
}
logger.debug(
f"Prepared chat completion args: messages count={len(server_args['messages'])}"
)
return server_args
else:
# Extract the user message for completion API
prompt = self._extract_user_message(messages)
server_args = {
"prompt": prompt,
"max_tokens": kwargs.get("max_tokens", 2048),
"temperature": kwargs.get("temperature", 0.0),
"stop": stop_sequences,
}
logger.debug(f"Prepared completion args: prompt length={len(prompt)}")
return server_args
def _extract_user_message(self, messages):
"""Extract content from the last user message."""
for msg in reversed(messages):
# Handle both dict and ChatMessage objects
role = msg.role if hasattr(msg, "role") else msg["role"]
role_str = role.value if hasattr(role, "value") else str(role)
if role_str.lower() == "user":
content = msg.content if hasattr(msg, "content") else msg["content"]
if isinstance(content, list):
# Handle list format [{"type": "text", "text": "content"}]
return "\n".join(
item["text"] for item in content if item["type"] == "text"
)
return content
raise ValueError("No user message found")
def _format_chat_messages(self, messages):
"""Format messages for the chat completion API."""
formatted_messages = []
# For OpenAI API, we need to map roles to the ones they support
for i, msg in enumerate(messages):
# Handle both dict and ChatMessage objects
role = msg.role if hasattr(msg, "role") else msg["role"]
content = msg.content if hasattr(msg, "content") else msg["content"]
# Map any role to either system, user, or assistant
if isinstance(role, str):
role_str = role.lower()
elif hasattr(role, "value"):
role_str = str(role.value).lower()
else:
role_str = str(role).lower()
# Mapping to OpenAI roles
if role_str == "system":
openai_role = "system"
elif role_str == "user":
openai_role = "user"
elif role_str == "assistant":
openai_role = "assistant"
elif role_str in (
"tool_call",
"tool_response",
"function_call",
"function_response",
):
# Silently map tool and function calls/responses to user roles
openai_role = "user"
else:
# Default everything else to user without logging
openai_role = "user"
# Extract text content if it's in the list format
if isinstance(content, list):
text_content = "\n".join(
item["text"] for item in content if item["type"] == "text"
)
formatted_messages.append(
{"role": openai_role, "content": text_content}
)
else:
formatted_messages.append({"role": openai_role, "content": content})
return formatted_messages
def generate(
self,
messages: list[dict[str, str | list[dict]]],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list | None = None,
**kwargs,
) -> ChatMessage:
"""
Process the input messages and return the model's response.
Uses the server proxy to communicate with the Atropos server.
Parameters:
messages: A list of message dictionaries to be processed.
stop_sequences: A list of strings that will stop the generation if encountered.
grammar: The grammar or formatting structure to use (not used with Atropos).
tools_to_call_from: List of tools (not used with Atropos).
**kwargs: Additional keyword arguments for the server.
Returns:
ChatMessage: A chat message object containing the model's response.
"""
# Special handling for CodeAgent stop sequences
if stop_sequences is None:
stop_sequences = ["Observation:", "<end_code>", "Calling tools:"]
logger.info(
f"Generate called with {len(messages)} messages, use_chat_completion={self.use_chat_completion}"
)
# Prepare completion arguments
completion_kwargs = self._prepare_completion_args(
messages=messages, stop_sequences=stop_sequences, **kwargs
)
# Extract timeout from kwargs or use default (but not used in this method)
kwargs.pop("timeout", 120) # Default 2 minutes
try:
# Use chat_completion if configured
if self.use_chat_completion:
logger.info("Using chat completion API through proxy")
# Convert prompt to messages format if needed
if (
"prompt" in completion_kwargs
and "messages" not in completion_kwargs
):
logger.info("Converting prompt to messages format")
completion_kwargs["messages"] = [
{"role": "user", "content": completion_kwargs.pop("prompt")}
]
# Call chat_completion via proxy
resp = self.server_proxy.chat_completion(**completion_kwargs)
# Process response
if resp and hasattr(resp, "choices") and len(resp.choices) > 0:
content = resp.choices[0].message.content
logger.info(f"Got response with {len(content)} chars")
else:
content = "No response content"
logger.warning("No content found in response")
else:
logger.info("Using completion API through proxy")
# Call completion via proxy
resp = self.server_proxy.completion(**completion_kwargs)
# Process response
if resp and hasattr(resp, "choices") and len(resp.choices) > 0:
content = resp.choices[0].text
logger.info(f"Got response with {len(content)} chars")
else:
content = "No response content"
logger.warning("No content found in response")
# Track token usage
if hasattr(resp, "usage"):
self.last_input_token_count = resp.usage.prompt_tokens
self.last_output_token_count = resp.usage.completion_tokens
logger.info(
f"Token usage: input={self.last_input_token_count}, output={self.last_output_token_count}"
)
# Return result in SmolaGents format
logger.info("Successfully returning ChatMessage")
return ChatMessage(role=MessageRole.ASSISTANT, content=content, raw=resp)
except Exception as e:
# Provide more detailed error information
error_msg = f"Error during server proxy call: {type(e).__name__}: {str(e)}"
logger.error(error_msg)
# Print full stack trace for debugging
logger.error(f"Full traceback: {traceback.format_exc()}")
raise ValueError(error_msg)