fix imports and style issues

This commit is contained in:
Allan Niemerg 2025-05-27 11:00:35 -05:00
parent 7a653044a4
commit 013090579d
4 changed files with 176 additions and 146 deletions

View file

@ -78,7 +78,8 @@ python -m environments.smolagents_integration.smolagents_env process \
--env.max_concurrent_processes 8 \
--env.use_chat_completion true \
--openai.model_name "gpt-4o" \
--openai.base_url "https://api.openai.com/v1"
--openai.base_url "https://api.openai.com/v1" \
--openai.api_key x
```
```bash
@ -91,7 +92,8 @@ python -m environments.smolagents_integration.smolagents_env process \
--env.max_concurrent_processes 8 \
--env.use_chat_completion true \
--openai.model_name "gpt-4o" \
--openai.base_url "https://api.openai.com/v1"
--openai.base_url "https://api.openai.com/v1" \
--openai.api_key x
```
Note: The command syntax uses dots (`.`) to separate namespaces. Also, the OpenAI API key should be set in your environment variables as `OPENAI_API_KEY` or in a `.env` file in the project root.
@ -118,8 +120,8 @@ python -m environments.smolagents_integration.smolagents_env serve \
--env.use_chat_completion true \
--env.max_concurrent_processes 5 \
--env.group_size 8 \
--openai.model_name "gpt-4o" \
--openai.base_url "https://api.openai.com/v1"
--openai.model_name "your-model-name" \
--openai.base_url "http://localhost:8000/v1"
```
## How It Works
@ -204,4 +206,4 @@ Each task includes:
- **Web tool errors**: If Tavily tools aren't working, make sure you have set the `TAVILY_API_KEY` environment variable and have installed the `tavily-python` package.
- **Tool import errors**: If you see errors about missing tool modules, ensure your working directory allows proper imports of the tools folder.
- **Permission errors with file tools**: Ensure your process has the correct permissions to read/write files in the directories being accessed.
- **Memory issues**: If you encounter memory usage problems, try lowering the `max_concurrent_processes` parameter.
- **Memory issues**: If you encounter memory usage problems, try lowering the `max_concurrent_processes` parameter.

View file

