InternBootcamp/internbootcamp/bootcamp/base.py
2025-06-16 10:33:07 +08:00

123 lines
4.2 KiB
Python
Executable file
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import json
import random
class Basebootcamp:
"""
Base class for bootcamp implementations.
A bootcamp is a class that contains the logic to verify the solution of a task.
"""
@staticmethod
def prompt_func(question_ori) -> str:
"""
Process the input_data and return the processed prompt.
Args:
question_ori: The question to be processed.
Returns:
str: The processed prompt.
"""
raise NotImplementedError
@staticmethod
def extract_output(output):
"""
Extract the output from the solution.
Args:
output: Model output to be processed.
Returns:
The processed output.
"""
raise NotImplementedError
@classmethod
def _verify_correction(cls,solution,identity)->bool:
"""
Verify the correction of the solution.
"""
raise NotImplementedError
@classmethod
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=256, ans_threshold=128, format_penalty=False) -> float:
"""
Verify the output against the ground truth.
Args:
output: The model output to be verified.
identity: Some rules or parameters to be used in the verification.
format_score: Whether to give a score for the format of the output.
Returns:
float: The score of the output.
"""
score = 0.
if format_penalty and ("<think>" not in model_output or "</think>" not in model_output):
return score
if format_penalty and (model_output.count("<think>") > 1 or model_output.count("</think>") > 1 or model_output.count("<think>") != model_output.count("</think>") or not model_output.startswith("<think>") or model_output.endswith("</think>")):
# should not end with </think>
return score
try:
extract_solution = cls.extract_output(model_output)
if extract_solution is None:
return score
else:
score = format_score # 必须在这里就给format_score 赋值否则后面verify_correction如果报错format_score就没有赋值
judge = cls._verify_correction(extract_solution, identity)
if type(judge) == bool and judge:
score = 1.
else:
assert type(judge) == float or type(judge) == int
score = float(judge)
except Exception as e:
# print("Error in verify_score:", e)
pass
ans_output = model_output.rsplit("</think>", 1)[1] if "</think>" in model_output else ""
if (short_penalty and len(model_output) < short_threshold) or (short_penalty and len(ans_output) < ans_threshold):
# if the output is too short, consider it incorrect
return min(score * len(model_output) / short_threshold, score * len(ans_output) / ans_threshold)
# This for training Debug
if random.randint(1,1024) == 1:
print("=============DEBUG=============")
print("model_output:\n", model_output)
print("identity:\n", identity)
print("extract_solution:\n", extract_solution)
print("score:", score)
print("===============================")
return score
# class BaseV2bootcamp(Basebootcamp):
# @classmethod
# def verify_score(cls, model_output, identity: str, format_score=0.1, short_penalty=True, short_threshold=100) -> float:
# """
# Verify the output against the ground truth.
# float: The score of the output.
# """
# score = 0.
# if short_penalty and len(model_output) < short_threshold:
# # if the output is too short, consider it incorrect
# return score
# identity = json.loads(identity)
# try:
# extract_solution = cls.extract_output(model_output)
# except Exception as e:
# # print("Error in verify_score:", e)
# pass
# return score
# return score