mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
bc8fde8771
commit
47c68f06f2
7 changed files with 78 additions and 49 deletions
|
|
@ -1,3 +1,3 @@
|
||||||
"""
|
"""
|
||||||
SmolaGents evaluation utilities for Atropos integrations.
|
SmolaGents evaluation utilities for Atropos integrations.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
"""
|
"""
|
||||||
Scoring rubrics for SmolaGents integrations.
|
Scoring rubrics for SmolaGents integrations.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,10 @@ def question_scorer(
|
||||||
try:
|
try:
|
||||||
model_answer = str(model_answer)
|
model_answer = str(model_answer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"Failed to convert model_answer to string: {e}. Type: {type(model_answer)}", UserWarning)
|
warnings.warn(
|
||||||
|
f"Failed to convert model_answer to string: {e}. Type: {type(model_answer)}",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# if gt is a number
|
# if gt is a number
|
||||||
|
|
@ -73,7 +76,9 @@ def question_scorer(
|
||||||
|
|
||||||
# check length is the same
|
# check length is the same
|
||||||
if len(gt_elems) != len(ma_elems):
|
if len(gt_elems) != len(ma_elems):
|
||||||
warnings.warn("Answer lists have different lengths, returning False.", UserWarning)
|
warnings.warn(
|
||||||
|
"Answer lists have different lengths, returning False.", UserWarning
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# compare each element as float or str
|
# compare each element as float or str
|
||||||
|
|
@ -85,7 +90,8 @@ def question_scorer(
|
||||||
else:
|
else:
|
||||||
# we do not remove punct since comparisons can include punct
|
# we do not remove punct since comparisons can include punct
|
||||||
comparisons.append(
|
comparisons.append(
|
||||||
normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)
|
normalize_str(ma_elem, remove_punct=False)
|
||||||
|
== normalize_str(gt_elem, remove_punct=False)
|
||||||
)
|
)
|
||||||
return all(comparisons)
|
return all(comparisons)
|
||||||
|
|
||||||
|
|
@ -116,8 +122,12 @@ def check_close_call(prediction, true_answer, is_correct):
|
||||||
return is_correct
|
return is_correct
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
check_prediction_contains_answer_letters_in_order(str(prediction), str(true_answer))
|
check_prediction_contains_answer_letters_in_order(
|
||||||
and len(str(true_answer)) * 0.5 <= len(str(prediction)) <= len(str(true_answer)) * 2
|
str(prediction), str(true_answer)
|
||||||
|
)
|
||||||
|
and len(str(true_answer)) * 0.5
|
||||||
|
<= len(str(prediction))
|
||||||
|
<= len(str(true_answer)) * 2
|
||||||
):
|
):
|
||||||
# Remove print statement that causes duplicated output
|
# Remove print statement that causes duplicated output
|
||||||
return True
|
return True
|
||||||
|
|
@ -142,9 +152,12 @@ def normalize_str(input_str, remove_punct=True) -> str:
|
||||||
try:
|
try:
|
||||||
input_str = str(input_str)
|
input_str = str(input_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"Failed to convert input to string: {e}. Type: {type(input_str)}", UserWarning)
|
warnings.warn(
|
||||||
|
f"Failed to convert input to string: {e}. Type: {type(input_str)}",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
||||||
no_spaces = re.sub(r"\s", "", input_str)
|
no_spaces = re.sub(r"\s", "", input_str)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,32 +7,33 @@ execution error detection, and efficiency metrics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def check_format_adherence(memory_content: str) -> float:
|
def check_format_adherence(memory_content: str) -> float:
|
||||||
"""
|
"""
|
||||||
Check if memory content follows the required CodeAgent format.
|
Check if memory content follows the required CodeAgent format.
|
||||||
|
|
||||||
The expected format includes:
|
The expected format includes:
|
||||||
- "Thought:" section with reasoning
|
- "Thought:" section with reasoning
|
||||||
- "Code:" section with a Python code block
|
- "Code:" section with a Python code block
|
||||||
- Code blocks with triple backticks and "<end_code>" marker
|
- Code blocks with triple backticks and "<end_code>" marker
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_content: The content of a memory step to check
|
memory_content: The content of a memory step to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Score between 0.0 and 1.0 indicating format compliance
|
float: Score between 0.0 and 1.0 indicating format compliance
|
||||||
"""
|
"""
|
||||||
thought_pattern = r"Thought: .+"
|
thought_pattern = r"Thought: .+"
|
||||||
code_pattern = r"Code:\s*```py\s*[\s\S]*?```<end_code>"
|
code_pattern = r"Code:\s*```py\s*[\s\S]*?```<end_code>"
|
||||||
|
|
||||||
# Check if both patterns exist in the content
|
# Check if both patterns exist in the content
|
||||||
has_thought = bool(re.search(thought_pattern, memory_content))
|
has_thought = bool(re.search(thought_pattern, memory_content))
|
||||||
has_code = bool(re.search(code_pattern, memory_content))
|
has_code = bool(re.search(code_pattern, memory_content))
|
||||||
|
|
||||||
if has_thought and has_code:
|
if has_thought and has_code:
|
||||||
return 1.0
|
return 1.0
|
||||||
elif has_thought or has_code:
|
elif has_thought or has_code:
|
||||||
|
|
@ -44,10 +45,10 @@ def check_format_adherence(memory_content: str) -> float:
|
||||||
def check_final_answer_usage(memory_content: str) -> bool:
|
def check_final_answer_usage(memory_content: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the final_answer tool was used appropriately.
|
Check if the final_answer tool was used appropriately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_content: The content of a memory step to check
|
memory_content: The content of a memory step to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if final_answer tool was used, False otherwise
|
bool: True if final_answer tool was used, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -58,101 +59,108 @@ def check_final_answer_usage(memory_content: str) -> bool:
|
||||||
def extract_execution_errors(agent_memory: List[Dict]) -> List[Dict]:
|
def extract_execution_errors(agent_memory: List[Dict]) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Extract execution errors from agent memory.
|
Extract execution errors from agent memory.
|
||||||
|
|
||||||
Looks for error patterns in the observations field of each memory step.
|
Looks for error patterns in the observations field of each memory step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_memory: List of memory steps from agent execution
|
agent_memory: List of memory steps from agent execution
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict]: List of errors with step number and error message
|
List[Dict]: List of errors with step number and error message
|
||||||
"""
|
"""
|
||||||
execution_errors = []
|
execution_errors = []
|
||||||
|
|
||||||
if not agent_memory:
|
if not agent_memory:
|
||||||
return execution_errors
|
return execution_errors
|
||||||
|
|
||||||
for step in agent_memory:
|
for step in agent_memory:
|
||||||
# In SmolaGents ActionStep, observations field contains execution output
|
# In SmolaGents ActionStep, observations field contains execution output
|
||||||
if isinstance(step, dict) and "observations" in step and isinstance(step["observations"], str):
|
if (
|
||||||
|
isinstance(step, dict)
|
||||||
|
and "observations" in step
|
||||||
|
and isinstance(step["observations"], str)
|
||||||
|
):
|
||||||
observation = step["observations"]
|
observation = step["observations"]
|
||||||
|
|
||||||
# Look for error patterns
|
# Look for error patterns
|
||||||
error_patterns = [
|
error_patterns = [
|
||||||
r"Error: .*",
|
r"Error: .*",
|
||||||
r"Exception: .*",
|
r"Exception: .*",
|
||||||
r"Traceback \(most recent call last\).*",
|
r"Traceback \(most recent call last\).*",
|
||||||
r".*Error: .*",
|
r".*Error: .*",
|
||||||
r".*Exception: .*"
|
r".*Exception: .*",
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in error_patterns:
|
for pattern in error_patterns:
|
||||||
matches = re.findall(pattern, observation, re.DOTALL)
|
matches = re.findall(pattern, observation, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
# Record step number and error
|
# Record step number and error
|
||||||
execution_errors.append({
|
execution_errors.append(
|
||||||
"step": step.get("step_number", 0),
|
{"step": step.get("step_number", 0), "error": matches[0]}
|
||||||
"error": matches[0]
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return execution_errors
|
return execution_errors
|
||||||
|
|
||||||
|
|
||||||
def calculate_efficiency_score(
|
def calculate_efficiency_score(
|
||||||
steps_count: int,
|
steps_count: int,
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
execution_time: float = None, # Parameter kept for backward compatibility but not used
|
execution_time: float = None, # Parameter kept for backward compatibility but not used
|
||||||
execution_times_history: Optional[List[float]] = None # Parameter kept for backward compatibility but not used
|
execution_times_history: Optional[
|
||||||
|
List[float]
|
||||||
|
] = None, # Parameter kept for backward compatibility but not used
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate efficiency score based on steps used only.
|
Calculate efficiency score based on steps used only.
|
||||||
Execution time is no longer considered in the score calculation.
|
Execution time is no longer considered in the score calculation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
steps_count: Number of steps taken by the agent
|
steps_count: Number of steps taken by the agent
|
||||||
max_steps: Maximum allowed steps
|
max_steps: Maximum allowed steps
|
||||||
execution_time: Not used, kept for backward compatibility
|
execution_time: Not used, kept for backward compatibility
|
||||||
execution_times_history: Not used, kept for backward compatibility
|
execution_times_history: Not used, kept for backward compatibility
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Efficiency score between 0.0 and 1.0
|
float: Efficiency score between 0.0 and 1.0
|
||||||
"""
|
"""
|
||||||
# Start with full efficiency score
|
# Start with full efficiency score
|
||||||
efficiency_score = 1.0
|
efficiency_score = 1.0
|
||||||
|
|
||||||
# Penalty for excessive steps (above 75% of max)
|
# Penalty for excessive steps (above 75% of max)
|
||||||
step_penalty = 1.0
|
step_penalty = 1.0
|
||||||
if steps_count > (max_steps * 0.75):
|
if steps_count > (max_steps * 0.75):
|
||||||
step_penalty = max(0.5, 1.0 - ((steps_count - max_steps * 0.75) / (max_steps * 0.25)))
|
step_penalty = max(
|
||||||
|
0.5, 1.0 - ((steps_count - max_steps * 0.75) / (max_steps * 0.25))
|
||||||
|
)
|
||||||
efficiency_score *= step_penalty
|
efficiency_score *= step_penalty
|
||||||
|
|
||||||
# Note: Execution time penalty has been removed
|
# Note: Execution time penalty has been removed
|
||||||
|
|
||||||
return efficiency_score
|
return efficiency_score
|
||||||
|
|
||||||
|
|
||||||
def calculate_execution_score(agent_memory: List[Dict]) -> float:
|
def calculate_execution_score(agent_memory: List[Dict]) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate execution success score by detecting errors in agent memory.
|
Calculate execution success score by detecting errors in agent memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_memory: List of memory steps from agent execution
|
agent_memory: List of memory steps from agent execution
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Execution score between 0.0 and 1.0
|
float: Execution score between 0.0 and 1.0
|
||||||
"""
|
"""
|
||||||
execution_errors = extract_execution_errors(agent_memory)
|
execution_errors = extract_execution_errors(agent_memory)
|
||||||
|
|
||||||
if not agent_memory:
|
if not agent_memory:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
total_steps = len(agent_memory)
|
total_steps = len(agent_memory)
|
||||||
error_steps = len(execution_errors)
|
error_steps = len(execution_errors)
|
||||||
|
|
||||||
if total_steps > 0:
|
if total_steps > 0:
|
||||||
# Penalize proportionally to the number of steps with errors
|
# Penalize proportionally to the number of steps with errors
|
||||||
execution_score = max(0, 1.0 - (error_steps / total_steps))
|
execution_score = max(0, 1.0 - (error_steps / total_steps))
|
||||||
else:
|
else:
|
||||||
execution_score = 0.0
|
execution_score = 0.0
|
||||||
|
|
||||||
return execution_score
|
return execution_score
|
||||||
|
|
|
||||||
|
|
@ -777,7 +777,11 @@ class SmolagentsEnv(BaseEnv):
|
||||||
# Handle both dict and ChatMessage objects
|
# Handle both dict and ChatMessage objects
|
||||||
if hasattr(message, "role") and hasattr(message, "content"):
|
if hasattr(message, "role") and hasattr(message, "content"):
|
||||||
# Convert ChatMessage to dict
|
# Convert ChatMessage to dict
|
||||||
role = message.role.value if hasattr(message.role, "value") else str(message.role)
|
role = (
|
||||||
|
message.role.value
|
||||||
|
if hasattr(message.role, "value")
|
||||||
|
else str(message.role)
|
||||||
|
)
|
||||||
messages.append({"role": role, "content": message.content})
|
messages.append({"role": role, "content": message.content})
|
||||||
elif isinstance(message, dict):
|
elif isinstance(message, dict):
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from smolagents import tool
|
from smolagents import tool
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def read_file(file_path: str) -> str:
|
def read_file(file_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -19,6 +21,7 @@ def read_file(file_path: str) -> str:
|
||||||
print(content)
|
print(content)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def write_file(file_path: str, content: str) -> str:
|
def write_file(file_path: str, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -36,6 +39,7 @@ def write_file(file_path: str, content: str) -> str:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
return f"Content written to {file_path}"
|
return f"Content written to {file_path}"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def append_to_file(file_path: str, content: str) -> str:
|
def append_to_file(file_path: str, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -84,4 +84,4 @@ else
|
||||||
echo ""
|
echo ""
|
||||||
echo "ERROR: Test failed!"
|
echo "ERROR: Test failed!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue