mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
123 lines
4.2 KiB
Python
Executable file
123 lines
4.2 KiB
Python
Executable file
import re
|
||
import json
|
||
import random
|
||
|
||
|
||
class Basebootcamp:
|
||
"""
|
||
Base class for bootcamp implementations.
|
||
|
||
A bootcamp is a class that contains the logic to verify the solution of a task.
|
||
"""
|
||
|
||
@staticmethod
|
||
def prompt_func(question_ori) -> str:
|
||
"""
|
||
Process the input_data and return the processed prompt.
|
||
|
||
Args:
|
||
question_ori: The question to be processed.
|
||
|
||
Returns:
|
||
str: The processed prompt.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
@staticmethod
|
||
def extract_output(output):
|
||
"""
|
||
Extract the output from the solution.
|
||
|
||
Args:
|
||
output: Model output to be processed.
|
||
|
||
Returns:
|
||
The processed output.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
|
||
@classmethod
|
||
def _verify_correction(cls,solution,identity)->bool:
|
||
"""
|
||
Verify the correction of the solution.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
|
||
|
||
@classmethod
|
||
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=256, ans_threshold=128, format_penalty=False) -> float:
|
||
"""
|
||
Verify the output against the ground truth.
|
||
|
||
Args:
|
||
output: The model output to be verified.
|
||
identity: Some rules or parameters to be used in the verification.
|
||
format_score: Whether to give a score for the format of the output.
|
||
|
||
Returns:
|
||
float: The score of the output.
|
||
"""
|
||
score = 0.
|
||
if format_penalty and ("<think>" not in model_output or "</think>" not in model_output):
|
||
return score
|
||
if format_penalty and (model_output.count("<think>") > 1 or model_output.count("</think>") > 1 or model_output.count("<think>") != model_output.count("</think>") or not model_output.startswith("<think>") or model_output.endswith("</think>")):
|
||
# should not end with </think>
|
||
return score
|
||
try:
|
||
extract_solution = cls.extract_output(model_output)
|
||
if extract_solution is None:
|
||
return score
|
||
else:
|
||
score = format_score # 必须在这里就给format_score 赋值!否则后面verify_correction如果报错,format_score就没有赋值
|
||
judge = cls._verify_correction(extract_solution, identity)
|
||
if type(judge) == bool and judge:
|
||
score = 1.
|
||
else:
|
||
assert type(judge) == float or type(judge) == int
|
||
score = float(judge)
|
||
|
||
except Exception as e:
|
||
# print("Error in verify_score:", e)
|
||
pass
|
||
|
||
ans_output = model_output.rsplit("</think>", 1)[1] if "</think>" in model_output else ""
|
||
|
||
if (short_penalty and len(model_output) < short_threshold) or (short_penalty and len(ans_output) < ans_threshold):
|
||
# if the output is too short, consider it incorrect
|
||
return min(score * len(model_output) / short_threshold, score * len(ans_output) / ans_threshold)
|
||
|
||
# This for training Debug
|
||
if random.randint(1,1024) == 1:
|
||
print("=============DEBUG=============")
|
||
print("model_output:\n", model_output)
|
||
print("identity:\n", identity)
|
||
print("extract_solution:\n", extract_solution)
|
||
print("score:", score)
|
||
print("===============================")
|
||
return score
|
||
|
||
|
||
|
||
|
||
# class BaseV2bootcamp(Basebootcamp):
|
||
# @classmethod
|
||
# def verify_score(cls, model_output, identity: str, format_score=0.1, short_penalty=True, short_threshold=100) -> float:
|
||
# """
|
||
# Verify the output against the ground truth.
|
||
|
||
# float: The score of the output.
|
||
# """
|
||
# score = 0.
|
||
# if short_penalty and len(model_output) < short_threshold:
|
||
# # if the output is too short, consider it incorrect
|
||
# return score
|
||
# identity = json.loads(identity)
|
||
# try:
|
||
# extract_solution = cls.extract_output(model_output)
|
||
# except Exception as e:
|
||
# # print("Error in verify_score:", e)
|
||
# pass
|
||
# return score
|
||
# return score
|