mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-22 16:49:04 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
107
internbootcamp/bootcamp/base.py
Executable file
107
internbootcamp/bootcamp/base.py
Executable file
|
|
@ -0,0 +1,107 @@
|
|||
import re
|
||||
import json
|
||||
|
||||
|
||||
class Basebootcamp:
|
||||
"""
|
||||
Base class for bootcamp implementations.
|
||||
|
||||
A bootcamp is a class that contains the logic to verify the solution of a task.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def prompt_func(question_ori) -> str:
|
||||
"""
|
||||
Process the input_data and return the processed prompt.
|
||||
|
||||
Args:
|
||||
question_ori: The question to be processed.
|
||||
|
||||
Returns:
|
||||
str: The processed prompt.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def extract_output(output):
|
||||
"""
|
||||
Extract the output from the solution.
|
||||
|
||||
Args:
|
||||
output: Model output to be processed.
|
||||
|
||||
Returns:
|
||||
The processed output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@classmethod
|
||||
def _verify_correction(cls,solution,identity)->bool:
|
||||
"""
|
||||
Verify the correction of the solution.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=True, short_threshold=100, format_penalty=True) -> float:
|
||||
"""
|
||||
Verify the output against the ground truth.
|
||||
|
||||
Args:
|
||||
output: The model output to be verified.
|
||||
identity: Some rules or parameters to be used in the verification.
|
||||
format_score: Whether to give a score for the format of the output.
|
||||
|
||||
Returns:
|
||||
float: The score of the output.
|
||||
"""
|
||||
score = 0.
|
||||
if short_penalty and len(model_output) < short_threshold:
|
||||
# if the output is too short, consider it incorrect
|
||||
return score
|
||||
if format_penalty and "</think>" not in model_output:
|
||||
return score
|
||||
try:
|
||||
extract_solution = cls.extract_output(model_output)
|
||||
if extract_solution is None:
|
||||
return score
|
||||
else:
|
||||
score = format_score # 必须在这里就给format_score 赋值!否则后面verify_correction如果报错,format_score就没有赋值
|
||||
judge = cls._verify_correction(extract_solution, identity)
|
||||
if type(judge) == bool and judge:
|
||||
score = 1.
|
||||
else:
|
||||
assert type(judge) == float or type(judge) == int
|
||||
score = float(judge)
|
||||
|
||||
except Exception as e:
|
||||
# print("Error in verify_score:", e)
|
||||
pass
|
||||
return score
|
||||
|
||||
|
||||
|
||||
|
||||
# class BaseV2bootcamp(Basebootcamp):
|
||||
# @classmethod
|
||||
# def verify_score(cls, model_output, identity: str, format_score=0.1, short_penalty=True, short_threshold=100) -> float:
|
||||
# """
|
||||
# Verify the output against the ground truth.
|
||||
|
||||
# float: The score of the output.
|
||||
# """
|
||||
# score = 0.
|
||||
# if short_penalty and len(model_output) < short_threshold:
|
||||
# # if the output is too short, consider it incorrect
|
||||
# return score
|
||||
# identity = json.loads(identity)
|
||||
# try:
|
||||
# extract_solution = cls.extract_output(model_output)
|
||||
# except Exception as e:
|
||||
# # print("Error in verify_score:", e)
|
||||
# pass
|
||||
# return score
|
||||
# return score
|
||||
Loading…
Add table
Add a link
Reference in a new issue