mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
bc8fde8771
commit
47c68f06f2
7 changed files with 78 additions and 49 deletions
|
|
@ -56,7 +56,10 @@ def question_scorer(
|
|||
try:
|
||||
model_answer = str(model_answer)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to convert model_answer to string: {e}. Type: {type(model_answer)}", UserWarning)
|
||||
warnings.warn(
|
||||
f"Failed to convert model_answer to string: {e}. Type: {type(model_answer)}",
|
||||
UserWarning,
|
||||
)
|
||||
return False
|
||||
|
||||
# if gt is a number
|
||||
|
|
@ -73,7 +76,9 @@ def question_scorer(
|
|||
|
||||
# check length is the same
|
||||
if len(gt_elems) != len(ma_elems):
|
||||
warnings.warn("Answer lists have different lengths, returning False.", UserWarning)
|
||||
warnings.warn(
|
||||
"Answer lists have different lengths, returning False.", UserWarning
|
||||
)
|
||||
return False
|
||||
|
||||
# compare each element as float or str
|
||||
|
|
@ -85,7 +90,8 @@ def question_scorer(
|
|||
else:
|
||||
# we do not remove punct since comparisons can include punct
|
||||
comparisons.append(
|
||||
normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)
|
||||
normalize_str(ma_elem, remove_punct=False)
|
||||
== normalize_str(gt_elem, remove_punct=False)
|
||||
)
|
||||
return all(comparisons)
|
||||
|
||||
|
|
@ -116,8 +122,12 @@ def check_close_call(prediction, true_answer, is_correct):
|
|||
return is_correct
|
||||
else:
|
||||
if (
|
||||
check_prediction_contains_answer_letters_in_order(str(prediction), str(true_answer))
|
||||
and len(str(true_answer)) * 0.5 <= len(str(prediction)) <= len(str(true_answer)) * 2
|
||||
check_prediction_contains_answer_letters_in_order(
|
||||
str(prediction), str(true_answer)
|
||||
)
|
||||
and len(str(true_answer)) * 0.5
|
||||
<= len(str(prediction))
|
||||
<= len(str(true_answer)) * 2
|
||||
):
|
||||
# Remove print statement that causes duplicated output
|
||||
return True
|
||||
|
|
@ -142,7 +152,10 @@ def normalize_str(input_str, remove_punct=True) -> str:
|
|||
try:
|
||||
input_str = str(input_str)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to convert input to string: {e}. Type: {type(input_str)}", UserWarning)
|
||||
warnings.warn(
|
||||
f"Failed to convert input to string: {e}. Type: {type(input_str)}",
|
||||
UserWarning,
|
||||
)
|
||||
return ""
|
||||
|
||||
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ execution error detection, and efficiency metrics.
|
|||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -74,7 +75,11 @@ def extract_execution_errors(agent_memory: List[Dict]) -> List[Dict]:
|
|||
|
||||
for step in agent_memory:
|
||||
# In SmolaGents ActionStep, observations field contains execution output
|
||||
if isinstance(step, dict) and "observations" in step and isinstance(step["observations"], str):
|
||||
if (
|
||||
isinstance(step, dict)
|
||||
and "observations" in step
|
||||
and isinstance(step["observations"], str)
|
||||
):
|
||||
observation = step["observations"]
|
||||
|
||||
# Look for error patterns
|
||||
|
|
@ -83,17 +88,16 @@ def extract_execution_errors(agent_memory: List[Dict]) -> List[Dict]:
|
|||
r"Exception: .*",
|
||||
r"Traceback \(most recent call last\).*",
|
||||
r".*Error: .*",
|
||||
r".*Exception: .*"
|
||||
r".*Exception: .*",
|
||||
]
|
||||
|
||||
for pattern in error_patterns:
|
||||
matches = re.findall(pattern, observation, re.DOTALL)
|
||||
if matches:
|
||||
# Record step number and error
|
||||
execution_errors.append({
|
||||
"step": step.get("step_number", 0),
|
||||
"error": matches[0]
|
||||
})
|
||||
execution_errors.append(
|
||||
{"step": step.get("step_number", 0), "error": matches[0]}
|
||||
)
|
||||
|
||||
return execution_errors
|
||||
|
||||
|
|
@ -102,7 +106,9 @@ def calculate_efficiency_score(
|
|||
steps_count: int,
|
||||
max_steps: int,
|
||||
execution_time: float = None, # Parameter kept for backward compatibility but not used
|
||||
execution_times_history: Optional[List[float]] = None # Parameter kept for backward compatibility but not used
|
||||
execution_times_history: Optional[
|
||||
List[float]
|
||||
] = None, # Parameter kept for backward compatibility but not used
|
||||
) -> float:
|
||||
"""
|
||||
Calculate efficiency score based on steps used only.
|
||||
|
|
@ -123,7 +129,9 @@ def calculate_efficiency_score(
|
|||
# Penalty for excessive steps (above 75% of max)
|
||||
step_penalty = 1.0
|
||||
if steps_count > (max_steps * 0.75):
|
||||
step_penalty = max(0.5, 1.0 - ((steps_count - max_steps * 0.75) / (max_steps * 0.25)))
|
||||
step_penalty = max(
|
||||
0.5, 1.0 - ((steps_count - max_steps * 0.75) / (max_steps * 0.25))
|
||||
)
|
||||
efficiency_score *= step_penalty
|
||||
|
||||
# Note: Execution time penalty has been removed
|
||||
|
|
|
|||
|
|
@ -777,7 +777,11 @@ class SmolagentsEnv(BaseEnv):
|
|||
# Handle both dict and ChatMessage objects
|
||||
if hasattr(message, "role") and hasattr(message, "content"):
|
||||
# Convert ChatMessage to dict
|
||||
role = message.role.value if hasattr(message.role, "value") else str(message.role)
|
||||
role = (
|
||||
message.role.value
|
||||
if hasattr(message.role, "value")
|
||||
else str(message.role)
|
||||
)
|
||||
messages.append({"role": role, "content": message.content})
|
||||
elif isinstance(message, dict):
|
||||
messages.append(message)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
|
||||
from smolagents import tool
|
||||
|
||||
|
||||
@tool
|
||||
def read_file(file_path: str) -> str:
|
||||
"""
|
||||
|
|
@ -19,6 +21,7 @@ def read_file(file_path: str) -> str:
|
|||
print(content)
|
||||
return content
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(file_path: str, content: str) -> str:
|
||||
"""
|
||||
|
|
@ -36,6 +39,7 @@ def write_file(file_path: str, content: str) -> str:
|
|||
f.write(content)
|
||||
return f"Content written to {file_path}"
|
||||
|
||||
|
||||
@tool
|
||||
def append_to_file(file_path: str, content: str) -> str:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue