import re
import json
import random
class Basebootcamp:
"""
Base class for bootcamp implementations.
A bootcamp is a class that contains the logic to verify the solution of a task.
"""
@staticmethod
def prompt_func(question_ori) -> str:
"""
Process the input_data and return the processed prompt.
Args:
question_ori: The question to be processed.
Returns:
str: The processed prompt.
"""
raise NotImplementedError
@staticmethod
def extract_output(output):
"""
Extract the output from the solution.
Args:
output: Model output to be processed.
Returns:
The processed output.
"""
raise NotImplementedError
@classmethod
def _verify_correction(cls,solution,identity)->bool:
"""
Verify the correction of the solution.
"""
raise NotImplementedError
@classmethod
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=256, think_threshold=128, ans_threshold=128, format_penalty=False) -> float:
"""
Verify the output against the ground truth.
Args:
output: The model output to be verified.
identity: Some rules or parameters to be used in the verification.
format_score: Whether to give a score for the format of the output.
Returns:
float: The score of the output.
"""
score = 0.
if format_penalty and ("" not in model_output or "" not in model_output):
return score
if format_penalty and (model_output.count("") > 1 or model_output.count("") > 1 or model_output.count("") != model_output.count("") or not model_output.startswith("") or model_output.endswith("")):
# should not end with
return score
try:
extract_solution = cls.extract_output(model_output)
if extract_solution is None:
return score
else:
score = format_score # 必须在这里就给format_score 赋值!否则后面verify_correction如果报错,format_score就没有赋值
judge = cls._verify_correction(extract_solution, identity)
if type(judge) == bool and judge:
score = 1.
else:
assert type(judge) == float or type(judge) == int
score = float(judge)
except Exception as e:
# print("Error in verify_score:", e)
pass
ans_output = model_output.rsplit("", 1)[1] if "" in model_output else ""
think_length = len(model_output) - len(ans_output)
score = max(0, score) # Ensure score is non-negative
if (short_penalty and len(model_output) < short_threshold) or (short_penalty and len(ans_output) < ans_threshold) or (short_penalty and think_length < think_threshold):
# if the output is too short, consider it incorrect
return min(score * len(model_output) / short_threshold, score * len(ans_output) / ans_threshold, score * think_length / think_threshold)
# This for training Debug
if random.randint(1,1024) == 1:
print("=============DEBUG=============")
print("model_output:\n", model_output)
print("identity:\n", identity)
print("extract_solution:\n", extract_solution)
print("score:", score)
print("===============================")
return score
# class BaseV2bootcamp(Basebootcamp):
# @classmethod
# def verify_score(cls, model_output, identity: str, format_score=0.1, short_penalty=True, short_threshold=100) -> float:
# """
# Verify the output against the ground truth.
# float: The score of the output.
# """
# score = 0.
# if short_penalty and len(model_output) < short_threshold:
# # if the output is too short, consider it incorrect
# return score
# identity = json.loads(identity)
# try:
# extract_solution = cls.extract_output(model_output)
# except Exception as e:
# # print("Error in verify_score:", e)
# pass
# return score
# return score