fix imports and style issues

This commit is contained in:
Allan Niemerg 2025-05-27 11:00:35 -05:00
parent 7a653044a4
commit 013090579d
4 changed files with 176 additions and 146 deletions

View file

@ -12,22 +12,19 @@ from typing import Any, Dict
from smolagents import CodeAgent
# Import tools directly
from environments.smolagents_integration.server_proxy import ServerProxy
from environments.smolagents_integration.smolagents_model import ProcessSafeAtroposServerModel
from environments.smolagents_integration.tools.file_tools import (
append_to_file, read_file, write_file
)
from .server_proxy import ServerProxy
from .smolagents_model import ProcessSafeAtroposServerModel
from .tools.file_tools import append_to_file, read_file, write_file
# Conditionally import Tavily tools if API key is available
tavily_tools = []
if os.environ.get("TAVILY_API_KEY"):
try:
from environments.smolagents_integration.tools.tavily_tools import (
TavilyExtractTool, TavilySearchTool
)
from .tools.tavily_tools import TavilyExtractTool, TavilySearchTool
tavily_tools = [
TavilySearchTool(api_key=os.environ.get("TAVILY_API_KEY")),
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY"))
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY")),
]
except ImportError:
pass
@ -64,7 +61,9 @@ def run_agent_process(
try:
start_time = time.time()
process_id = os.getpid()
logger.info(f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}")
logger.info(
f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}"
)
# Create a model using the server proxy
model = ProcessSafeAtroposServerModel(
@ -85,14 +84,18 @@ def run_agent_process(
verbosity_level=agent_config.get("verbosity", 2),
)
logger.info(f"Process {process_id}: Running agent on prompt with {len(prompt)} chars")
logger.info(
f"Process {process_id}: Running agent on prompt with {len(prompt)} chars"
)
# Run the agent and get response
agent_response = agent.run(prompt)
# Ensure the response is properly formatted (convert sets, etc. to strings)
if not isinstance(agent_response, str):
logger.info(f"Converting non-string response of type {type(agent_response)} to string")
logger.info(
f"Converting non-string response of type {type(agent_response)} to string"
)
try:
if isinstance(agent_response, set):
# Convert sets to comma-separated strings
@ -103,7 +106,7 @@ def run_agent_process(
except Exception as e:
logger.error(f"Failed to convert agent_response to string: {e}")
agent_response = str(agent_response)
# Extract agent memory
agent_memory = getattr(agent, "memory", None)
if hasattr(agent, "write_memory_to_messages"):
@ -113,14 +116,16 @@ def run_agent_process(
execution_time = time.time() - start_time
# Prepare and send result
result_queue.put({
"status": "success",
"response": agent_response,
"task_id": task_metadata.get("task_id"),
"execution_time": execution_time,
"agent_memory": agent_memory,
"task_metadata": task_metadata,
})
result_queue.put(
{
"status": "success",
"response": agent_response,
"task_id": task_metadata.get("task_id"),
"execution_time": execution_time,
"agent_memory": agent_memory,
"task_metadata": task_metadata,
}
)
logger.info(f"Process {process_id}: Agent completed in {execution_time:.2f}s")
@ -128,14 +133,16 @@ def run_agent_process(
# Log the exception and put error result in queue
logger.error(f"Process {os.getpid()}: Error in agent execution: {e}")
logger.error(traceback.format_exc())
result_queue.put({
"status": "error",
"error_message": str(e),
"error_traceback": traceback.format_exc(),
"task_id": task_metadata.get("task_id"),
"task_metadata": task_metadata,
})
result_queue.put(
{
"status": "error",
"error_message": str(e),
"error_traceback": traceback.format_exc(),
"task_id": task_metadata.get("task_id"),
"task_metadata": task_metadata,
}
)
finally:
logger.info(f"Process {os.getpid()}: Cleanup complete")