This commit is contained in:
lipeiji 2025-06-12 12:45:31 +08:00
parent e5d5e53728
commit b379c541bf
11 changed files with 239 additions and 223 deletions

View file

@ -1,5 +1,6 @@
import re
import json
import random
class Basebootcamp:
@ -46,7 +47,7 @@ class Basebootcamp:
@classmethod
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=True, short_threshold=100, format_penalty=True) -> float:
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=100, format_penalty=False) -> float:
"""
Verify the output against the ground truth.
@ -62,7 +63,10 @@ class Basebootcamp:
if short_penalty and len(model_output) < short_threshold:
# if the output is too short, consider it incorrect
return score
if format_penalty and "</think>" not in model_output:
if format_penalty and ("<think>" not in model_output or "</think>" not in model_output):
return score
if format_penalty and (model_output.count("<think>") > 1 or model_output.count("</think>") > 1 or model_output.count("<think>") != model_output.count("</think>") or not model_output.startswith("<think>") or model_output.endswith("</think>")):
# should not end with </think>
return score
try:
extract_solution = cls.extract_output(model_output)
@ -80,6 +84,13 @@ class Basebootcamp:
except Exception as e:
# print("Error in verify_score:", e)
pass
if random.randint(1,1024) == 1:
print("=============DEBUG=============")
print("model_output:\n", model_output)
print("identity:\n", identity)
print("extract_solution:\n", extract_solution)
print("score:", score)
print("===============================")
return score