mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix imports and style issues
This commit is contained in:
parent
7a653044a4
commit
013090579d
4 changed files with 176 additions and 146 deletions
|
|
@ -78,7 +78,8 @@ python -m environments.smolagents_integration.smolagents_env process \
|
|||
--env.max_concurrent_processes 8 \
|
||||
--env.use_chat_completion true \
|
||||
--openai.model_name "gpt-4o" \
|
||||
--openai.base_url "https://api.openai.com/v1"
|
||||
--openai.base_url "https://api.openai.com/v1" \
|
||||
--openai.api_key x
|
||||
```
|
||||
|
||||
```bash
|
||||
|
|
@ -91,7 +92,8 @@ python -m environments.smolagents_integration.smolagents_env process \
|
|||
--env.max_concurrent_processes 8 \
|
||||
--env.use_chat_completion true \
|
||||
--openai.model_name "gpt-4o" \
|
||||
--openai.base_url "https://api.openai.com/v1"
|
||||
--openai.base_url "https://api.openai.com/v1" \
|
||||
--openai.api_key x
|
||||
```
|
||||
|
||||
Note: The command syntax uses dots (`.`) to separate namespaces. Also, the OpenAI API key should be set in your environment variables as `OPENAI_API_KEY` or in a `.env` file in the project root.
|
||||
|
|
@ -118,8 +120,8 @@ python -m environments.smolagents_integration.smolagents_env serve \
|
|||
--env.use_chat_completion true \
|
||||
--env.max_concurrent_processes 5 \
|
||||
--env.group_size 8 \
|
||||
--openai.model_name "gpt-4o" \
|
||||
--openai.base_url "https://api.openai.com/v1"
|
||||
--openai.model_name "your-model-name" \
|
||||
--openai.base_url "http://localhost:8000/v1"
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
|
@ -204,4 +206,4 @@ Each task includes:
|
|||
- **Web tool errors**: If Tavily tools aren't working, make sure you have set the `TAVILY_API_KEY` environment variable and have installed the `tavily-python` package.
|
||||
- **Tool import errors**: If you see errors about missing tool modules, ensure your working directory allows proper imports of the tools folder.
|
||||
- **Permission errors with file tools**: Ensure your process has the correct permissions to read/write files in the directories being accessed.
|
||||
- **Memory issues**: If you encounter memory usage problems, try lowering the `max_concurrent_processes` parameter.
|
||||
- **Memory issues**: If you encounter memory usage problems, try lowering the `max_concurrent_processes` parameter.
|
||||
|
|
|
|||
|
|
@ -12,22 +12,19 @@ from typing import Any, Dict
|
|||
from smolagents import CodeAgent
|
||||
|
||||
# Import tools directly
|
||||
from environments.smolagents_integration.server_proxy import ServerProxy
|
||||
from environments.smolagents_integration.smolagents_model import ProcessSafeAtroposServerModel
|
||||
from environments.smolagents_integration.tools.file_tools import (
|
||||
append_to_file, read_file, write_file
|
||||
)
|
||||
from .server_proxy import ServerProxy
|
||||
from .smolagents_model import ProcessSafeAtroposServerModel
|
||||
from .tools.file_tools import append_to_file, read_file, write_file
|
||||
|
||||
# Conditionally import Tavily tools if API key is available
|
||||
tavily_tools = []
|
||||
if os.environ.get("TAVILY_API_KEY"):
|
||||
try:
|
||||
from environments.smolagents_integration.tools.tavily_tools import (
|
||||
TavilyExtractTool, TavilySearchTool
|
||||
)
|
||||
from .tools.tavily_tools import TavilyExtractTool, TavilySearchTool
|
||||
|
||||
tavily_tools = [
|
||||
TavilySearchTool(api_key=os.environ.get("TAVILY_API_KEY")),
|
||||
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY"))
|
||||
TavilyExtractTool(api_key=os.environ.get("TAVILY_API_KEY")),
|
||||
]
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -64,7 +61,9 @@ def run_agent_process(
|
|||
try:
|
||||
start_time = time.time()
|
||||
process_id = os.getpid()
|
||||
logger.info(f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}")
|
||||
logger.info(
|
||||
f"Process {process_id} starting for task {task_metadata.get('task_id', 'unknown')}"
|
||||
)
|
||||
|
||||
# Create a model using the server proxy
|
||||
model = ProcessSafeAtroposServerModel(
|
||||
|
|
@ -85,14 +84,18 @@ def run_agent_process(
|
|||
verbosity_level=agent_config.get("verbosity", 2),
|
||||
)
|
||||
|
||||
logger.info(f"Process {process_id}: Running agent on prompt with {len(prompt)} chars")
|
||||
logger.info(
|
||||
f"Process {process_id}: Running agent on prompt with {len(prompt)} chars"
|
||||
)
|
||||
|
||||
# Run the agent and get response
|
||||
agent_response = agent.run(prompt)
|
||||
|
||||
|
||||
# Ensure the response is properly formatted (convert sets, etc. to strings)
|
||||
if not isinstance(agent_response, str):
|
||||
logger.info(f"Converting non-string response of type {type(agent_response)} to string")
|
||||
logger.info(
|
||||
f"Converting non-string response of type {type(agent_response)} to string"
|
||||
)
|
||||
try:
|
||||
if isinstance(agent_response, set):
|
||||
# Convert sets to comma-separated strings
|
||||
|
|
@ -103,7 +106,7 @@ def run_agent_process(
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to convert agent_response to string: {e}")
|
||||
agent_response = str(agent_response)
|
||||
|
||||
|
||||
# Extract agent memory
|
||||
agent_memory = getattr(agent, "memory", None)
|
||||
if hasattr(agent, "write_memory_to_messages"):
|
||||
|
|
@ -113,14 +116,16 @@ def run_agent_process(
|
|||
execution_time = time.time() - start_time
|
||||
|
||||
# Prepare and send result
|
||||
result_queue.put({
|
||||
"status": "success",
|
||||
"response": agent_response,
|
||||
"task_id": task_metadata.get("task_id"),
|
||||
"execution_time": execution_time,
|
||||
"agent_memory": agent_memory,
|
||||
"task_metadata": task_metadata,
|
||||
})
|
||||
result_queue.put(
|
||||
{
|
||||
"status": "success",
|
||||
"response": agent_response,
|
||||
"task_id": task_metadata.get("task_id"),
|
||||
"execution_time": execution_time,
|
||||
"agent_memory": agent_memory,
|
||||
"task_metadata": task_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Process {process_id}: Agent completed in {execution_time:.2f}s")
|
||||
|
||||
|
|
@ -128,14 +133,16 @@ def run_agent_process(
|
|||
# Log the exception and put error result in queue
|
||||
logger.error(f"Process {os.getpid()}: Error in agent execution: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
result_queue.put({
|
||||
"status": "error",
|
||||
"error_message": str(e),
|
||||
"error_traceback": traceback.format_exc(),
|
||||
"task_id": task_metadata.get("task_id"),
|
||||
"task_metadata": task_metadata,
|
||||
})
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"status": "error",
|
||||
"error_message": str(e),
|
||||
"error_traceback": traceback.format_exc(),
|
||||
"task_id": task_metadata.get("task_id"),
|
||||
"task_metadata": task_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
finally:
|
||||
logger.info(f"Process {os.getpid()}: Cleanup complete")
|
||||
|
|
|
|||
|
|
@ -3,35 +3,21 @@ SmolagentsEnv - Environment for creating high-quality agent trajectories
|
|||
for training language models using the SmolaGents agent framework.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from smolagents import CodeAgent, Tool
|
||||
from smolagents.tools import tool # Import the tool decorator
|
||||
from pydantic import Field
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from environments.smolagents_integration.agent_process_runner import run_agent_process
|
||||
from environments.smolagents_integration.server_proxy import ServerProxyManager
|
||||
from environments.smolagents_integration.tools.file_tools import (
|
||||
append_to_file,
|
||||
read_file,
|
||||
write_file,
|
||||
)
|
||||
|
||||
from .agent_process_runner import run_agent_process
|
||||
from .server_proxy import ServerProxyManager
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -73,7 +59,6 @@ class SmolagentsEnvConfig(BaseEnvConfig):
|
|||
default="combined",
|
||||
description="Scoring strategy: basic, correctness, or combined",
|
||||
)
|
||||
# Removed max_concurrent_agents as we only use process-based execution
|
||||
length_penalty_weight: float = Field(
|
||||
default=0.1, description="Weight for length penalty in scoring (0.0 to disable)"
|
||||
)
|
||||
|
|
@ -240,8 +225,6 @@ class SmolagentsEnv(BaseEnv):
|
|||
|
||||
logger.info("SmolagentsEnv setup complete")
|
||||
|
||||
# _create_tools method removed - tools are created directly in agent_process_runner.py
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next item from the GAIA dataset."""
|
||||
if not self.examples:
|
||||
|
|
@ -414,9 +397,9 @@ class SmolagentsEnv(BaseEnv):
|
|||
)
|
||||
else:
|
||||
# Just log the error and omit this example from training
|
||||
error_message = result.get('error_message', 'Unknown error')
|
||||
task_id = result.get('task_id', 'Unknown task')
|
||||
|
||||
error_message = result.get("error_message", "Unknown error")
|
||||
task_id = result.get("task_id", "Unknown task")
|
||||
|
||||
logger.warning(
|
||||
f"Omitting failed task {task_id} from training batch: {error_message}"
|
||||
)
|
||||
|
|
@ -493,7 +476,7 @@ class SmolagentsEnv(BaseEnv):
|
|||
"""
|
||||
Score the agent trajectory based on multiple criteria:
|
||||
- Answer correctness using GAIA scoring
|
||||
- Message format adherence
|
||||
- Message format adherence
|
||||
- Final answer tool usage
|
||||
- Execution success (detection of errors)
|
||||
- Efficiency (steps only)
|
||||
|
|
@ -504,30 +487,35 @@ class SmolagentsEnv(BaseEnv):
|
|||
true_answer: The ground truth answer
|
||||
agent_memory: The memory trace of the agent's steps
|
||||
execution_time: Time taken for execution (not used in scoring)
|
||||
|
||||
|
||||
Returns:
|
||||
float: A score between 0.0 and 1.0
|
||||
"""
|
||||
# Import all scoring functions upfront
|
||||
from environments.smolagents_integration.evaluations.smolagent_integrations.rubrics.gaia_scorer import (
|
||||
question_scorer, check_close_call
|
||||
from .evaluations.smolagent_integrations.rubrics.gaia_scorer import (
|
||||
check_close_call,
|
||||
question_scorer,
|
||||
)
|
||||
from environments.smolagents_integration.evaluations.smolagent_integrations.smolagents_scorer import (
|
||||
check_format_adherence, check_final_answer_usage,
|
||||
calculate_execution_score, calculate_efficiency_score
|
||||
from .evaluations.smolagent_integrations.smolagents_scorer import (
|
||||
calculate_efficiency_score,
|
||||
calculate_execution_score,
|
||||
check_final_answer_usage,
|
||||
check_format_adherence,
|
||||
)
|
||||
|
||||
|
||||
# Initialize component scores
|
||||
correctness_score = 0.0
|
||||
format_score = 0.0
|
||||
final_answer_score = 0.0
|
||||
execution_score = 0.0
|
||||
efficiency_score = 0.0
|
||||
|
||||
|
||||
# 1. Calculate correctness score using GAIA scorer
|
||||
# Ensure agent_response is a string before scoring
|
||||
if not isinstance(agent_response, str):
|
||||
logger.warning(f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}")
|
||||
logger.warning(
|
||||
f"agent_response is not a string, it's a {type(agent_response)}: {agent_response}"
|
||||
)
|
||||
try:
|
||||
if isinstance(agent_response, set):
|
||||
# Convert sets to comma-separated strings
|
||||
|
|
@ -538,15 +526,15 @@ class SmolagentsEnv(BaseEnv):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to convert agent_response to string: {e}")
|
||||
agent_response = ""
|
||||
|
||||
|
||||
is_correct = question_scorer(agent_response, true_answer)
|
||||
is_near_correct = check_close_call(agent_response, true_answer, is_correct)
|
||||
|
||||
|
||||
if is_correct:
|
||||
correctness_score = 1.0
|
||||
elif is_near_correct:
|
||||
correctness_score = 0.5
|
||||
|
||||
|
||||
# 2. Check format adherence
|
||||
if agent_memory:
|
||||
format_scores = []
|
||||
|
|
@ -555,10 +543,12 @@ class SmolagentsEnv(BaseEnv):
|
|||
format_scores.append(check_format_adherence(step["content"]))
|
||||
elif "model_output" in step and isinstance(step["model_output"], str):
|
||||
format_scores.append(check_format_adherence(step["model_output"]))
|
||||
|
||||
|
||||
# Average the format scores across all steps
|
||||
format_score = sum(format_scores) / len(format_scores) if format_scores else 0.0
|
||||
|
||||
format_score = (
|
||||
sum(format_scores) / len(format_scores) if format_scores else 0.0
|
||||
)
|
||||
|
||||
# 3. Check for final_answer tool usage
|
||||
final_answer_used = False
|
||||
for step in agent_memory:
|
||||
|
|
@ -570,41 +560,44 @@ class SmolagentsEnv(BaseEnv):
|
|||
if check_final_answer_usage(step["model_output"]):
|
||||
final_answer_used = True
|
||||
break
|
||||
|
||||
|
||||
final_answer_score = 1.0 if final_answer_used else 0.0
|
||||
|
||||
|
||||
# 4. Check for execution errors and calculate execution score
|
||||
execution_score = calculate_execution_score(agent_memory)
|
||||
|
||||
|
||||
# 5. Calculate efficiency score
|
||||
steps_count = len(agent_memory)
|
||||
efficiency_score = calculate_efficiency_score(
|
||||
steps_count=steps_count,
|
||||
max_steps=self.max_steps
|
||||
steps_count=steps_count, max_steps=self.max_steps
|
||||
)
|
||||
|
||||
|
||||
# Component weights - can be adjusted to emphasize different aspects
|
||||
correctness_weight = 0.50 # 50% of score - correctness matters most
|
||||
format_weight = 0.20 # 20% - format adherence is important
|
||||
format_weight = 0.20 # 20% - format adherence is important
|
||||
final_answer_weight = 0.10 # 10% - using final_answer tool properly
|
||||
execution_weight = 0.10 # 10% - avoiding errors
|
||||
efficiency_weight = 0.10 # 10% - being efficient
|
||||
|
||||
execution_weight = 0.10 # 10% - avoiding errors
|
||||
efficiency_weight = 0.10 # 10% - being efficient
|
||||
|
||||
# Calculate combined score with weights
|
||||
combined_score = (
|
||||
correctness_score * correctness_weight +
|
||||
format_score * format_weight +
|
||||
final_answer_score * final_answer_weight +
|
||||
execution_score * execution_weight +
|
||||
efficiency_score * efficiency_weight
|
||||
correctness_score * correctness_weight
|
||||
+ format_score * format_weight
|
||||
+ final_answer_score * final_answer_weight
|
||||
+ execution_score * execution_weight
|
||||
+ efficiency_score * efficiency_weight
|
||||
)
|
||||
|
||||
|
||||
# Apply length penalty if configured
|
||||
length_penalty = 0.0
|
||||
if self.config.length_penalty_weight > 0:
|
||||
# agent_response should already be a string from the previous conversion
|
||||
# but double-check just to be sure
|
||||
response_to_measure = agent_response if isinstance(agent_response, str) else str(agent_response)
|
||||
response_to_measure = (
|
||||
agent_response
|
||||
if isinstance(agent_response, str)
|
||||
else str(agent_response)
|
||||
)
|
||||
response_length = len(response_to_measure)
|
||||
# Penalize very long responses
|
||||
if response_length > 2000:
|
||||
|
|
@ -613,57 +606,75 @@ class SmolagentsEnv(BaseEnv):
|
|||
self.config.length_penalty_weight * (response_length - 2000) / 1000,
|
||||
)
|
||||
combined_score = max(0.0, combined_score - length_penalty)
|
||||
|
||||
|
||||
# Debug logging for score calculation
|
||||
if self.debug_scoring:
|
||||
logger.info("=== SCORE CALCULATION (detailed) ===")
|
||||
logger.info(f"1. Correctness component:")
|
||||
logger.info("1. Correctness component:")
|
||||
logger.info(f" - True answer: '{true_answer}'")
|
||||
logger.info(f" - Agent answer: '{agent_response}'")
|
||||
logger.info(f" - Is correct: {is_correct}")
|
||||
logger.info(f" - Is near correct: {is_near_correct}")
|
||||
logger.info(f" - Raw score: {correctness_score:.3f}")
|
||||
logger.info(f" - Weight: {correctness_weight}")
|
||||
logger.info(f" - Weighted score: {correctness_score * correctness_weight:.3f}")
|
||||
|
||||
logger.info(f"2. Format adherence component:")
|
||||
logger.info(f" - Format scores: {format_scores if 'format_scores' in locals() else []}")
|
||||
logger.info(
|
||||
f" - Weighted score: {correctness_score * correctness_weight:.3f}"
|
||||
)
|
||||
|
||||
logger.info("2. Format adherence component:")
|
||||
logger.info(
|
||||
f" - Format scores: {format_scores if 'format_scores' in locals() else []}"
|
||||
)
|
||||
logger.info(f" - Raw score: {format_score:.3f}")
|
||||
logger.info(f" - Weight: {format_weight}")
|
||||
logger.info(f" - Weighted score: {format_score * format_weight:.3f}")
|
||||
|
||||
logger.info(f"3. Final answer tool usage:")
|
||||
logger.info(f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}")
|
||||
|
||||
logger.info("3. Final answer tool usage:")
|
||||
logger.info(
|
||||
f" - Final answer tool used: {final_answer_used if 'final_answer_used' in locals() else False}"
|
||||
)
|
||||
logger.info(f" - Raw score: {final_answer_score:.3f}")
|
||||
logger.info(f" - Weight: {final_answer_weight}")
|
||||
logger.info(f" - Weighted score: {final_answer_score * final_answer_weight:.3f}")
|
||||
|
||||
logger.info(f"4. Execution component:")
|
||||
logger.info(
|
||||
f" - Weighted score: {final_answer_score * final_answer_weight:.3f}"
|
||||
)
|
||||
|
||||
logger.info("4. Execution component:")
|
||||
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
|
||||
logger.info(f" - Raw score: {execution_score:.3f}")
|
||||
logger.info(f" - Weight: {execution_weight}")
|
||||
logger.info(f" - Weighted score: {execution_score * execution_weight:.3f}")
|
||||
|
||||
logger.info(f"5. Efficiency component:")
|
||||
logger.info(
|
||||
f" - Weighted score: {execution_score * execution_weight:.3f}"
|
||||
)
|
||||
|
||||
logger.info("5. Efficiency component:")
|
||||
logger.info(f" - Steps count: {len(agent_memory) if agent_memory else 0}")
|
||||
logger.info(f" - Max steps: {self.max_steps}")
|
||||
logger.info(f" - Raw score: {efficiency_score:.3f}")
|
||||
logger.info(f" - Weight: {efficiency_weight}")
|
||||
logger.info(f" - Weighted score: {efficiency_score * efficiency_weight:.3f}")
|
||||
|
||||
logger.info(f"6. Length penalty:")
|
||||
logger.info(f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}")
|
||||
logger.info(
|
||||
f" - Weighted score: {efficiency_score * efficiency_weight:.3f}"
|
||||
)
|
||||
|
||||
logger.info("6. Length penalty:")
|
||||
logger.info(
|
||||
f" - Response length: {len(agent_response) if isinstance(agent_response, str) else 'N/A'}"
|
||||
)
|
||||
logger.info(f" - Penalty: {length_penalty:.3f}")
|
||||
|
||||
logger.info(f"7. Final score calculation:")
|
||||
logger.info(f" - Correctness: {correctness_score * correctness_weight:.3f}")
|
||||
|
||||
logger.info("7. Final score calculation:")
|
||||
logger.info(
|
||||
f" - Correctness: {correctness_score * correctness_weight:.3f}"
|
||||
)
|
||||
logger.info(f" - Format adherence: {format_score * format_weight:.3f}")
|
||||
logger.info(f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}")
|
||||
logger.info(
|
||||
f" - Final answer tool: {final_answer_score * final_answer_weight:.3f}"
|
||||
)
|
||||
logger.info(f" - Execution: {execution_score * execution_weight:.3f}")
|
||||
logger.info(f" - Efficiency: {efficiency_score * efficiency_weight:.3f}")
|
||||
logger.info(f" - Length penalty: -{length_penalty:.3f}")
|
||||
logger.info(f" - FINAL SCORE: {combined_score:.3f}")
|
||||
|
||||
|
||||
return combined_score
|
||||
|
||||
def _create_scored_data_group(
|
||||
|
|
@ -704,76 +715,82 @@ class SmolagentsEnv(BaseEnv):
|
|||
# Create a comprehensive markdown document for HTML compatibility
|
||||
# Format similar to the Wikipedia environment
|
||||
complete_conversation = []
|
||||
|
||||
|
||||
# Add task information at the top
|
||||
task_type = item.metadata.get("task", "Unknown task")
|
||||
task_id = item.metadata.get("task_id", "Unknown ID")
|
||||
complete_conversation.append(f"# GAIA Task: {task_type} (ID: {task_id})")
|
||||
|
||||
|
||||
# Process each message in the conversation
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not content:
|
||||
continue
|
||||
|
||||
|
||||
# Add a header for each role
|
||||
complete_conversation.append(f"## {role.upper()}")
|
||||
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, list):
|
||||
content_text = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
content_text.append(item['text'])
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
content_text.append(item["text"])
|
||||
else:
|
||||
content_text.append(str(item))
|
||||
content = '\n'.join(content_text)
|
||||
content = "\n".join(content_text)
|
||||
elif not isinstance(content, str):
|
||||
# Handle non-string content
|
||||
content = str(content)
|
||||
|
||||
|
||||
# Add the message content
|
||||
complete_conversation.append(content)
|
||||
|
||||
|
||||
# If this message has agent memory, show it
|
||||
if role == "assistant" and i > 1 and self.config.save_full_traces:
|
||||
# Agent memory with thinking is often present in assistant messages
|
||||
if "<thinking>" in content:
|
||||
complete_conversation.append("### Thinking Process")
|
||||
# The thinking is already in the content
|
||||
|
||||
|
||||
# Add tool usage information if present
|
||||
if "tool_usage" in msg:
|
||||
tool_name = msg.get("tool_usage", {}).get("name", "unknown tool")
|
||||
tool_name = msg.get("tool_usage", {}).get(
|
||||
"name", "unknown tool"
|
||||
)
|
||||
complete_conversation.append(f"### 🛠️ Tool Used: {tool_name}")
|
||||
tool_args = msg.get("tool_usage", {}).get("args", {})
|
||||
if tool_args:
|
||||
complete_conversation.append("```json")
|
||||
complete_conversation.append(json.dumps(tool_args, indent=2))
|
||||
complete_conversation.append(
|
||||
json.dumps(tool_args, indent=2)
|
||||
)
|
||||
complete_conversation.append("```")
|
||||
|
||||
|
||||
tool_result = msg.get("tool_usage", {}).get("result", None)
|
||||
if tool_result:
|
||||
complete_conversation.append("### Tool Result")
|
||||
complete_conversation.append("```")
|
||||
complete_conversation.append(str(tool_result))
|
||||
complete_conversation.append("```")
|
||||
|
||||
|
||||
# Add score information at the end
|
||||
complete_conversation.append(f"\n## Score: {scored_data['score']:.4f}")
|
||||
|
||||
|
||||
# Join everything into a single string with double newlines between sections
|
||||
full_conversation_markdown = "\n\n".join(complete_conversation)
|
||||
|
||||
|
||||
# Create the ScoredDataGroup with a single comprehensive markdown document
|
||||
scored_group = ScoredDataGroup(
|
||||
tokens=[self.tokenizer.encode(json.dumps(messages))],
|
||||
masks=[[1] * len(self.tokenizer.encode(json.dumps(messages)))],
|
||||
scores=[scored_data["score"]],
|
||||
messages=[full_conversation_markdown], # Use single markdown document for HTML
|
||||
messages=[
|
||||
full_conversation_markdown
|
||||
], # Use single markdown document for HTML
|
||||
_original_messages=[messages], # Keep original for trainer API
|
||||
)
|
||||
|
||||
|
|
@ -787,16 +804,16 @@ class SmolagentsEnv(BaseEnv):
|
|||
{"role": "user", "content": item.prompt},
|
||||
{"role": "assistant", "content": scored_data["response"]},
|
||||
]
|
||||
|
||||
|
||||
# Use the standard tokenize_for_trainer utility
|
||||
|
||||
|
||||
# Tokenize using the standard utility (only trains on assistant messages)
|
||||
tokenized = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
self.tokenizer,
|
||||
messages,
|
||||
train_on_all_assistant_turns=True # Train on all assistant turns if present
|
||||
train_on_all_assistant_turns=True, # Train on all assistant turns if present
|
||||
)
|
||||
|
||||
|
||||
# Create the ScoredDataGroup
|
||||
scored_group = ScoredDataGroup(
|
||||
tokens=[tokenized["tokens"]],
|
||||
|
|
@ -807,7 +824,6 @@ class SmolagentsEnv(BaseEnv):
|
|||
|
||||
return scored_group
|
||||
|
||||
|
||||
async def evaluate(self, **kwargs):
|
||||
"""
|
||||
Evaluate the agent on a subset of the GAIA benchmark.
|
||||
|
|
|
|||
|
|
@ -4,11 +4,10 @@ Process-safe implementation of the AtroposServerModel for SmolaGents.
|
|||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from smolagents.models import ChatMessage, MessageRole, Model
|
||||
|
||||
from environments.smolagents_integration.server_proxy import ServerProxy
|
||||
from .server_proxy import ServerProxy
|
||||
|
||||
# Configure logger for the model class
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -45,7 +44,8 @@ class ProcessSafeAtroposServerModel(Model):
|
|||
|
||||
# Log the configuration
|
||||
logger.info(
|
||||
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, use_chat_completion={self.use_chat_completion}"
|
||||
f"Initializing ProcessSafeAtroposServerModel with model_id={model_id}, "
|
||||
f"use_chat_completion={self.use_chat_completion}"
|
||||
)
|
||||
|
||||
super().__init__(model_id=model_id, **kwargs)
|
||||
|
|
@ -117,7 +117,12 @@ class ProcessSafeAtroposServerModel(Model):
|
|||
openai_role = "user"
|
||||
elif role_str == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role_str in ("tool_call", "tool_response", "function_call", "function_response"):
|
||||
elif role_str in (
|
||||
"tool_call",
|
||||
"tool_response",
|
||||
"function_call",
|
||||
"function_response",
|
||||
):
|
||||
# Silently map tool and function calls/responses to user roles
|
||||
openai_role = "user"
|
||||
else:
|
||||
|
|
@ -172,8 +177,8 @@ class ProcessSafeAtroposServerModel(Model):
|
|||
messages=messages, stop_sequences=stop_sequences, **kwargs
|
||||
)
|
||||
|
||||
# Extract timeout from kwargs or use default
|
||||
timeout = kwargs.pop("timeout", 120) # Default 2 minutes
|
||||
# Extract timeout from kwargs or use default (but not used in this method)
|
||||
kwargs.pop("timeout", 120) # Default 2 minutes
|
||||
|
||||
try:
|
||||
# Use chat_completion if configured
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue