mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
This adds the SmolaGents integration to Atropos, enabling the creation of high-quality agent trajectories for training data.
This commit is contained in:
parent
bef6a0b99a
commit
7710e151cc
13 changed files with 2534 additions and 0 deletions
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Scoring rubrics for SmolaGents integrations.
|
||||
"""
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
# This file is from the smolagents project: https://github.com/huggingface/smolagents
|
||||
#
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import string
|
||||
import warnings
|
||||
|
||||
|
||||
def normalize_number_str(number_str: str) -> float:
|
||||
# we replace these common units and commas to allow
|
||||
# conversion to float
|
||||
for char in ["$", "%", ","]:
|
||||
number_str = number_str.replace(char, "")
|
||||
try:
|
||||
return float(number_str)
|
||||
except ValueError:
|
||||
# Remove print statement to avoid duplicate output
|
||||
return float("inf")
|
||||
|
||||
|
||||
def split_string(
|
||||
s: str,
|
||||
char_list: list[str] = [",", ";"],
|
||||
) -> list[str]:
|
||||
pattern = f"[{''.join(char_list)}]"
|
||||
return re.split(pattern, s)
|
||||
|
||||
|
||||
def is_float(element: any) -> bool:
|
||||
try:
|
||||
float(element)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def question_scorer(
|
||||
model_answer: str,
|
||||
ground_truth: str,
|
||||
) -> bool:
|
||||
# Ensure model_answer is a string
|
||||
if not isinstance(model_answer, str):
|
||||
try:
|
||||
model_answer = str(model_answer)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to convert model_answer to string: {e}. Type: {type(model_answer)}", UserWarning)
|
||||
return False
|
||||
|
||||
# if gt is a number
|
||||
if is_float(ground_truth):
|
||||
normalized_answer = normalize_number_str(str(model_answer))
|
||||
return normalized_answer == float(ground_truth)
|
||||
|
||||
# if gt is a list
|
||||
elif any(char in ground_truth for char in [",", ";"]):
|
||||
# question with the fish: normalization removes punct
|
||||
|
||||
gt_elems = split_string(ground_truth)
|
||||
ma_elems = split_string(model_answer)
|
||||
|
||||
# check length is the same
|
||||
if len(gt_elems) != len(ma_elems):
|
||||
warnings.warn("Answer lists have different lengths, returning False.", UserWarning)
|
||||
return False
|
||||
|
||||
# compare each element as float or str
|
||||
comparisons = []
|
||||
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
||||
if is_float(gt_elem):
|
||||
normalized_ma_elem = normalize_number_str(ma_elem)
|
||||
comparisons.append(normalized_ma_elem == float(gt_elem))
|
||||
else:
|
||||
# we do not remove punct since comparisons can include punct
|
||||
comparisons.append(
|
||||
normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)
|
||||
)
|
||||
return all(comparisons)
|
||||
|
||||
# if gt is a str
|
||||
else:
|
||||
return normalize_str(model_answer) == normalize_str(ground_truth)
|
||||
|
||||
|
||||
def check_prediction_contains_answer_letters_in_order(prediction, true_answer):
|
||||
prediction = prediction.lower()
|
||||
true_answer = true_answer.lower()
|
||||
if len(prediction) > len(true_answer) * 3:
|
||||
return False
|
||||
i = 0
|
||||
for letter in true_answer:
|
||||
if letter in prediction[i:]:
|
||||
i += prediction[i:].index(letter)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_close_call(prediction, true_answer, is_correct):
|
||||
if is_correct:
|
||||
return True
|
||||
else:
|
||||
if is_float(true_answer):
|
||||
return is_correct
|
||||
else:
|
||||
if (
|
||||
check_prediction_contains_answer_letters_in_order(str(prediction), str(true_answer))
|
||||
and len(str(true_answer)) * 0.5 <= len(str(prediction)) <= len(str(true_answer)) * 2
|
||||
):
|
||||
# Remove print statement that causes duplicated output
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def normalize_str(input_str, remove_punct=True) -> str:
|
||||
"""
|
||||
Normalize a string by:
|
||||
- Removing all white spaces
|
||||
- Optionally removing punctuation (if remove_punct is True)
|
||||
- Converting to lowercase
|
||||
Parameters:
|
||||
- input_str: str, the string to normalize
|
||||
- remove_punct: bool, whether to remove punctuation (default: True)
|
||||
Returns:
|
||||
- str, the normalized string
|
||||
"""
|
||||
# Ensure input is a string
|
||||
if not isinstance(input_str, str):
|
||||
try:
|
||||
input_str = str(input_str)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to convert input to string: {e}. Type: {type(input_str)}", UserWarning)
|
||||
return ""
|
||||
|
||||
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
||||
no_spaces = re.sub(r"\s", "", input_str)
|
||||
|
||||
# Remove punctuation, if specified.
|
||||
if remove_punct:
|
||||
translator = str.maketrans("", "", string.punctuation)
|
||||
return no_spaces.lower().translate(translator)
|
||||
else:
|
||||
return no_spaces.lower()
|
||||
Loading…
Add table
Add a link
Reference in a new issue