mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1086 lines
42 KiB
Python
1086 lines
42 KiB
Python
"""
|
|
SmolagentsEnv - Environment for creating high-quality agent trajectories
|
|
for training language models using the SmolaGents agent framework.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
|
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
from .agent_process_runner import run_agent_process
|
|
from .server_proxy import ServerProxyManager
|
|
|
|
|
|
@dataclass
|
|
class Item:
|
|
prompt: str
|
|
metadata: Dict[str, Any]
|
|
id: Optional[str] = None
|
|
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
# Prevent propagation to root logger to avoid duplicate logging
|
|
logger.propagate = False
|
|
|
|
# Only add handler if not already present
|
|
if not logger.handlers:
|
|
# Add a console handler to make sure logs are visible
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter("%(levelname)s - %(message)s")
|
|
console_handler.setFormatter(formatter)
|
|
logger.addHandler(console_handler)
|
|
|
|
|
|
class SmolagentsEnvConfig(BaseEnvConfig):
|
|
"""Configuration for SmolagentsEnv."""
|
|
|
|
dataset_path: str = Field(default="data/gaia", description="Path to GAIA dataset")
|
|
split: str = Field(
|
|
default="validation", description="Dataset split to use (validation, test)"
|
|
)
|
|
use_chat_completion: bool = Field(
|
|
default=True, description="Use chat completion API"
|
|
)
|
|
max_steps: int = Field(default=12, description="Maximum number of agent steps")
|
|
agent_verbosity: int = Field(default=2, description="Agent verbosity level (0-3)")
|
|
scoring_strategy: str = Field(
|
|
default="combined",
|
|
description="Scoring strategy: basic, correctness, or combined",
|
|
)
|
|
length_penalty_weight: float = Field(
|
|
default=0.1, description="Weight for length penalty in scoring (0.0 to disable)"
|
|
)
|
|
save_full_traces: bool = Field(
|
|
default=True, description="Save full agent execution traces in the output"
|
|
)
|
|
# Output path configured in __init__
|
|
data_path_to_save_groups: Optional[str] = Field(
|
|
default=None,
|
|
description="Path to save JSONL output (defaults to timestamped file if None)",
|
|
)
|
|
# Process-based settings
|
|
max_concurrent_processes: int = Field(
|
|
default=5,
|
|
description="Maximum number of concurrent processes for agent execution",
|
|
)
|
|
process_timeout: int = Field(
|
|
default=240, # 4 minutes by default
|
|
description="Timeout for agent processes in seconds",
|
|
)
|
|
# Debugging options
|
|
debug_scoring: bool = Field(
|
|
default=False, description="Enable detailed score calculation logging"
|
|
)
|
|
|
|
|
|
class SmolagentsEnv(BaseEnv):
|
|
"""
|
|
Environment for generating high-quality agent trajectories using the SmolaGents framework.
|
|
|
|
This environment:
|
|
1. Loads tasks from the GAIA benchmark dataset
|
|
2. Uses SmolaGents CodeAgent with appropriate tools
|
|
3. Scores trajectories based on correctness and reasoning quality
|
|
4. Integrates with Atropos SFT generation pipeline
|
|
"""
|
|
|
|
name = "smolagents"
|
|
env_config_cls = SmolagentsEnvConfig
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
|
"""Initialize the config for CLI use."""
|
|
env_config = SmolagentsEnvConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=32,
|
|
steps_per_eval=100,
|
|
max_token_length=4096,
|
|
wandb_name="smolagents",
|
|
include_messages=True,
|
|
# Process-based settings
|
|
max_concurrent_processes=8,
|
|
process_timeout=240,
|
|
# Common settings
|
|
dataset_path="data/gaia",
|
|
split="validation", # GAIA only supports 'validation' and 'test' splits
|
|
use_chat_completion=True,
|
|
# Debugging options
|
|
debug_scoring=False, # Set to True to enable detailed score logging
|
|
# Using default timestamped output path from the config definition
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=32,
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
def __init__(
|
|
self,
|
|
config: SmolagentsEnvConfig,
|
|
server_configs: Union[List[APIServerConfig], APIServerConfig],
|
|
slurm=False,
|
|
testing=False,
|
|
):
|
|
# Set a timestamped output file path if not provided
|
|
if config.data_path_to_save_groups is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
config.data_path_to_save_groups = f"smolagents_output_{timestamp}.jsonl"
|
|
print(
|
|
f"Using auto-generated output path: {config.data_path_to_save_groups}"
|
|
)
|
|
|
|
# Initialize the base class
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# Initialize dataset variables
|
|
self.examples = []
|
|
self.current_index = 0
|
|
self.iter = 0 # Add iter for checkpoint tracking
|
|
|
|
# Initialize the server proxy manager for process-based execution
|
|
self.server_proxy_manager = None # Will be initialized in setup()
|
|
|
|
# Save config for easier access
|
|
self.max_steps = config.max_steps
|
|
self.verbosity = config.agent_verbosity
|
|
self.scoring_strategy = config.scoring_strategy
|
|
self.debug_scoring = config.debug_scoring
|
|
|
|
# Track agent execution times and metrics
|
|
self.agent_execution_times = []
|
|
self.percent_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
async def setup(self):
|
|
"""Set up the environment, load dataset, and initialize server components."""
|
|
logger.info("Setting up SmolagentsEnv...")
|
|
logger.info(f"Using dataset split: {self.config.split}")
|
|
|
|
# Initialize the server proxy manager
|
|
logger.info("Setting up process-based isolation for agent execution")
|
|
self.server_proxy_manager = ServerProxyManager(
|
|
server=self.server, max_workers=self.config.max_concurrent_processes
|
|
)
|
|
self.server_proxy_manager.start()
|
|
logger.info(
|
|
f"Started server proxy manager with max_workers={self.config.max_concurrent_processes}"
|
|
)
|
|
|
|
# Load the GAIA dataset
|
|
try:
|
|
import os
|
|
|
|
# Check if dataset exists
|
|
dataset_path = self.config.dataset_path
|
|
validation_path = os.path.join(
|
|
dataset_path, "2023", "validation", "metadata.jsonl"
|
|
)
|
|
gaia_py_path = os.path.join(dataset_path, "GAIA.py")
|
|
|
|
# If dataset files are missing, try to download them
|
|
if not os.path.exists(validation_path) or not os.path.exists(gaia_py_path):
|
|
logger.info(
|
|
f"GAIA dataset not found at {dataset_path}, attempting to download..."
|
|
)
|
|
from .download_gaia import download_gaia_dataset
|
|
|
|
download_success = download_gaia_dataset(dataset_path)
|
|
if not download_success:
|
|
logger.error("Failed to download GAIA dataset automatically.")
|
|
logger.error(
|
|
"Please run: python -m environments.smolagents_integration.download_gaia"
|
|
)
|
|
self.examples = []
|
|
return
|
|
else:
|
|
logger.info(
|
|
f"GAIA dataset downloaded successfully to {dataset_path}"
|
|
)
|
|
|
|
logger.info(
|
|
f"Loading GAIA dataset directly from {self.config.dataset_path}"
|
|
)
|
|
|
|
# Load the metadata.jsonl file directly instead of using the datasets library
|
|
import json
|
|
|
|
metadata_path = os.path.join(
|
|
self.config.dataset_path, "2023", self.config.split, "metadata.jsonl"
|
|
)
|
|
|
|
logger.info(f"Reading metadata from: {metadata_path}")
|
|
|
|
# Check if the file exists
|
|
if not os.path.exists(metadata_path):
|
|
logger.error(f"Metadata file not found: {metadata_path}")
|
|
self.examples = []
|
|
return
|
|
|
|
# Read the metadata file directly
|
|
self.examples = []
|
|
with open(metadata_path, "r") as f:
|
|
for i, line in enumerate(f):
|
|
try:
|
|
example = json.loads(line)
|
|
self.examples.append(
|
|
{
|
|
"question": example["Question"],
|
|
"true_answer": example["Final answer"],
|
|
"task": example["Level"],
|
|
"task_id": (
|
|
example["task_id"]
|
|
if "task_id" in example
|
|
else f"task_{i}"
|
|
),
|
|
"file_name": (
|
|
os.path.join(
|
|
self.config.dataset_path,
|
|
"2023",
|
|
self.config.split,
|
|
example["file_name"],
|
|
)
|
|
if example.get("file_name")
|
|
else ""
|
|
),
|
|
}
|
|
)
|
|
except Exception as parse_error:
|
|
logger.error(
|
|
f"Error parsing line {i} of metadata file: {parse_error}"
|
|
)
|
|
continue
|
|
|
|
logger.info(
|
|
f"Loaded {len(self.examples)} examples from GAIA {self.config.split} set"
|
|
)
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
logger.error(f"Error loading GAIA dataset: {type(e).__name__}: {e}")
|
|
logger.error(f"Detailed traceback: {traceback.format_exc()}")
|
|
logger.error(
|
|
"Please run: python -m environments.smolagents_integration.download_gaia"
|
|
)
|
|
# Create empty list if dataset loading fails
|
|
self.examples = []
|
|
|
|
logger.info("SmolagentsEnv setup complete")
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""Get the next item from the GAIA dataset."""
|
|
if not self.examples:
|
|
logger.warning("No examples loaded in dataset")
|
|
return None
|
|
|
|
# Use iter to track position and support checkpointing
|
|
example = self.examples[self.iter % len(self.examples)]
|
|
self.iter += 1
|
|
self.current_index = self.iter % len(self.examples)
|
|
|
|
# Construct the prompt
|
|
prompt = example["question"]
|
|
|
|
# Add file information if available
|
|
if example.get("file_name"):
|
|
prompt += f"\n\nTo solve this task, you can use the file at: {example['file_name']}"
|
|
|
|
# Create an Item object
|
|
item = Item(
|
|
prompt=prompt,
|
|
metadata={
|
|
"task_id": example["task_id"],
|
|
"task": example["task"],
|
|
"true_answer": example["true_answer"],
|
|
"file_name": example.get("file_name", ""),
|
|
"dataset_idx": self.current_index,
|
|
},
|
|
)
|
|
|
|
return item
|
|
|
|
async def collect_trajectories(self, items: Union[Item, List[Item]]) -> Tuple[
|
|
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any]],
|
|
List[Item],
|
|
]:
|
|
"""
|
|
Collect trajectories for multiple items using process-based parallelism.
|
|
"""
|
|
# Handle both single item and list of items
|
|
if not isinstance(items, list):
|
|
items = [items] * self.config.group_size
|
|
|
|
return await self._collect_trajectories_process_based(items)
|
|
|
|
async def _collect_trajectories_process_based(
|
|
self, items: List[Item]
|
|
) -> Tuple[List[Any], List[Item]]:
|
|
"""
|
|
Collect trajectories for multiple items using process-based parallelism.
|
|
"""
|
|
logger.info(
|
|
f"Collecting trajectories for {len(items)} items using process-based parallelism"
|
|
)
|
|
|
|
# Create a manager for shared objects
|
|
manager = multiprocessing.Manager()
|
|
result_queue = manager.Queue()
|
|
|
|
# Create agent config dictionary
|
|
agent_config = {
|
|
"max_steps": self.max_steps,
|
|
"verbosity": self.verbosity,
|
|
"use_chat_completion": self.config.use_chat_completion,
|
|
"model_name": getattr(self.server, "model_name", "unknown-model"),
|
|
}
|
|
|
|
# Start processes for each item
|
|
processes = []
|
|
proxies = []
|
|
|
|
for item in items:
|
|
# Create a server proxy for this process
|
|
server_proxy, proxy_id = self.server_proxy_manager.create_server_proxy(
|
|
model_name=agent_config["model_name"],
|
|
timeout=self.config.process_timeout,
|
|
)
|
|
proxies.append(proxy_id)
|
|
|
|
# Start a process for this item
|
|
process = multiprocessing.Process(
|
|
target=run_agent_process,
|
|
args=(
|
|
item.prompt,
|
|
item.metadata,
|
|
server_proxy,
|
|
agent_config,
|
|
result_queue,
|
|
),
|
|
)
|
|
process.start()
|
|
processes.append(process)
|
|
|
|
logger.info(f"Started {len(processes)} agent processes")
|
|
|
|
# Wait for all processes to complete or timeout
|
|
for process in processes:
|
|
process.join(timeout=self.config.process_timeout)
|
|
|
|
# Check if process is still alive (timeout)
|
|
if process.is_alive():
|
|
logger.warning(f"Process {process.pid} timed out, terminating")
|
|
process.terminate()
|
|
process.join()
|
|
|
|
# Clean up proxies
|
|
for proxy_id in proxies:
|
|
self.server_proxy_manager.remove_proxy(proxy_id)
|
|
|
|
# Get all results from the queue
|
|
results = []
|
|
while not result_queue.empty():
|
|
try:
|
|
result = result_queue.get(block=False)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Error getting result from queue: {e}")
|
|
break
|
|
|
|
logger.info(f"Collected {len(results)} results from processes")
|
|
|
|
# Process results
|
|
backlog = []
|
|
to_postprocess = []
|
|
|
|
for result in results:
|
|
if result["status"] == "success":
|
|
# Create scored data from successful result
|
|
scored_data = {
|
|
"prompt": result["task_metadata"].get("prompt", ""),
|
|
"response": result["response"],
|
|
"task_id": result["task_id"],
|
|
"task": result["task_metadata"].get("task", ""),
|
|
"true_answer": result["task_metadata"].get("true_answer", ""),
|
|
"execution_time": result["execution_time"],
|
|
}
|
|
|
|
# Add agent memory if configured
|
|
if self.config.save_full_traces and "agent_memory" in result:
|
|
scored_data["agent_memory"] = result["agent_memory"]
|
|
|
|
# Score the trajectory
|
|
score = self._score_trajectory(
|
|
scored_data["prompt"],
|
|
scored_data["response"],
|
|
scored_data["true_answer"],
|
|
scored_data.get("agent_memory"),
|
|
scored_data["execution_time"],
|
|
)
|
|
|
|
scored_data["score"] = score
|
|
|
|
# Create ScoredDataGroup
|
|
item_for_scoring = next(
|
|
(
|
|
i
|
|
for i in items
|
|
if i.metadata.get("task_id") == result["task_id"]
|
|
),
|
|
None,
|
|
)
|
|
if item_for_scoring:
|
|
scored_group = self._create_scored_data_group(
|
|
item_for_scoring, scored_data
|
|
)
|
|
to_postprocess.append(scored_group)
|
|
else:
|
|
logger.warning(
|
|
f"Could not find original item for task_id {result['task_id']}"
|
|
)
|
|
else:
|
|
# Just log the error and omit this example from training
|
|
error_message = result.get("error_message", "Unknown error")
|
|
task_id = result.get("task_id", "Unknown task")
|
|
|
|
logger.warning(
|
|
f"Omitting failed task {task_id} from training batch: {error_message}"
|
|
)
|
|
|
|
# Return processed results
|
|
logger.info(f"Final to_postprocess: len={len(to_postprocess)}")
|
|
return to_postprocess, backlog
|
|
|
|
async def postprocess_histories(
|
|
self, histories: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
|
) -> ScoredDataGroup:
|
|
"""
|
|
Post-process the agent histories.
|
|
|
|
We need to merge multiple ScoredDataGroups into a single ScoredDataGroup.
|
|
"""
|
|
logger.info(
|
|
f"postprocess_histories called with: type={type(histories)}, is_none={histories is None}"
|
|
)
|
|
if isinstance(histories, list):
|
|
logger.info(f" List length: {len(histories)}")
|
|
|
|
if not isinstance(histories, list):
|
|
# If it's already a single ScoredDataGroup, return it with group_overrides
|
|
logger.info(f" Single history, returning directly: {type(histories)}")
|
|
if (
|
|
"group_overrides" not in histories
|
|
or histories["group_overrides"] is None
|
|
):
|
|
histories["group_overrides"] = {}
|
|
return histories
|
|
|
|
# If we have multiple ScoredDataGroups, merge them
|
|
logger.info(f" Merging {len(histories)} histories")
|
|
merged = ScoredDataGroup(
|
|
tokens=[],
|
|
masks=[],
|
|
scores=[],
|
|
advantages=None,
|
|
ref_logprobs=None,
|
|
messages=[] if self.config.include_messages else None,
|
|
group_overrides={},
|
|
overrides=None,
|
|
)
|
|
|
|
# Merge all the fields
|
|
for i, history in enumerate(histories):
|
|
logger.info(
|
|
f" Processing history {i}: type={type(history)}, is_none={history is None}"
|
|
)
|
|
if history is not None:
|
|
logger.info(f" History {i} tokens: {len(history['tokens'])}")
|
|
merged["tokens"].extend(history["tokens"])
|
|
merged["masks"].extend(history["masks"])
|
|
merged["scores"].extend(history["scores"])
|
|
|
|
if merged["messages"] is not None and "messages" in history:
|
|
logger.info(f" History {i} messages: {len(history['messages'])}")
|
|
merged["messages"].extend(history["messages"])
|
|
|
|
logger.info(
|
|
f" Final merged data: tokens={len(merged['tokens'])}, scores={len(merged['scores'])}"
|
|
)
|
|
return merged
|
|
|
|
def _score_trajectory(
|
|
self,
|
|
prompt: str,
|
|
agent_response: str,
|
|
true_answer: str,
|
|
agent_memory: List[Dict] = None,
|
|
execution_time: float = 0, # Parameter kept for backward compatibility but not used in scoring
|
|
) -> float:
|
|
"""
|
|
Score the agent trajectory based on multiple criteria:
|
|
- Answer correctness using GAIA scoring
|
|
- Message format adherence
|
|
- Final answer tool usage
|
|
- Execution success (detection of errors)
|
|
- Efficiency (steps only)
|
|
|
|
Args:
|
|
prompt: The original task prompt
|
|
agent_response: The final response from the agent
|
|
true_answer: The ground truth answer
|
|
agent_memory: The memory trace of the agent's steps
|
|
execution_time: Time taken for execution (not used in scoring)
|
|
|
|
Returns:
|
|
float: A score between 0.0 and 1.0
|
|
"""
|
|
# Import all scoring functions upfront
|
|
from .evaluations.smolagent_integrations.rubrics.gaia_scorer import (
|
|
check_close_call,
|
|
question_scorer,
|
|
)
|
|
from .evaluations.smolagent_integrations.smolagents_scorer import (
|
|
calculate_efficiency_score,
|
|
calculate_execution_score,
|
|
check_final_answer_usage,
|
|
check_format_adherence,
|
|
)
|
|
|
|
# Initialize component scores
|
|
correctness_score = 0.0
|
|
format_score = 0.0
|
|
final_answer_score = 0.0
|
|
execution_score = 0.0
|
|
efficiency_score = 0.0
|
|
|
|
# 1. Calculate correctness score using GAIA scorer
|
|
# Ensure agent_response is a string before scoring
|
|
if not isinstance(agent_response, str):
|
|
logger.warning(
|
|
f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}"
|
|
)
|
|
try:
|
|
if isinstance(agent_response, set):
|
|
# Convert sets to comma-separated strings
|
|
agent_response = ", ".join(str(item) for item in agent_response)
|
|
else:
|
|
# Try to convert other types to string
|
|
agent_response = str(agent_response)
|
|
except Exception as e:
|
|
logger.error(f"Failed to convert agent_response to string: {e}")
|
|
agent_response = ""
|
|
|
|
is_correct = question_scorer(agent_response, true_answer)
|
|
is_near_correct = check_close_call(agent_response, true_answer, is_correct)
|
|
|
|
if is_correct:
|
|
correctness_score = 1.0
|
|
elif is_near_correct:
|
|
correctness_score = 0.5
|
|
|
|
# 2. Check format adherence
|
|
if agent_memory:
|
|
format_scores = []
|
|
for step in agent_memory:
|
|
# Handle both dict and ChatMessage objects
|
|
content = None
|
|
if hasattr(step, "content"):
|
|
content = step.content
|
|
elif isinstance(step, dict):
|
|
content = step.get("content") or step.get("model_output")
|
|
|
|
if content and isinstance(content, str):
|
|
format_scores.append(check_format_adherence(content))
|
|
|
|
# Average the format scores across all steps
|
|
format_score = (
|
|
sum(format_scores) / len(format_scores) if format_scores else 0.0
|
|
)
|
|
|
|
# 3. Check for final_answer tool usage
|
|
final_answer_used = False
|
|
for step in agent_memory:
|
|
# Handle both dict and ChatMessage objects
|
|
content = None
|
|
if hasattr(step, "content"):
|
|
content = step.content
|
|
elif isinstance(step, dict):
|
|
content = step.get("content") or step.get("model_output")
|
|
|
|
if content and isinstance(content, str):
|
|
if check_final_answer_usage(content):
|
|
final_answer_used = True
|
|
break
|
|
|
|
final_answer_score = 1.0 if final_answer_used else 0.0
|
|
|
|
# 4. Check for execution errors and calculate execution score
|
|
execution_score = calculate_execution_score(agent_memory)
|
|
|
|
# 5. Calculate efficiency score
|
|
steps_count = len(agent_memory)
|
|
efficiency_score = calculate_efficiency_score(
|
|
steps_count=steps_count, max_steps=self.max_steps
|
|
)
|
|
|
|
# Component weights - can be adjusted to emphasize different aspects
|
|
correctness_weight = 0.50 # 50% of score - correctness matters most
|
|
format_weight = 0.20 # 20% - format adherence is important
|
|
final_answer_weight = 0.10 # 10% - using final_answer tool properly
|
|
execution_weight = 0.10 # 10% - avoiding errors
|
|
efficiency_weight = 0.10 # 10% - being efficient
|
|
|
|
# Calculate combined score with weights
|
|
combined_score = (
|
|
correctness_score * correctness_weight
|
|
+ format_score * format_weight
|
|
+ final_answer_score * final_answer_weight
|
|
+ execution_score * execution_weight
|
|
+ efficiency_score * efficiency_weight
|
|
)
|
|
|
|
# Apply length penalty if configured
|
|
length_penalty = 0.0
|
|
if self.config.length_penalty_weight > 0:
|
|
# agent_response should already be a string from the previous conversion
|
|
# but double-check just to be sure
|
|
response_to_measure = (
|
|
agent_response
|
|
if isinstance(agent_response, str)
|
|
else str(agent_response)
|
|
)
|
|
response_length = len(response_to_measure)
|
|
# Penalize very long responses
|
|
if response_length > 2000:
|
|
length_penalty = min(
|
|
0.3,
|
|
self.config.length_penalty_weight * (response_length - 2000) / 1000,
|
|
)
|
|
combined_score = max(0.0, combined_score - length_penalty)
|
|
|
|
# Debug logging for score calculation
|
|
if self.debug_scoring:
|
|
logger.info("=== SCORE CALCULATION (detailed) ===")
|
|
logger.info("1. Correctness component:")
|
|
logger.info(f" - True answer: '{true_answer}'")
|
|
logger.info(f" - Agent answer: '{agent_response}'")
|
|
logger.info(f" - Is correct: {is_correct}")
|
|
logger.info(f" - Is near correct: {is_near_correct}")
|
|
logger.info(f" - Raw score: {correctness_score:.3f}")
|
|
logger.info(f" - Weight: {correctness_weight}")
|
|
logger.info(
|
|
f" - Weighted score: {correctness_score * correctness_weight:.3f}"
|
|
)
|
|
|
|
logger.info("2. Format adherence component:")
|
|
logger.info(
|
|
f" - Format scores: {format_scores if 'format_scores' in locals() else []}"
|
|
)
|
|
logger.info(f" - Raw score: {format_score:.3f}")
|
|
logger.info(f" - Weight: {format_weight}")
|
|
logger.info(f" - Weighted score: {format_score * format_weight:.3f}")
|
|
|
|
logger.info("3. Final answer tool usage:")
|
|
logger.info(
|
|
f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}"
|
|
)
|
|
logger.info(f" - Raw score: {final_answer_score:.3f}")
|
|
logger.info(f" - Weight: {final_answer_weight}")
|
|
logger.info(
|
|
f" - Weighted score: {final_answer_score * final_answer_weight:.3f}"
|
|
)
|
|
|
|
logger.info("4. Execution component:")
|
|
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
|
|
logger.info(f" - Raw score: {execution_score:.3f}")
|
|
logger.info(f" - Weight: {execution_weight}")
|
|
logger.info(
|
|
f" - Weighted score: {execution_score * execution_weight:.3f}"
|
|
)
|
|
|
|
logger.info("5. Efficiency component:")
|
|
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
|
|
logger.info(f" - Max steps: {self.max_steps}")
|
|
logger.info(f" - Raw score: {efficiency_score:.3f}")
|
|
logger.info(f" - Weight: {efficiency_weight}")
|
|
logger.info(
|
|
f" - Weighted score: {efficiency_score * efficiency_weight:.3f}"
|
|
)
|
|
|
|
logger.info("6. Length penalty:")
|
|
logger.info(
|
|
f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}"
|
|
)
|
|
logger.info(f" - Penalty: {length_penalty:.3f}")
|
|
|
|
logger.info("7. Final score calculation:")
|
|
logger.info(
|
|
f" - Correctness: {correctness_score * correctness_weight:.3f}"
|
|
)
|
|
logger.info(f" - Format adherence: {format_score * format_weight:.3f}")
|
|
logger.info(
|
|
f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}"
|
|
)
|
|
logger.info(f" - Execution: {execution_score * execution_weight:.3f}")
|
|
logger.info(f" - Efficiency: {efficiency_score * efficiency_weight:.3f}")
|
|
logger.info(f" - Length penalty: -{length_penalty:.3f}")
|
|
logger.info(f" - FINAL SCORE: {combined_score:.3f}")
|
|
|
|
return combined_score
|
|
|
|
def _create_scored_data_group(
|
|
self, item: Item, scored_data: Dict
|
|
) -> ScoredDataGroup:
|
|
"""
|
|
Create a ScoredDataGroup for the trainer API.
|
|
|
|
Converts the agent trajectory into tokenized format for the trainer.
|
|
"""
|
|
# Prepare the data in message format or token format
|
|
if self.config.include_messages:
|
|
# Create message format with agent memory if available
|
|
messages = []
|
|
|
|
# Add system message with task description
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": "You are an AI assistant solving a task with reasoning and problem-solving skills.",
|
|
}
|
|
)
|
|
|
|
# Add user message with the prompt
|
|
messages.append({"role": "user", "content": item.prompt})
|
|
|
|
# For message format, extract agent memory if available
|
|
if self.config.save_full_traces and "agent_memory" in scored_data:
|
|
# Add intermediate reasoning steps
|
|
for message in scored_data["agent_memory"]:
|
|
# Handle both dict and ChatMessage objects
|
|
if hasattr(message, "role") and hasattr(message, "content"):
|
|
# Convert ChatMessage to dict
|
|
role = (
|
|
message.role.value
|
|
if hasattr(message.role, "value")
|
|
else str(message.role)
|
|
)
|
|
messages.append({"role": role, "content": message.content})
|
|
elif isinstance(message, dict):
|
|
messages.append(message)
|
|
else:
|
|
# Unknown format, try to convert to dict
|
|
messages.append({"role": "assistant", "content": str(message)})
|
|
else:
|
|
# Just add the final response
|
|
messages.append(
|
|
{"role": "assistant", "content": scored_data["response"]}
|
|
)
|
|
|
|
# Create a comprehensive markdown document for HTML compatibility
|
|
# Format similar to the Wikipedia environment
|
|
complete_conversation = []
|
|
|
|
# Add task information at the top
|
|
task_type = item.metadata.get("task", "Unknown task")
|
|
task_id = item.metadata.get("task_id", "Unknown ID")
|
|
complete_conversation.append(f"# GAIA Task: {task_type} (ID: {task_id})")
|
|
|
|
# Process each message in the conversation
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role", "unknown")
|
|
content = msg.get("content", "")
|
|
|
|
# Skip empty messages
|
|
if not content:
|
|
continue
|
|
|
|
# Add a header for each role
|
|
complete_conversation.append(f"## {role.upper()}")
|
|
|
|
# Handle different content formats
|
|
if isinstance(content, list):
|
|
content_text = []
|
|
for item in content:
|
|
if isinstance(item, dict) and "text" in item:
|
|
content_text.append(item["text"])
|
|
else:
|
|
content_text.append(str(item))
|
|
content = "\n".join(content_text)
|
|
elif not isinstance(content, str):
|
|
# Handle non-string content
|
|
content = str(content)
|
|
|
|
# Add the message content
|
|
complete_conversation.append(content)
|
|
|
|
# If this message has agent memory, show it
|
|
if role == "assistant" and i > 1 and self.config.save_full_traces:
|
|
# Agent memory with thinking is often present in assistant messages
|
|
if "<thinking>" in content:
|
|
complete_conversation.append("### Thinking Process")
|
|
# The thinking is already in the content
|
|
|
|
# Add tool usage information if present
|
|
if "tool_usage" in msg:
|
|
tool_name = msg.get("tool_usage", {}).get(
|
|
"name", "unknown tool"
|
|
)
|
|
complete_conversation.append(f"### 🛠️ Tool Used: {tool_name}")
|
|
tool_args = msg.get("tool_usage", {}).get("args", {})
|
|
if tool_args:
|
|
complete_conversation.append("```json")
|
|
complete_conversation.append(
|
|
json.dumps(tool_args, indent=2)
|
|
)
|
|
complete_conversation.append("```")
|
|
|
|
tool_result = msg.get("tool_usage", {}).get("result", None)
|
|
if tool_result:
|
|
complete_conversation.append("### Tool Result")
|
|
complete_conversation.append("```")
|
|
complete_conversation.append(str(tool_result))
|
|
complete_conversation.append("```")
|
|
|
|
# Add score information at the end
|
|
complete_conversation.append(f"\n## Score: {scored_data['score']:.4f}")
|
|
|
|
# Join everything into a single string with double newlines between sections
|
|
full_conversation_markdown = "\n\n".join(complete_conversation)
|
|
|
|
# Create the ScoredDataGroup with a single comprehensive markdown document
|
|
scored_group = ScoredDataGroup(
|
|
tokens=[self.tokenizer.encode(json.dumps(messages))],
|
|
masks=[[1] * len(self.tokenizer.encode(json.dumps(messages)))],
|
|
scores=[scored_data["score"]],
|
|
messages=[
|
|
full_conversation_markdown
|
|
], # Use single markdown document for HTML
|
|
_original_messages=[messages], # Keep original for trainer API
|
|
)
|
|
|
|
else:
|
|
# Create a proper conversation with role-based messages
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are an AI assistant solving a task with reasoning and problem-solving skills.",
|
|
},
|
|
{"role": "user", "content": item.prompt},
|
|
{"role": "assistant", "content": scored_data["response"]},
|
|
]
|
|
|
|
# Use the standard tokenize_for_trainer utility
|
|
|
|
# Tokenize using the standard utility (only trains on assistant messages)
|
|
tokenized = tokenize_for_trainer(
|
|
self.tokenizer,
|
|
messages,
|
|
train_on_all_assistant_turns=True, # Train on all assistant turns if present
|
|
)
|
|
|
|
# Create the ScoredDataGroup
|
|
scored_group = ScoredDataGroup(
|
|
tokens=[tokenized["tokens"]],
|
|
masks=[tokenized["masks"]],
|
|
scores=[scored_data["score"]],
|
|
messages=None,
|
|
)
|
|
|
|
return scored_group
|
|
|
|
async def evaluate(self, **kwargs):
|
|
"""
|
|
Evaluate the agent on a subset of the GAIA benchmark.
|
|
|
|
Provides metrics on:
|
|
- Success rate
|
|
- Average score
|
|
- Execution time
|
|
- Step efficiency
|
|
"""
|
|
logger.info("Starting evaluation on GAIA benchmark")
|
|
|
|
# Use a fixed subset of examples for evaluation
|
|
# Start from a different point than training to avoid overlap
|
|
eval_start = len(self.examples) // 2
|
|
eval_count = min(
|
|
10, len(self.examples) // 10
|
|
) # 10% of dataset or 10 examples max
|
|
|
|
eval_examples = self.examples[eval_start : eval_start + eval_count]
|
|
|
|
results = []
|
|
correct_count = 0
|
|
|
|
# Create items for evaluation
|
|
eval_items = []
|
|
for example in eval_examples:
|
|
# Create an Item for this example
|
|
item = Item(
|
|
prompt=example["question"],
|
|
metadata={
|
|
"task_id": example["task_id"],
|
|
"task": example["task"],
|
|
"true_answer": example["true_answer"],
|
|
"file_name": example.get("file_name", ""),
|
|
},
|
|
)
|
|
eval_items.append(item)
|
|
|
|
# Use the existing process-based trajectory collection
|
|
scored_groups, _ = await self.collect_trajectories(eval_items)
|
|
|
|
# Process the scored groups
|
|
for scored_group in scored_groups:
|
|
if isinstance(scored_group, ScoredDataGroup):
|
|
score = (
|
|
scored_group["scores"][0]
|
|
if "scores" in scored_group and scored_group["scores"]
|
|
else 0
|
|
)
|
|
|
|
# Try to extract the task_id from metadata
|
|
task_id = None
|
|
if (
|
|
"group_overrides" in scored_group
|
|
and scored_group["group_overrides"]
|
|
):
|
|
task_id = scored_group["group_overrides"].get("task_id")
|
|
|
|
results.append(
|
|
{
|
|
"task_id": task_id or "unknown",
|
|
"score": score,
|
|
}
|
|
)
|
|
|
|
if score > 0.5: # Consider it correct if score > 0.5
|
|
correct_count += 1
|
|
|
|
# Since we're using the process-based approach, the execution time
|
|
# is stored in the server metrics which are already tracked
|
|
|
|
# Calculate metrics
|
|
if results:
|
|
success_rate = correct_count / len(results)
|
|
avg_score = sum(r["score"] for r in results) / len(results)
|
|
|
|
# Calculate average time from agent_execution_times if available
|
|
avg_time = 0
|
|
if self.agent_execution_times:
|
|
avg_time = sum(self.agent_execution_times) / len(
|
|
self.agent_execution_times
|
|
)
|
|
|
|
logger.info(f"Evaluation complete on {len(results)} examples:")
|
|
logger.info(f" Success rate: {success_rate:.2f}")
|
|
logger.info(f" Average score: {avg_score:.2f}")
|
|
logger.info(f" Average execution time: {avg_time:.2f}s")
|
|
|
|
# Update wandb metrics
|
|
if self.config.use_wandb:
|
|
metrics = {
|
|
"eval/success_rate": success_rate,
|
|
"eval/avg_score": avg_score,
|
|
"eval/num_examples": len(results),
|
|
"eval/avg_execution_time": avg_time,
|
|
}
|
|
|
|
await self.wandb_log(metrics)
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save environment state for checkpointing."""
|
|
if data is None:
|
|
data = {}
|
|
# Save the iteration counter
|
|
data["iter"] = self.iter
|
|
# Save the current index in the dataset
|
|
data["current_index"] = self.current_index
|
|
# Call the parent class save_checkpoint
|
|
super().save_checkpoint(step, data)
|
|
|
|
def load_checkpoint(self):
|
|
"""Load environment state from checkpoint."""
|
|
# Call parent method first
|
|
super().load_checkpoint()
|
|
# Check if we loaded iter and current_index
|
|
if hasattr(self, "checkpoint_data"):
|
|
if "iter" in self.checkpoint_data:
|
|
self.iter = self.checkpoint_data["iter"]
|
|
if "current_index" in self.checkpoint_data:
|
|
self.current_index = self.checkpoint_data["current_index"]
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""
|
|
Log to wandb with comprehensive metrics.
|
|
"""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = dict()
|
|
|
|
# Try to calculate percent_correct, skip if there's a division by zero
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
# Skip if buffer is empty
|
|
pass
|
|
|
|
# Log agent performance metrics
|
|
if self.agent_execution_times and len(self.agent_execution_times) > 0:
|
|
wandb_metrics["agent/avg_execution_time"] = sum(
|
|
self.agent_execution_times
|
|
) / len(self.agent_execution_times)
|
|
wandb_metrics["agent/max_execution_time"] = max(self.agent_execution_times)
|
|
wandb_metrics["agent/min_execution_time"] = min(self.agent_execution_times)
|
|
# Reset the buffer
|
|
self.agent_execution_times = []
|
|
|
|
# Add dataset iteration tracking
|
|
wandb_metrics["train/dataset_iterations"] = self.iter
|
|
wandb_metrics["train/current_dataset_index"] = self.current_index
|
|
|
|
# Add custom evaluation metrics
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
|
|
# Clear buffers after logging
|
|
self.percent_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
# Call the parent method to handle the server metrics
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def cleanup(self):
|
|
"""Clean up resources when environment is closed."""
|
|
logger.info("Cleaning up SmolagentsEnv resources")
|
|
|
|
# Clean up the server proxy manager
|
|
if self.server_proxy_manager:
|
|
self.server_proxy_manager.stop()
|
|
logger.info("Stopped server proxy manager")
|
|
|
|
# Let the parent class do its cleanup
|
|
await super().cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
SmolagentsEnv.cli()
|