@ -12,22 +12,19 @@ from typing import Any, Dict
from smolagents import CodeAgent
# Import tools directly
from environments.smolagents_integration.server_proxy import ServerProxy
from environments.smolagents_integration.smolagents_model import ProcessSafeAtroposServerModel
from environments.smolagents_integration.tools.file_tools import (
append_to_file, read_file, write_file
)
from .server_proxy import ServerProxy
from .smolagents_model import ProcessSafeAtroposServerModel
from .tools.file_tools import append_to_file, read_file, write_file
# Conditionally import Tavily tools if API key is available
tavily_tools = []
if os.environ.get("TAVILY_API_KEY"):
try:
from environments.smolagents_integration.tools.tavily_tools import (
TavilyExtractTool, TavilySearchTool
)
from .tools.tavily_tools import TavilyExtractTool, TavilySearchTool
tavily_tools = [
TavilySearchTool(api_key=os.environ.get("TAVILY_API_KEY")),
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY"))
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY")),
]
except ImportError:
pass
@ -64,7 +61,9 @@ def run_agent_process(
try:
start_time = time.time()
process_id = os.getpid()
logger.info(f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}")
logger.info(
f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}"
)
# Create a model using the server proxy
model = ProcessSafeAtroposServerModel(
@ -85,14 +84,18 @@ def run_agent_process(
verbosity_level=agent_config.get("verbosity", 2),
)
logger.info(f"Process {process_id}: Running agent on prompt with {len(prompt)} chars")
logger.info(
f"Process {process_id}: Running agent on prompt with {len(prompt)} chars"
)
# Run the agent and get response
agent_response = agent.run(prompt)
# Ensure the response is properly formatted (convert sets, etc. to strings)
if not isinstance(agent_response, str):
logger.info(f"Converting non-string response of type {type(agent_response)} to string")
logger.info(
f"Converting non-string response of type {type(agent_response)} to string"
)
try:
if isinstance(agent_response, set):
# Convert sets to comma-separated strings
@ -103,7 +106,7 @@ def run_agent_process(
except Exception as e:
logger.error(f"Failed to convert agent_response to string: {e}")
agent_response = str(agent_response)
# Extract agent memory
agent_memory = getattr(agent, "memory", None)
if hasattr(agent, "write_memory_to_messages"):
@ -113,14 +116,16 @@ def run_agent_process(
execution_time = time.time() - start_time
# Prepare and send result
result_queue.put({
"status": "success",
"response": agent_response,
"task_id": task_metadata.get("task_id"),
"execution_time": execution_time,
"agent_memory": agent_memory,
"task_metadata": task_metadata,
})
result_queue.put(
{
"status": "success",
"response": agent_response,
"task_id": task_metadata.get("task_id"),
"execution_time": execution_time,
"agent_memory": agent_memory,
"task_metadata": task_metadata,
}
)
logger.info(f"Process {process_id}: Agent completed in {execution_time:.2f}s")
@ -128,14 +133,16 @@ def run_agent_process(
# Log the exception and put error result in queue
logger.error(f"Process {os.getpid()}: Error in agent execution: {e}")
logger.error(traceback.format_exc())
result_queue.put({
"status": "error",
"error_message": str(e),
"error_traceback": traceback.format_exc(),
"task_id": task_metadata.get("task_id"),
"task_metadata": task_metadata,
})
result_queue.put(
{
"status": "error",
"error_message": str(e),
"error_traceback": traceback.format_exc(),
"task_id": task_metadata.get("task_id"),
"task_metadata": task_metadata,
}
)
finally:
logger.info(f"Process {os.getpid()}: Cleanup complete")

View file

@ -3,35 +3,21 @@ SmolagentsEnv - Environment for creating high-quality agent trajectories
for training language models using the SmolaGents agent framework.
"""
import asyncio
import json
import logging
import multiprocessing
import os
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from pydantic import BaseModel, Field
from smolagents import CodeAgent, Tool
from smolagents.tools import tool # Import the tool decorator
from pydantic import Field
import wandb
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from environments.smolagents_integration.agent_process_runner import run_agent_process
from environments.smolagents_integration.server_proxy import ServerProxyManager
from environments.smolagents_integration.tools.file_tools import (
append_to_file,
read_file,
write_file,
)
from .agent_process_runner import run_agent_process
from .server_proxy import ServerProxyManager
@dataclass
@ -73,7 +59,6 @@ class SmolagentsEnvConfig(BaseEnvConfig):
default="combined",
description="Scoring strategy: basic, correctness, or combined",
)
# Removed max_concurrent_agents as we only use process-based execution
length_penalty_weight: float = Field(
default=0.1, description="Weight for length penalty in scoring (0.0 to disable)"
)
@ -240,8 +225,6 @@ class SmolagentsEnv(BaseEnv):
logger.info("SmolagentsEnv setup complete")
# _create_tools method removed - tools are created directly in agent_process_runner.py
async def get_next_item(self) -> Item:
"""Get the next item from the GAIA dataset."""
if not self.examples:
@ -414,9 +397,9 @@ class SmolagentsEnv(BaseEnv):
)
else:
# Just log the error and omit this example from training
error_message = result.get('error_message', 'Unknown error')
task_id = result.get('task_id', 'Unknown task')
error_message = result.get("error_message", "Unknown error")
task_id = result.get("task_id", "Unknown task")
logger.warning(
f"Omitting failed task {task_id} from training batch: {error_message}"
)
@ -493,7 +476,7 @@ class SmolagentsEnv(BaseEnv):
"""
Score the agent trajectory based on multiple criteria:
- Answer correctness using GAIA scoring
- Message format adherence
- Message format adherence
- Final answer tool usage
- Execution success (detection of errors)
- Efficiency (steps only)
@ -504,30 +487,35 @@ class SmolagentsEnv(BaseEnv):
true_answer: The ground truth answer
agent_memory: The memory trace of the agent's steps
execution_time: Time taken for execution (not used in scoring)
Returns:
float: A score between 0.0 and 1.0
"""
# Import all scoring functions upfront
from environments.smolagents_integration.evaluations.smolagent_integrations.rubrics.gaia_scorer import (
question_scorer, check_close_call
from .evaluations.smolagent_integrations.rubrics.gaia_scorer import (
check_close_call,
question_scorer,
)
from environments.smolagents_integration.evaluations.smolagent_integrations.smolagents_scorer import (
check_format_adherence, check_final_answer_usage,
calculate_execution_score, calculate_efficiency_score
from .evaluations.smolagent_integrations.smolagents_scorer import (
calculate_efficiency_score,
calculate_execution_score,
check_final_answer_usage,
check_format_adherence,
)
# Initialize component scores
correctness_score = 0.0
format_score = 0.0
final_answer_score = 0.0
execution_score = 0.0
efficiency_score = 0.0
# 1. Calculate correctness score using GAIA scorer
# Ensure agent_response is a string before scoring
if not isinstance(agent_response, str):
logger.warning(f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}")
logger.warning(
f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}"
)
try:
if isinstance(agent_response, set):
# Convert sets to comma-separated strings
@ -538,15 +526,15 @@ class SmolagentsEnv(BaseEnv):
except Exception as e:
logger.error(f"Failed to convert agent_response to string: {e}")
agent_response = ""
is_correct = question_scorer(agent_response, true_answer)
is_near_correct = check_close_call(agent_response, true_answer, is_correct)
if is_correct:
correctness_score = 1.0
elif is_near_correct:
correctness_score = 0.5
# 2. Check format adherence
if agent_memory:
format_scores = []
@ -555,10 +543,12 @@ class SmolagentsEnv(BaseEnv):
format_scores.append(check_format_adherence(step["content"]))
elif "model_output" in step and isinstance(step["model_output"], str):
format_scores.append(check_format_adherence(step["model_output"]))
# Average the format scores across all steps
format_score = sum(format_scores) / len(format_scores) if format_scores else 0.0
format_score = (
sum(format_scores) / len(format_scores) if format_scores else 0.0
)
# 3. Check for final_answer tool usage
final_answer_used = False
for step in agent_memory:
@ -570,41 +560,44 @@ class SmolagentsEnv(BaseEnv):
if check_final_answer_usage(step["model_output"]):
final_answer_used = True
break
final_answer_score = 1.0 if final_answer_used else 0.0
# 4. Check for execution errors and calculate execution score
execution_score = calculate_execution_score(agent_memory)
# 5. Calculate efficiency score
steps_count = len(agent_memory)
efficiency_score = calculate_efficiency_score(
steps_count=steps_count,
max_steps=self.max_steps
steps_count=steps_count, max_steps=self.max_steps
)
# Component weights - can be adjusted to emphasize different aspects
correctness_weight = 0.50 # 50% of score - correctness matters most
format_weight = 0.20 # 20% - format adherence is important
format_weight = 0.20 # 20% - format adherence is important
final_answer_weight = 0.10 # 10% - using final_answer tool properly
execution_weight = 0.10 # 10% - avoiding errors
efficiency_weight = 0.10 # 10% - being efficient
execution_weight = 0.10 # 10% - avoiding errors
efficiency_weight = 0.10 # 10% - being efficient
# Calculate combined score with weights
combined_score = (
correctness_score * correctness_weight +
format_score * format_weight +
final_answer_score * final_answer_weight +
execution_score * execution_weight +
efficiency_score * efficiency_weight
correctness_score * correctness_weight
+ format_score * format_weight
+ final_answer_score * final_answer_weight
+ execution_score * execution_weight
+ efficiency_score * efficiency_weight
)
# Apply length penalty if configured
length_penalty = 0.0
if self.config.length_penalty_weight > 0:
# agent_response should already be a string from the previous conversion
# but double-check just to be sure
response_to_measure = agent_response if isinstance(agent_response, str) else str(agent_response)
response_to_measure = (
agent_response
if isinstance(agent_response, str)
else str(agent_response)
)
response_length = len(response_to_measure)
# Penalize very long responses
if response_length > 2000:
@ -613,57 +606,75 @@ class SmolagentsEnv(BaseEnv):
self.config.length_penalty_weight * (response_length - 2000) / 1000,
)
combined_score = max(0.0, combined_score - length_penalty)
# Debug logging for score calculation
if self.debug_scoring:
logger.info("=== SCORE CALCULATION (detailed) ===")
logger.info(f"1. Correctness component:")
logger.info("1. Correctness component:")
logger.info(f" - True answer: '{true_answer}'")
logger.info(f" - Agent answer: '{agent_response}'")
logger.info(f" - Is correct: {is_correct}")
logger.info(f" - Is near correct: {is_near_correct}")
logger.info(f" - Raw score: {correctness_score:.3f}")
logger.info(f" - Weight: {correctness_weight}")
logger.info(f" - Weighted score: {correctness_score * correctness_weight:.3f}")
logger.info(f"2. Format adherence component:")
logger.info(f" - Format scores: {format_scores if 'format_scores' in locals() else []}")
logger.info(
f" - Weighted score: {correctness_score * correctness_weight:.3f}"
)
logger.info("2. Format adherence component:")
logger.info(
f" - Format scores: {format_scores if 'format_scores' in locals() else []}"
)
logger.info(f" - Raw score: {format_score:.3f}")
logger.info(f" - Weight: {format_weight}")
logger.info(f" - Weighted score: {format_score * format_weight:.3f}")
logger.info(f"3. Final answer tool usage:")
logger.info(f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}")
logger.info("3. Final answer tool usage:")
logger.info(
f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}"
)
logger.info(f" - Raw score: {final_answer_score:.3f}")
logger.info(f" - Weight: {final_answer_weight}")
logger.info(f" - Weighted score: {final_answer_score * final_answer_weight:.3f}")
logger.info(f"4. Execution component:")
logger.info(
f" - Weighted score: {final_answer_score * final_answer_weight:.3f}"
)
logger.info("4. Execution component:")
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
logger.info(f" - Raw score: {execution_score:.3f}")
logger.info(f" - Weight: {execution_weight}")
logger.info(f" - Weighted score: {execution_score * execution_weight:.3f}")
logger.info(f"5. Efficiency component:")
logger.info(
f" - Weighted score: {execution_score * execution_weight:.3f}"
)
logger.info("5. Efficiency component:")
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
logger.info(f" - Max steps: {self.max_steps}")
logger.info(f" - Raw score: {efficiency_score:.3f}")
logger.info(f" - Weight: {efficiency_weight}")
logger.info(f" - Weighted score: {efficiency_score * efficiency_weight:.3f}")
logger.info(f"6. Length penalty:")
logger.info(f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}")
logger.info(
f" - Weighted score: {efficiency_score * efficiency_weight:.3f}"
)
logger.info("6. Length penalty:")
logger.info(
f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}"
)
logger.info(f" - Penalty: {length_penalty:.3f}")
logger.info(f"7. Final score calculation:")
logger.info(f" - Correctness: {correctness_score * correctness_weight:.3f}")
logger.info("7. Final score calculation:")
logger.info(
f" - Correctness: {correctness_score * correctness_weight:.3f}"
)
logger.info(f" - Format adherence: {format_score * format_weight:.3f}")
logger.info(f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}")
logger.info(
f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}"
)
logger.info(f" - Execution: {execution_score * execution_weight:.3f}")
logger.info(f" - Efficiency: {efficiency_score * efficiency_weight:.3f}")
logger.info(f" - Length penalty: -{length_penalty:.3f}")
logger.info(f" - FINAL SCORE: {combined_score:.3f}")
return combined_score
def _create_scored_data_group(
@ -704,76 +715,82 @@ class SmolagentsEnv(BaseEnv):
# Create a comprehensive markdown document for HTML compatibility
# Format similar to the Wikipedia environment
complete_conversation = []
# Add task information at the top
task_type = item.metadata.get("task", "Unknown task")
task_id = item.metadata.get("task_id", "Unknown ID")
complete_conversation.append(f"# GAIA Task: {task_type} (ID: {task_id})")
# Process each message in the conversation
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
# Skip empty messages
if not content:
continue
# Add a header for each role
complete_conversation.append(f"## {role.upper()}")
# Handle different content formats
if isinstance(content, list):
content_text = []
for item in content:
if isinstance(item, dict) and 'text' in item:
content_text.append(item['text'])
if isinstance(item, dict) and "text" in item:
content_text.append(item["text"])
else:
content_text.append(str(item))
content = '\n'.join(content_text)
content = "\n".join(content_text)
elif not isinstance(content, str):
# Handle non-string content
content = str(content)
# Add the message content
complete_conversation.append(content)
# If this message has agent memory, show it
if role == "assistant" and i > 1 and self.config.save_full_traces:
# Agent memory with thinking is often present in assistant messages
if "<thinking>" in content:
complete_conversation.append("### Thinking Process")
# The thinking is already in the content
# Add tool usage information if present
if "tool_usage" in msg:
tool_name = msg.get("tool_usage", {}).get("name", "unknown tool")
tool_name = msg.get("tool_usage", {}).get(
"name", "unknown tool"
)
complete_conversation.append(f"### 🛠️ Tool Used: {tool_name}")
tool_args = msg.get("tool_usage", {}).get("args", {})
if tool_args:
complete_conversation.append("```json")
complete_conversation.append(json.dumps(tool_args, indent=2))
complete_conversation.append(
json.dumps(tool_args, indent=2)
)
complete_conversation.append("```")
tool_result = msg.get("tool_usage", {}).get("result", None)
if tool_result:
complete_conversation.append("### Tool Result")
complete_conversation.append("```")
complete_conversation.append(str(tool_result))
complete_conversation.append("```")
# Add score information at the end
complete_conversation.append(f"\n## Score: {scored_data['score']:.4f}")
# Join everything into a single string with double newlines between sections
full_conversation_markdown = "\n\n".join(complete_conversation)
# Create the ScoredDataGroup with a single comprehensive markdown document
scored_group = ScoredDataGroup(
tokens=[self.tokenizer.encode(json.dumps(messages))],
masks=[[1] * len(self.tokenizer.encode(json.dumps(messages)))],
scores=[scored_data["score"]],
messages=[full_conversation_markdown], # Use single markdown document for HTML
messages=[
full_conversation_markdown
], # Use single markdown document for HTML
_original_messages=[messages], # Keep original for trainer API
)
@ -787,16 +804,16 @@ class SmolagentsEnv(BaseEnv):
{"role": "user", "content": item.prompt},
{"role": "assistant", "content": scored_data["response"]},
]
# Use the standard tokenize_for_trainer utility
# Tokenize using the standard utility (only trains on assistant messages)
tokenized = tokenize_for_trainer(
self.tokenizer,
self.tokenizer,
messages,
train_on_all_assistant_turns=True # Train on all assistant turns if present
train_on_all_assistant_turns=True, # Train on all assistant turns if present
)
# Create the ScoredDataGroup
scored_group = ScoredDataGroup(
tokens=[tokenized["tokens"]],
@ -807,7 +824,6 @@ class SmolagentsEnv(BaseEnv):
return scored_group
async def evaluate(self, **kwargs):
"""
Evaluate the agent on a subset of the GAIA benchmark.

View file

@ -4,11 +4,10 @@ Process-safe implementation of the AtroposServerModel for SmolaGents.
import logging
import traceback
from typing import Any, Dict, List, Optional
from smolagents.models import ChatMessage, MessageRole, Model
from environments.smolagents_integration.server_proxy import ServerProxy
from .server_proxy import ServerProxy
# Configure logger for the model class
logger = logging.getLogger(__name__)
@ -45,7 +44,8 @@ class ProcessSafeAtroposServerModel(Model):
# Log the configuration
logger.info(
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, use_chat_completion={self.use_chat_completion}"
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, "
f"use_chat_completion={self.use_chat_completion}"
)
super().__init__(model_id=model_id, **kwargs)
@ -117,7 +117,12 @@ class ProcessSafeAtroposServerModel(Model):
openai_role = "user"
elif role_str == "assistant":
openai_role = "assistant"
elif role_str in ("tool_call", "tool_response", "function_call", "function_response"):
elif role_str in (
"tool_call",
"tool_response",
"function_call",
"function_response",
):
# Silently map tool and function calls/responses to user roles
openai_role = "user"
else:
@ -172,8 +177,8 @@ class ProcessSafeAtroposServerModel(Model):
messages=messages, stop_sequences=stop_sequences, **kwargs
)
# Extract timeout from kwargs or use default
timeout = kwargs.pop("timeout", 120) # Default 2 minutes
# Extract timeout from kwargs or use default (but not used in this method)
kwargs.pop("timeout", 120) # Default 2 minutes
try:
# Use chat_completion if configured