atropos/environments/smolagents_integration/smolagents_env.py
2025-09-30 14:03:43 +00:00

1086 lines
42 KiB
Python

"""
SmolagentsEnv - Environment for creating high-quality agent trajectories
for training language models using the SmolaGents agent framework.
"""
import json
import logging
import multiprocessing
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .agent_process_runner import run_agent_process
from .server_proxy import ServerProxyManager
@dataclass
class Item:
prompt: str
metadata: Dict[str, Any]
id: Optional[str] = None
# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Prevent propagation to root logger to avoid duplicate logging
logger.propagate = False
# Only add handler if not already present
if not logger.handlers:
# Add a console handler to make sure logs are visible
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SmolagentsEnvConfig(BaseEnvConfig):
"""Configuration for SmolagentsEnv."""
dataset_path: str = Field(default="data/gaia", description="Path to GAIA dataset")
split: str = Field(
default="validation", description="Dataset split to use (validation, test)"
)
use_chat_completion: bool = Field(
default=True, description="Use chat completion API"
)
max_steps: int = Field(default=12, description="Maximum number of agent steps")
agent_verbosity: int = Field(default=2, description="Agent verbosity level (0-3)")
scoring_strategy: str = Field(
default="combined",
description="Scoring strategy: basic, correctness, or combined",
)
length_penalty_weight: float = Field(
default=0.1, description="Weight for length penalty in scoring (0.0 to disable)"
)
save_full_traces: bool = Field(
default=True, description="Save full agent execution traces in the output"
)
# Output path configured in __init__
data_path_to_save_groups: Optional[str] = Field(
default=None,
description="Path to save JSONL output (defaults to timestamped file if None)",
)
# Process-based settings
max_concurrent_processes: int = Field(
default=5,
description="Maximum number of concurrent processes for agent execution",
)
process_timeout: int = Field(
default=240, # 4 minutes by default
description="Timeout for agent processes in seconds",
)
# Debugging options
debug_scoring: bool = Field(
default=False, description="Enable detailed score calculation logging"
)
class SmolagentsEnv(BaseEnv):
"""
Environment for generating high-quality agent trajectories using the SmolaGents framework.
This environment:
1. Loads tasks from the GAIA benchmark dataset
2. Uses SmolaGents CodeAgent with appropriate tools
3. Scores trajectories based on correctness and reasoning quality
4. Integrates with Atropos SFT generation pipeline
"""
name = "smolagents"
env_config_cls = SmolagentsEnvConfig
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
"""Initialize the config for CLI use."""
env_config = SmolagentsEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=32,
steps_per_eval=100,
max_token_length=4096,
wandb_name="smolagents",
include_messages=True,
# Process-based settings
max_concurrent_processes=8,
process_timeout=240,
# Common settings
dataset_path="data/gaia",
split="validation", # GAIA only supports 'validation' and 'test' splits
use_chat_completion=True,
# Debugging options
debug_scoring=False, # Set to True to enable detailed score logging
# Using default timestamped output path from the config definition
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=32,
),
]
return env_config, server_configs
def __init__(
self,
config: SmolagentsEnvConfig,
server_configs: Union[List[APIServerConfig], APIServerConfig],
slurm=False,
testing=False,
):
# Set a timestamped output file path if not provided
if config.data_path_to_save_groups is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config.data_path_to_save_groups = f"smolagents_output_{timestamp}.jsonl"
print(
f"Using auto-generated output path: {config.data_path_to_save_groups}"
)
# Initialize the base class
super().__init__(config, server_configs, slurm, testing)
# Initialize dataset variables
self.examples = []
self.current_index = 0
self.iter = 0 # Add iter for checkpoint tracking
# Initialize the server proxy manager for process-based execution
self.server_proxy_manager = None # Will be initialized in setup()
# Save config for easier access
self.max_steps = config.max_steps
self.verbosity = config.agent_verbosity
self.scoring_strategy = config.scoring_strategy
self.debug_scoring = config.debug_scoring
# Track agent execution times and metrics
self.agent_execution_times = []
self.percent_correct_buffer = []
self.eval_metrics = []
async def setup(self):
"""Set up the environment, load dataset, and initialize server components."""
logger.info("Setting up SmolagentsEnv...")
logger.info(f"Using dataset split: {self.config.split}")
# Initialize the server proxy manager
logger.info("Setting up process-based isolation for agent execution")
self.server_proxy_manager = ServerProxyManager(
server=self.server, max_workers=self.config.max_concurrent_processes
)
self.server_proxy_manager.start()
logger.info(
f"Started server proxy manager with max_workers={self.config.max_concurrent_processes}"
)
# Load the GAIA dataset
try:
import os
# Check if dataset exists
dataset_path = self.config.dataset_path
validation_path = os.path.join(
dataset_path, "2023", "validation", "metadata.jsonl"
)
gaia_py_path = os.path.join(dataset_path, "GAIA.py")
# If dataset files are missing, try to download them
if not os.path.exists(validation_path) or not os.path.exists(gaia_py_path):
logger.info(
f"GAIA dataset not found at {dataset_path}, attempting to download..."
)
from .download_gaia import download_gaia_dataset
download_success = download_gaia_dataset(dataset_path)
if not download_success:
logger.error("Failed to download GAIA dataset automatically.")
logger.error(
"Please run: python -m environments.smolagents_integration.download_gaia"
)
self.examples = []
return
else:
logger.info(
f"GAIA dataset downloaded successfully to {dataset_path}"
)
logger.info(
f"Loading GAIA dataset directly from {self.config.dataset_path}"
)
# Load the metadata.jsonl file directly instead of using the datasets library
import json
metadata_path = os.path.join(
self.config.dataset_path, "2023", self.config.split, "metadata.jsonl"
)
logger.info(f"Reading metadata from: {metadata_path}")
# Check if the file exists
if not os.path.exists(metadata_path):
logger.error(f"Metadata file not found: {metadata_path}")
self.examples = []
return
# Read the metadata file directly
self.examples = []
with open(metadata_path, "r") as f:
for i, line in enumerate(f):
try:
example = json.loads(line)
self.examples.append(
{
"question": example["Question"],
"true_answer": example["Final answer"],
"task": example["Level"],
"task_id": (
example["task_id"]
if "task_id" in example
else f"task_{i}"
),
"file_name": (
os.path.join(
self.config.dataset_path,
"2023",
self.config.split,
example["file_name"],
)
if example.get("file_name")
else ""
),
}
)
except Exception as parse_error:
logger.error(
f"Error parsing line {i} of metadata file: {parse_error}"
)
continue
logger.info(
f"Loaded {len(self.examples)} examples from GAIA {self.config.split} set"
)
except Exception as e:
import traceback
logger.error(f"Error loading GAIA dataset: {type(e).__name__}: {e}")
logger.error(f"Detailed traceback: {traceback.format_exc()}")
logger.error(
"Please run: python -m environments.smolagents_integration.download_gaia"
)
# Create empty list if dataset loading fails
self.examples = []
logger.info("SmolagentsEnv setup complete")
async def get_next_item(self) -> Item:
"""Get the next item from the GAIA dataset."""
if not self.examples:
logger.warning("No examples loaded in dataset")
return None
# Use iter to track position and support checkpointing
example = self.examples[self.iter % len(self.examples)]
self.iter += 1
self.current_index = self.iter % len(self.examples)
# Construct the prompt
prompt = example["question"]
# Add file information if available
if example.get("file_name"):
prompt += f"\n\nTo solve this task, you can use the file at: {example['file_name']}"
# Create an Item object
item = Item(
prompt=prompt,
metadata={
"task_id": example["task_id"],
"task": example["task"],
"true_answer": example["true_answer"],
"file_name": example.get("file_name", ""),
"dataset_idx": self.current_index,
},
)
return item
async def collect_trajectories(self, items: Union[Item, List[Item]]) -> Tuple[
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any]],
List[Item],
]:
"""
Collect trajectories for multiple items using process-based parallelism.
"""
# Handle both single item and list of items
if not isinstance(items, list):
items = [items] * self.config.group_size
return await self._collect_trajectories_process_based(items)
async def _collect_trajectories_process_based(
self, items: List[Item]
) -> Tuple[List[Any], List[Item]]:
"""
Collect trajectories for multiple items using process-based parallelism.
"""
logger.info(
f"Collecting trajectories for {len(items)} items using process-based parallelism"
)
# Create a manager for shared objects
manager = multiprocessing.Manager()
result_queue = manager.Queue()
# Create agent config dictionary
agent_config = {
"max_steps": self.max_steps,
"verbosity": self.verbosity,
"use_chat_completion": self.config.use_chat_completion,
"model_name": getattr(self.server, "model_name", "unknown-model"),
}
# Start processes for each item
processes = []
proxies = []
for item in items:
# Create a server proxy for this process
server_proxy, proxy_id = self.server_proxy_manager.create_server_proxy(
model_name=agent_config["model_name"],
timeout=self.config.process_timeout,
)
proxies.append(proxy_id)
# Start a process for this item
process = multiprocessing.Process(
target=run_agent_process,
args=(
item.prompt,
item.metadata,
server_proxy,
agent_config,
result_queue,
),
)
process.start()
processes.append(process)
logger.info(f"Started {len(processes)} agent processes")
# Wait for all processes to complete or timeout
for process in processes:
process.join(timeout=self.config.process_timeout)
# Check if process is still alive (timeout)
if process.is_alive():
logger.warning(f"Process {process.pid} timed out, terminating")
process.terminate()
process.join()
# Clean up proxies
for proxy_id in proxies:
self.server_proxy_manager.remove_proxy(proxy_id)
# Get all results from the queue
results = []
while not result_queue.empty():
try:
result = result_queue.get(block=False)
results.append(result)
except Exception as e:
logger.error(f"Error getting result from queue: {e}")
break
logger.info(f"Collected {len(results)} results from processes")
# Process results
backlog = []
to_postprocess = []
for result in results:
if result["status"] == "success":
# Create scored data from successful result
scored_data = {
"prompt": result["task_metadata"].get("prompt", ""),
"response": result["response"],
"task_id": result["task_id"],
"task": result["task_metadata"].get("task", ""),
"true_answer": result["task_metadata"].get("true_answer", ""),
"execution_time": result["execution_time"],
}
# Add agent memory if configured
if self.config.save_full_traces and "agent_memory" in result:
scored_data["agent_memory"] = result["agent_memory"]
# Score the trajectory
score = self._score_trajectory(
scored_data["prompt"],
scored_data["response"],
scored_data["true_answer"],
scored_data.get("agent_memory"),
scored_data["execution_time"],
)
scored_data["score"] = score
# Create ScoredDataGroup
item_for_scoring = next(
(
i
for i in items
if i.metadata.get("task_id") == result["task_id"]
),
None,
)
if item_for_scoring:
scored_group = self._create_scored_data_group(
item_for_scoring, scored_data
)
to_postprocess.append(scored_group)
else:
logger.warning(
f"Could not find original item for task_id {result['task_id']}"
)
else:
# Just log the error and omit this example from training
error_message = result.get("error_message", "Unknown error")
task_id = result.get("task_id", "Unknown task")
logger.warning(
f"Omitting failed task {task_id} from training batch: {error_message}"
)
# Return processed results
logger.info(f"Final to_postprocess: len={len(to_postprocess)}")
return to_postprocess, backlog
async def postprocess_histories(
self, histories: Union[ScoredDataGroup, List[ScoredDataGroup]]
) -> ScoredDataGroup:
"""
Post-process the agent histories.
We need to merge multiple ScoredDataGroups into a single ScoredDataGroup.
"""
logger.info(
f"postprocess_histories called with: type={type(histories)}, is_none={histories is None}"
)
if isinstance(histories, list):
logger.info(f" List length: {len(histories)}")
if not isinstance(histories, list):
# If it's already a single ScoredDataGroup, return it with group_overrides
logger.info(f" Single history, returning directly: {type(histories)}")
if (
"group_overrides" not in histories
or histories["group_overrides"] is None
):
histories["group_overrides"] = {}
return histories
# If we have multiple ScoredDataGroups, merge them
logger.info(f" Merging {len(histories)} histories")
merged = ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
advantages=None,
ref_logprobs=None,
messages=[] if self.config.include_messages else None,
group_overrides={},
overrides=None,
)
# Merge all the fields
for i, history in enumerate(histories):
logger.info(
f" Processing history {i}: type={type(history)}, is_none={history is None}"
)
if history is not None:
logger.info(f" History {i} tokens: {len(history['tokens'])}")
merged["tokens"].extend(history["tokens"])
merged["masks"].extend(history["masks"])
merged["scores"].extend(history["scores"])
if merged["messages"] is not None and "messages" in history:
logger.info(f" History {i} messages: {len(history['messages'])}")
merged["messages"].extend(history["messages"])
logger.info(
f" Final merged data: tokens={len(merged['tokens'])}, scores={len(merged['scores'])}"
)
return merged
def _score_trajectory(
self,
prompt: str,
agent_response: str,
true_answer: str,
agent_memory: List[Dict] = None,
execution_time: float = 0, # Parameter kept for backward compatibility but not used in scoring
) -> float:
"""
Score the agent trajectory based on multiple criteria:
- Answer correctness using GAIA scoring
- Message format adherence
- Final answer tool usage
- Execution success (detection of errors)
- Efficiency (steps only)
Args:
prompt: The original task prompt
agent_response: The final response from the agent
true_answer: The ground truth answer
agent_memory: The memory trace of the agent's steps
execution_time: Time taken for execution (not used in scoring)
Returns:
float: A score between 0.0 and 1.0
"""
# Import all scoring functions upfront
from .evaluations.smolagent_integrations.rubrics.gaia_scorer import (
check_close_call,
question_scorer,
)
from .evaluations.smolagent_integrations.smolagents_scorer import (
calculate_efficiency_score,
calculate_execution_score,
check_final_answer_usage,
check_format_adherence,
)
# Initialize component scores
correctness_score = 0.0
format_score = 0.0
final_answer_score = 0.0
execution_score = 0.0
efficiency_score = 0.0
# 1. Calculate correctness score using GAIA scorer
# Ensure agent_response is a string before scoring
if not isinstance(agent_response, str):
logger.warning(
f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}"
)
try:
if isinstance(agent_response, set):
# Convert sets to comma-separated strings
agent_response = ", ".join(str(item) for item in agent_response)
else:
# Try to convert other types to string
agent_response = str(agent_response)
except Exception as e:
logger.error(f"Failed to convert agent_response to string: {e}")
agent_response = ""
is_correct = question_scorer(agent_response, true_answer)
is_near_correct = check_close_call(agent_response, true_answer, is_correct)
if is_correct:
correctness_score = 1.0
elif is_near_correct:
correctness_score = 0.5
# 2. Check format adherence
if agent_memory:
format_scores = []
for step in agent_memory:
# Handle both dict and ChatMessage objects
content = None
if hasattr(step, "content"):
content = step.content
elif isinstance(step, dict):
content = step.get("content") or step.get("model_output")
if content and isinstance(content, str):
format_scores.append(check_format_adherence(content))
# Average the format scores across all steps
format_score = (
sum(format_scores) / len(format_scores) if format_scores else 0.0
)
# 3. Check for final_answer tool usage
final_answer_used = False
for step in agent_memory:
# Handle both dict and ChatMessage objects
content = None
if hasattr(step, "content"):
content = step.content
elif isinstance(step, dict):
content = step.get("content") or step.get("model_output")
if content and isinstance(content, str):
if check_final_answer_usage(content):
final_answer_used = True
break
final_answer_score = 1.0 if final_answer_used else 0.0
# 4. Check for execution errors and calculate execution score
execution_score = calculate_execution_score(agent_memory)
# 5. Calculate efficiency score
steps_count = len(agent_memory)
efficiency_score = calculate_efficiency_score(
steps_count=steps_count, max_steps=self.max_steps
)
# Component weights - can be adjusted to emphasize different aspects
correctness_weight = 0.50 # 50% of score - correctness matters most
format_weight = 0.20 # 20% - format adherence is important
final_answer_weight = 0.10 # 10% - using final_answer tool properly
execution_weight = 0.10 # 10% - avoiding errors
efficiency_weight = 0.10 # 10% - being efficient
# Calculate combined score with weights
combined_score = (
correctness_score * correctness_weight
+ format_score * format_weight
+ final_answer_score * final_answer_weight
+ execution_score * execution_weight
+ efficiency_score * efficiency_weight
)
# Apply length penalty if configured
length_penalty = 0.0
if self.config.length_penalty_weight > 0:
# agent_response should already be a string from the previous conversion
# but double-check just to be sure
response_to_measure = (
agent_response
if isinstance(agent_response, str)
else str(agent_response)
)
response_length = len(response_to_measure)
# Penalize very long responses
if response_length > 2000:
length_penalty = min(
0.3,
self.config.length_penalty_weight * (response_length - 2000) / 1000,
)
combined_score = max(0.0, combined_score - length_penalty)
# Debug logging for score calculation
if self.debug_scoring:
logger.info("=== SCORE CALCULATION (detailed) ===")
logger.info("1. Correctness component:")
logger.info(f" - True answer: '{true_answer}'")
logger.info(f" - Agent answer: '{agent_response}'")
logger.info(f" - Is correct: {is_correct}")
logger.info(f" - Is near correct: {is_near_correct}")
logger.info(f" - Raw score: {correctness_score:.3f}")
logger.info(f" - Weight: {correctness_weight}")
logger.info(
f" - Weighted score: {correctness_score * correctness_weight:.3f}"
)
logger.info("2. Format adherence component:")
logger.info(
f" - Format scores: {format_scores if 'format_scores' in locals() else []}"
)
logger.info(f" - Raw score: {format_score:.3f}")
logger.info(f" - Weight: {format_weight}")
logger.info(f" - Weighted score: {format_score * format_weight:.3f}")
logger.info("3. Final answer tool usage:")
logger.info(
f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}"
)
logger.info(f" - Raw score: {final_answer_score:.3f}")
logger.info(f" - Weight: {final_answer_weight}")
logger.info(
f" - Weighted score: {final_answer_score * final_answer_weight:.3f}"
)
logger.info("4. Execution component:")
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
logger.info(f" - Raw score: {execution_score:.3f}")
logger.info(f" - Weight: {execution_weight}")
logger.info(
f" - Weighted score: {execution_score * execution_weight:.3f}"
)
logger.info("5. Efficiency component:")
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
logger.info(f" - Max steps: {self.max_steps}")
logger.info(f" - Raw score: {efficiency_score:.3f}")
logger.info(f" - Weight: {efficiency_weight}")
logger.info(
f" - Weighted score: {efficiency_score * efficiency_weight:.3f}"
)
logger.info("6. Length penalty:")
logger.info(
f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}"
)
logger.info(f" - Penalty: {length_penalty:.3f}")
logger.info("7. Final score calculation:")
logger.info(
f" - Correctness: {correctness_score * correctness_weight:.3f}"
)
logger.info(f" - Format adherence: {format_score * format_weight:.3f}")
logger.info(
f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}"
)
logger.info(f" - Execution: {execution_score * execution_weight:.3f}")
logger.info(f" - Efficiency: {efficiency_score * efficiency_weight:.3f}")
logger.info(f" - Length penalty: -{length_penalty:.3f}")
logger.info(f" - FINAL SCORE: {combined_score:.3f}")
return combined_score
def _create_scored_data_group(
self, item: Item, scored_data: Dict
) -> ScoredDataGroup:
"""
Create a ScoredDataGroup for the trainer API.
Converts the agent trajectory into tokenized format for the trainer.
"""
# Prepare the data in message format or token format
if self.config.include_messages:
# Create message format with agent memory if available
messages = []
# Add system message with task description
messages.append(
{
"role": "system",
"content": "You are an AI assistant solving a task with reasoning and problem-solving skills.",
}
)
# Add user message with the prompt
messages.append({"role": "user", "content": item.prompt})
# For message format, extract agent memory if available
if self.config.save_full_traces and "agent_memory" in scored_data:
# Add intermediate reasoning steps
for message in scored_data["agent_memory"]:
# Handle both dict and ChatMessage objects
if hasattr(message, "role") and hasattr(message, "content"):
# Convert ChatMessage to dict
role = (
message.role.value
if hasattr(message.role, "value")
else str(message.role)
)
messages.append({"role": role, "content": message.content})
elif isinstance(message, dict):
messages.append(message)
else:
# Unknown format, try to convert to dict
messages.append({"role": "assistant", "content": str(message)})
else:
# Just add the final response
messages.append(
{"role": "assistant", "content": scored_data["response"]}
)
# Create a comprehensive markdown document for HTML compatibility
# Format similar to the Wikipedia environment
complete_conversation = []
# Add task information at the top
task_type = item.metadata.get("task", "Unknown task")
task_id = item.metadata.get("task_id", "Unknown ID")
complete_conversation.append(f"# GAIA Task: {task_type} (ID: {task_id})")
# Process each message in the conversation
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
# Skip empty messages
if not content:
continue
# Add a header for each role
complete_conversation.append(f"## {role.upper()}")
# Handle different content formats
if isinstance(content, list):
content_text = []
for item in content:
if isinstance(item, dict) and "text" in item:
content_text.append(item["text"])
else:
content_text.append(str(item))
content = "\n".join(content_text)
elif not isinstance(content, str):
# Handle non-string content
content = str(content)
# Add the message content
complete_conversation.append(content)
# If this message has agent memory, show it
if role == "assistant" and i > 1 and self.config.save_full_traces:
# Agent memory with thinking is often present in assistant messages
if "<thinking>" in content:
complete_conversation.append("### Thinking Process")
# The thinking is already in the content
# Add tool usage information if present
if "tool_usage" in msg:
tool_name = msg.get("tool_usage", {}).get(
"name", "unknown tool"
)
complete_conversation.append(f"### 🛠️ Tool Used: {tool_name}")
tool_args = msg.get("tool_usage", {}).get("args", {})
if tool_args:
complete_conversation.append("```json")
complete_conversation.append(
json.dumps(tool_args, indent=2)
)
complete_conversation.append("```")
tool_result = msg.get("tool_usage", {}).get("result", None)
if tool_result:
complete_conversation.append("### Tool Result")
complete_conversation.append("```")
complete_conversation.append(str(tool_result))
complete_conversation.append("```")
# Add score information at the end
complete_conversation.append(f"\n## Score: {scored_data['score']:.4f}")
# Join everything into a single string with double newlines between sections
full_conversation_markdown = "\n\n".join(complete_conversation)
# Create the ScoredDataGroup with a single comprehensive markdown document
scored_group = ScoredDataGroup(
tokens=[self.tokenizer.encode(json.dumps(messages))],
masks=[[1] * len(self.tokenizer.encode(json.dumps(messages)))],
scores=[scored_data["score"]],
messages=[
full_conversation_markdown
], # Use single markdown document for HTML
_original_messages=[messages], # Keep original for trainer API
)
else:
# Create a proper conversation with role-based messages
messages = [
{
"role": "system",
"content": "You are an AI assistant solving a task with reasoning and problem-solving skills.",
},
{"role": "user", "content": item.prompt},
{"role": "assistant", "content": scored_data["response"]},
]
# Use the standard tokenize_for_trainer utility
# Tokenize using the standard utility (only trains on assistant messages)
tokenized = tokenize_for_trainer(
self.tokenizer,
messages,
train_on_all_assistant_turns=True, # Train on all assistant turns if present
)
# Create the ScoredDataGroup
scored_group = ScoredDataGroup(
tokens=[tokenized["tokens"]],
masks=[tokenized["masks"]],
scores=[scored_data["score"]],
messages=None,
)
return scored_group
async def evaluate(self, **kwargs):
"""
Evaluate the agent on a subset of the GAIA benchmark.
Provides metrics on:
- Success rate
- Average score
- Execution time
- Step efficiency
"""
logger.info("Starting evaluation on GAIA benchmark")
# Use a fixed subset of examples for evaluation
# Start from a different point than training to avoid overlap
eval_start = len(self.examples) // 2
eval_count = min(
10, len(self.examples) // 10
) # 10% of dataset or 10 examples max
eval_examples = self.examples[eval_start : eval_start + eval_count]
results = []
correct_count = 0
# Create items for evaluation
eval_items = []
for example in eval_examples:
# Create an Item for this example
item = Item(
prompt=example["question"],
metadata={
"task_id": example["task_id"],
"task": example["task"],
"true_answer": example["true_answer"],
"file_name": example.get("file_name", ""),
},
)
eval_items.append(item)
# Use the existing process-based trajectory collection
scored_groups, _ = await self.collect_trajectories(eval_items)
# Process the scored groups
for scored_group in scored_groups:
if isinstance(scored_group, ScoredDataGroup):
score = (
scored_group["scores"][0]
if "scores" in scored_group and scored_group["scores"]
else 0
)
# Try to extract the task_id from metadata
task_id = None
if (
"group_overrides" in scored_group
and scored_group["group_overrides"]
):
task_id = scored_group["group_overrides"].get("task_id")
results.append(
{
"task_id": task_id or "unknown",
"score": score,
}
)
if score > 0.5: # Consider it correct if score > 0.5
correct_count += 1
# Since we're using the process-based approach, the execution time
# is stored in the server metrics which are already tracked
# Calculate metrics
if results:
success_rate = correct_count / len(results)
avg_score = sum(r["score"] for r in results) / len(results)
# Calculate average time from agent_execution_times if available
avg_time = 0
if self.agent_execution_times:
avg_time = sum(self.agent_execution_times) / len(
self.agent_execution_times
)
logger.info(f"Evaluation complete on {len(results)} examples:")
logger.info(f" Success rate: {success_rate:.2f}")
logger.info(f" Average score: {avg_score:.2f}")
logger.info(f" Average execution time: {avg_time:.2f}s")
# Update wandb metrics
if self.config.use_wandb:
metrics = {
"eval/success_rate": success_rate,
"eval/avg_score": avg_score,
"eval/num_examples": len(results),
"eval/avg_execution_time": avg_time,
}
await self.wandb_log(metrics)
def save_checkpoint(self, step, data=None):
"""Save environment state for checkpointing."""
if data is None:
data = {}
# Save the iteration counter
data["iter"] = self.iter
# Save the current index in the dataset
data["current_index"] = self.current_index
# Call the parent class save_checkpoint
super().save_checkpoint(step, data)
def load_checkpoint(self):
"""Load environment state from checkpoint."""
# Call parent method first
super().load_checkpoint()
# Check if we loaded iter and current_index
if hasattr(self, "checkpoint_data"):
if "iter" in self.checkpoint_data:
self.iter = self.checkpoint_data["iter"]
if "current_index" in self.checkpoint_data:
self.current_index = self.checkpoint_data["current_index"]
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""
Log to wandb with comprehensive metrics.
"""
if wandb_metrics is None:
wandb_metrics = dict()
# Try to calculate percent_correct, skip if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
# Log agent performance metrics
if self.agent_execution_times and len(self.agent_execution_times) > 0:
wandb_metrics["agent/avg_execution_time"] = sum(
self.agent_execution_times
) / len(self.agent_execution_times)
wandb_metrics["agent/max_execution_time"] = max(self.agent_execution_times)
wandb_metrics["agent/min_execution_time"] = min(self.agent_execution_times)
# Reset the buffer
self.agent_execution_times = []
# Add dataset iteration tracking
wandb_metrics["train/dataset_iterations"] = self.iter
wandb_metrics["train/current_dataset_index"] = self.current_index
# Add custom evaluation metrics
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
# Clear buffers after logging
self.percent_correct_buffer = []
self.eval_metrics = []
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
async def cleanup(self):
"""Clean up resources when environment is closed."""
logger.info("Cleaning up SmolagentsEnv resources")
# Clean up the server proxy manager
if self.server_proxy_manager:
self.server_proxy_manager.stop()
logger.info("Stopped server proxy manager")
# Let the parent class do its cleanup
await super().cleanup()
if __name__ == "__main__":
SmolagentsEnv.cli()