init-commit

This commit is contained in:
lilinyang 2025-05-23 15:27:15 +08:00
commit 18a552597a
3461 changed files with 1150579 additions and 0 deletions

View file

@ -0,0 +1,19 @@
# Copyright (c) InternLM. All rights reserved.
from .base_judger import (
BaseJudger,
register_judger,
registered_judgers,
)
from .math_judger import MathJudger
from .router import InputData, ParallelRouter
from .bootcamp_judger import bootcampJudger
__all__ = [
"register_judger",
"registered_judgers",
"BaseJudger",
"MathJudger",
"InputData",
"ParallelRouter",
"bootcampJudger",
]

View file

@ -0,0 +1,61 @@
# Copyright (c) InternLM. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Dict,
Generic,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
T = TypeVar("T")
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
Reward = Union[float, List[float], None]
MetaData = TypedDict("MetaData", {"data_source": str})
@dataclass
class JudgeStatus(Generic[T]):
ok: bool = True
reason: Optional[str] = None
handle: Optional[T] = None
class BaseJudger(ABC):
def __init__(self):
pass
@abstractmethod
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict,
) -> JudgeStatus:
raise NotImplementedError()
@abstractmethod
def on_reward_required(
self,
status: JudgeStatus,
timeout: Optional[float] = None,
) -> Reward:
raise NotImplementedError()
registered_judgers: Dict[str, Type[BaseJudger]] = {}
def register_judger(name: str):
global registered_judgers
def wrapper(cls):
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
registered_judgers[name] = cls
return cls
return wrapper

View file

@ -0,0 +1,81 @@
# Copyright (c) InternLM. All rights reserved.
import random
import re
import time
from typing import List, Optional, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
import requests
import internbootcamp
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
@register_judger("bootcamp_judger")
class bootcampJudger(BaseJudger):
def __init__(
self,
stop_word="<|im_end|>",
format_score=0,
format_penalty=True,
short_penalty=True,
short_threshold=128,
):
super().__init__()
self.stop_word = stop_word
self.format_score = format_score
self.format_penalty = format_penalty
self.short_penalty = short_penalty
self.short_threshold = short_threshold
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict, # 存在数据集对应的字段里面,想存啥都可以,自己解析出来就行
) -> JudgeStatus:
question = prompt_messages[-1]["content"]
response = completion_messages[-1]["content"]
identity = metadata["ground_truth"]
data_source = metadata["data_source"]
verify_label = None
if not response.strip().endswith(self.stop_word):
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
verify_label = False
return JudgeStatus(
ok=True,
handle={
"data_source": data_source,
"question": question,
"response": response,
"identity": identity,
"verify_label": verify_label,
},
)
def on_reward_required( # 把judger的判断结果转成reward的score
self, status: JudgeStatus, timeout: Optional[float] = None
) -> Reward:
if status.handle["verify_label"] is False:
score = 0.0
return score
# 把judger的判断结果转成reward的score
data_source = status.handle["data_source"]
response = status.handle["response"]
identity = status.handle["identity"]
prompt = status.handle["question"]
bootcamp_cls= getattr(internbootcamp, data_source[0].upper() + data_source[1:] + "bootcamp")
try:
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score,format_penalty=self.format_penalty,short_penalty=self.short_penalty,short_threshold=self.short_threshold)
except:
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score)
return score
# print(f"[Debug] Prompt: {prompt}")
# print(f"[Debug]: score: {score}, response: {response}")
# if type(score) == int:
# assert score >= 0 and score <= 1
# return score
# return 0

View file

@ -0,0 +1,198 @@
# Copyright (c) InternLM. All rights reserved.
import random
import re
import time
from typing import List, Optional, Tuple
import requests
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
from .utils import extract_answer, math_equal
@register_judger("math_judger")
class MathJudger(BaseJudger):
verify_prompt = """You are a helpful assistant who evaluates the correctness and quality of models' outputs.
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer.
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
Just return the letters \"A\" or \"B\", with no text around it.
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
<Original Question Begin>:
{question}
<Original Question End>
<Gold Target Begin>:
{gold_answer}
<Gold Target End>
<Predicted Answer Begin>:
{answer}
<Predicted End>
Judging the correctness of candidates' answers:"""
def __init__(
self,
hosts: List[str],
max_retries: int = 1,
retry_delay: float = 1.0,
stop_word="<|im_end|>",
thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
):
super().__init__()
self.hosts = hosts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.stop_word = stop_word
self.thinking_finish_words = thinking_finish_words
self.host_ip_idx = random.randint(0, len(hosts) - 1)
self.model_name = requests.get(
f"http://{self.hosts[self.host_ip_idx]}/v1/models",
headers={"Authorization": "Bearer "},
).json()["data"][0]["id"]
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict,
) -> JudgeStatus:
question = prompt_messages[-1]["content"]
response = completion_messages[-1]["content"]
question_type = metadata.get("question_type", None)
gold_answer = metadata["gold_answer"]
if not response.strip().endswith(self.stop_word):
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": False,
},
)
for thinking_finish_word in self.thinking_finish_words:
if thinking_finish_word in response:
response = response.split(thinking_finish_word)[-1]
response = response.replace(self.stop_word, "")
# first try to extract and verify with rule, if correct, return
extracted_answer, verify_label = self._extract_and_verify_with_logic(
response, gold_answer
)
if verify_label is True:
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": verify_label,
},
)
# then try to evaluate with model
res_string, verify_label = self._evaluate_answer_with_llm(
question, question_type, response, gold_answer
)
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": verify_label,
},
)
def on_reward_required(
self, status: JudgeStatus, timeout: Optional[float] = None
) -> Reward:
if status.handle is None:
return None
if status.handle["verify_label"] is not None:
return 1.0 if status.handle["verify_label"] else -1.0
return None
def _evaluate_answer_with_llm(
self, question: str, question_type: str, answer: str, gold_answer: str
) -> Tuple[str, bool]:
for i in range(self.max_retries):
host = self.hosts[self.host_ip_idx]
self.host_ip_idx = (self.host_ip_idx + 1) % len(self.hosts)
prompt = self.verify_prompt.format(
"", "", question=question, answer=answer, gold_answer=gold_answer
)
try:
res = requests.post(
f"http://{host}/v1/chat/completions",
json={
"model": self.model_name,
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.0,
"top_p": 0.8,
"top_k": 20,
"repetition_penalty": 1.05,
"max_tokens": 100,
"stop": ["<|im_end|>", "<|endoftext|>"],
},
)
res_string = res.json()["choices"][0]["message"]["content"]
print(f"Evaluate result: {res_string}")
verify_label = self._verify_from_string(res_string)
if verify_label is None:
raise ValueError(
f"Evaluate result is None, judger prediction: {res_string}"
)
return res_string, verify_label
except Exception as e:
print(f"Error verifying answer: {e}")
time.sleep(self.retry_delay)
continue
print(f"Failed to verify answer after {self.max_retries} retries.")
return None, None
def _verify_from_string(self, verification: str):
if "A" in verification and "B" not in verification:
label = True
elif "B" in verification and "A" not in verification:
label = False
else: # judger model failed to predict A or B
label = None
return label
def _extract_and_verify_with_logic(
self, response: str, gold_answer: str
) -> Tuple[str, bool]:
extracted_answer = extract_answer(response)
verify_label = math_equal(extracted_answer, gold_answer)
return extracted_answer, verify_label

View file

@ -0,0 +1,473 @@
# Copyright (c) InternLM. All rights reserved.
import atexit
import functools
import os
import queue
import time
import traceback
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, connection
from multiprocessing.synchronize import Event as EventClass
from typing import (
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
TypedDict,
TypeVar,
cast,
)
from uuid import uuid4
import loguru
from typing_extensions import NotRequired
from .base_judger import (
JudgeStatus,
MessageItem,
MetaData,
Reward,
registered_judgers,
)
class InputData(TypedDict):
prompt_messages: List[MessageItem]
completion_messages: List[MessageItem]
metadata: NotRequired[MetaData]
T = TypeVar("T")
@dataclass
class GenericTask(Generic[T]):
token: str
index: int
judger: str
content: T
@dataclass
class SubprocessConfig:
loguru_handlers: Optional[List[dict]] = None
worker_init_func: Optional[Callable] = None
class ParallelRouter:
def __init__(
self,
judgers_config: Dict[str, dict],
data_judger_mapping: Dict[str, Optional[List[str]]],
logger: Optional["loguru.Logger"] = None,
subprocess_config: Optional[SubprocessConfig] = None,
):
if logger is not None:
self.logger = logger
else:
import mock
self.logger = mock.Mock()
if subprocess_config is not None:
self.subprocess_config = subprocess_config
else:
self.subprocess_config = SubprocessConfig()
if not (
isinstance(judgers_config, dict)
and all(
isinstance(k, str) and isinstance(v, dict)
for k, v in judgers_config.items()
)
):
raise TypeError(
f"Illegal judgers_config: {judgers_config}\n"
"Should be Dict[str, dict]"
)
if "RM" in judgers_config.keys():
raise KeyError(
f"'RM' is a reserved judger keywork for {self.__class__.__name__}, "
f"please remove it from judgers_config: {judgers_config}"
)
self.judgers_config = judgers_config
data_judger_mapping: Dict[str, List[str]] = {
k: v or [] for k, v in data_judger_mapping.items()
} # change None to empty list []
if not (
isinstance(data_judger_mapping, dict)
and all(
isinstance(k, str)
and isinstance(v, (list, tuple, set))
and all(isinstance(vv, str) for vv in v)
for k, v in data_judger_mapping.items()
)
):
raise TypeError(
f"Illegal data_judger_mapping: {data_judger_mapping}\n"
"Should be Dict[str, List[str]]"
)
self.data_judger_mapping = data_judger_mapping
avail_judgers = set(self.judgers_config.keys()) | {"RM"}
_used_judgers: List[str] = []
for v in data_judger_mapping.values():
_used_judgers.extend(v)
used_judgers: set = set(_used_judgers)
if unused := avail_judgers - used_judgers:
self.logger.warning(
"Following judgers are available but not "
f"used in data mapping: {unused}\n"
"Please make sure this is intended"
)
# remove unused configs
for judger_name in unused:
self.judgers_config.pop(judger_name, None)
if missing := used_judgers - avail_judgers:
self.logger.warning(
"Following judgers are configured to be used "
f"but not built in data mapping: {missing}\n"
"Please make sure this is intended"
)
# remove missing judgers from mapping, to prevent potential errors
for source in list(self.data_judger_mapping.keys()):
before = set(self.data_judger_mapping[source])
self.data_judger_mapping[source] = list(before - missing)
# then filter out data_mapping without available judgers
self.data_judger_mapping = {
source: judgers
for source, judgers in self.data_judger_mapping.items()
if len(judgers) > 0
}
# Try build judgers in __init__ so that raise Exceptions earlly
for judger_name, judger_conf in self.judgers_config.items():
_ = self._build_judger(judger_name, judger_conf)
self._processes: List[Process] = []
self._stop_event = Event()
atexit.register(self.shutdown)
self._input_queues: Dict[str, Queue[GenericTask[InputData]]] = {
judger_name: Queue() for judger_name in self.judgers_config.keys()
}
self._output_queue: Queue[GenericTask[Reward]] = Queue()
self._exc_queue: Queue[Tuple[str, Exception]] = Queue()
self._num_tasks: Dict[str, int] = {} # for each token
self._num_indexes: Dict[str, int] = {} # for each token
self._results_buffer: Dict[str, List[GenericTask[Reward]]] = defaultdict(
list
) # results buffer grouped by the key "token"
def submit(self, data_batch: List[InputData]):
indexes_for_ext: List[int] = []
indexes_for_local: List[int] = []
tasks_input: List[GenericTask[InputData]] = []
token = str(uuid4())
for index, data_item in enumerate(data_batch):
if (
not isinstance(data_item, dict)
or "metadata" not in data_item
or "prompt_messages" not in data_item
or "completion_messages" not in data_item
):
indexes_for_local.append(index)
continue
source = data_item["metadata"].get("data_source", None)
if source is None or source not in self.data_judger_mapping:
indexes_for_local.append(index)
continue
indexes_for_ext.append(index)
for judger in self.data_judger_mapping[source]:
if judger == "RM":
indexes_for_local.append(index)
else:
tasks_input.append(
GenericTask(
token=token,
index=index,
judger=judger,
content=data_item,
)
)
self._num_tasks[token] = len(tasks_input)
self._num_indexes[token] = len(data_batch)
for task in tasks_input:
self._input_queues[task.judger].put(task, block=True, timeout=1)
if not self._processes:
self.logger.debug("Starting processes...")
for judger_name, judger_conf in self.judgers_config.items():
num_proc = judger_conf.pop("num_processes", 1)
self._processes.extend(
[
Process(
target=ParallelRouter._safe_process_worker,
kwargs={
"stop_event": self._stop_event,
"judger_name": judger_name,
"judger_conf": judger_conf,
"input_queue": self._input_queues[judger_name],
"output_queue": self._output_queue,
"exc_queue": self._exc_queue,
"config": self.subprocess_config,
},
daemon=True,
)
for _ in range(num_proc)
]
)
for p in self._processes:
p.start()
self.logger.debug(f"Start processes done, total {len(self._processes)}")
return token, indexes_for_local
def query(
self, token: str, timeout: float = 0
) -> Optional[List[Optional[Dict[str, Reward]]]]:
start = time.time()
while True:
self._try_catch_subprocess_exceptions()
try:
result = self._output_queue.get(timeout=0.1)
self._results_buffer[result.token].append(result)
except queue.Empty:
pass
if len(self._results_buffer[token]) == self._num_tasks[token]:
results = self._results_buffer.pop(token)
num_tasks = self._num_tasks.pop(token)
num_indexes = self._num_indexes.pop(token)
rewards: List[Dict[str, Reward]] = [{} for _ in range(num_indexes)]
for result in results:
reward = result.content
if result.judger in rewards[result.index]:
self.logger.warning(
f"{result.judger} already exists: {rewards[result.index]}, "
f"will replace --> {reward}"
)
rewards[result.index][result.judger] = reward
# convert empty dicts to None
return [r or None for r in rewards]
if timeout > 0 and (time.time() - start) > timeout:
raise TimeoutError(
f"Timeout after {timeout} seconds, got {len(self._results_buffer[token])} results, expected {self._num_tasks[token]}"
)
@staticmethod
def _safe_process_worker(
stop_event: EventClass,
judger_name: str,
judger_conf: dict,
input_queue: "Queue[GenericTask[InputData]]",
output_queue: "Queue[GenericTask[Reward]]",
exc_queue: "Queue[Tuple[str, Exception]]",
config: SubprocessConfig,
):
try:
ParallelRouter._process_worker(
stop_event=stop_event,
judger_name=judger_name,
judger_conf=judger_conf,
input_queue=input_queue,
output_queue=output_queue,
exc_queue=exc_queue,
config=config,
)
except Exception as e:
exc_queue.put((judger_name, e), timeout=1)
@staticmethod
def _process_worker(
stop_event: EventClass,
judger_name: str,
judger_conf: dict,
input_queue: "Queue[GenericTask[InputData]]",
output_queue: "Queue[GenericTask[Reward]]",
exc_queue: "Queue[Tuple[str, Exception]]",
config: SubprocessConfig,
):
from xtuner._lite import get_logger
logger = get_logger()
if config.loguru_handlers is not None:
for handler in config.loguru_handlers:
handler["enqueue"] = True
logger.add(*handler)
if config.worker_init_func is not None:
config.worker_init_func()
# Infer num threads for each stage according to configs
_num_threads = judger_conf.pop("concurrency_per_proc", (1, 1))
if isinstance(_num_threads, (tuple, list)) and len(_num_threads) == 2:
num_threads_s1, num_threads_s2 = _num_threads
elif isinstance(_num_threads, int):
num_threads_s1 = max(1, _num_threads // 2)
num_threads_s2 = max(1, _num_threads - num_threads_s1)
else:
raise TypeError(
"`concurrency_per_proc` in judger_conf should be int or "
f"Tuple[int, int], got {type(_num_threads)}: {_num_threads}"
)
# Lazy build judgers in subprocesses to avoid serialization errors
judger = ParallelRouter._build_judger(judger_name, judger_conf)
# input_queue = self._input_queues[judger_name]
# output_queue = self._output_queue
handle_queue: queue.Queue[GenericTask[JudgeStatus]] = queue.Queue()
log_prefix = f"[pid={os.getpid()},{judger_name}]"
def report_exc_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
stack_trace = traceback.format_exc()
logger.error(
f"{log_prefix} "
f"Thread worker of {judger_name} raised "
f"{type(e).__name__}: {e}",
f"Stack trace: {stack_trace}",
)
exc_queue.put((judger_name, e), timeout=1)
return wrapper
# Stage 1: input_queue -> judger.on_data_received -> handle_queue
@report_exc_wrapper
def thread_worker_s1():
while not stop_event.is_set():
try:
task = input_queue.get(timeout=0.1)
logger.debug(f"{log_prefix} dequeue input: {task}")
except queue.Empty:
logger.debug(f"{log_prefix} input queue empty")
time.sleep(0.1)
continue
data = task.content
if "metadata" not in data:
raise RuntimeError(
f"'metadata' not in data.keys(): {list(data.keys())}"
)
logger.debug(f"{log_prefix} on_data_received")
handle = judger.on_data_received(
data["prompt_messages"],
data["completion_messages"],
cast(dict, data["metadata"]),
)
logger.debug(f"{log_prefix} got handle")
new_task = GenericTask(
token=task.token,
index=task.index,
judger=task.judger,
content=handle,
)
while True:
try:
handle_queue.put(
new_task,
timeout=0.1,
)
logger.debug(f"{log_prefix} enqueue handle: {new_task}")
break
except queue.Full:
time.sleep(0.1)
# Stage 2: handle_queue -> judger.on_reward_required -> output_queue
@report_exc_wrapper
def thread_worker_s2():
while not stop_event.is_set():
try:
task = handle_queue.get(timeout=0.1)
logger.debug(f"{log_prefix} dequeue handle: {task}")
except queue.Empty:
logger.debug(f"{log_prefix} handle queue empty")
time.sleep(0.1)
continue
logger.debug(f"{log_prefix} on_reward_required")
reward = judger.on_reward_required(task.content)
logger.info(f"{log_prefix} got result")
new_task = GenericTask(
token=task.token,
index=task.index,
judger=task.judger,
content=reward,
)
while True:
try:
output_queue.put(
new_task,
timeout=0.1,
)
logger.debug(f"{log_prefix} enqueue output: {new_task}")
break
except queue.Full:
time.sleep(0.1)
from threading import Thread
threads: List[Thread] = []
for _ in range(num_threads_s1):
threads.append(Thread(target=thread_worker_s1, daemon=True))
for _ in range(num_threads_s2):
threads.append(Thread(target=thread_worker_s2, daemon=True))
for t in threads:
t.start()
for t in threads:
t.join()
@staticmethod
def _build_judger(judger_name: str, judger_conf: dict):
judger_conf = deepcopy(judger_conf)
judger_conf.pop("num_processes", None)
judger_conf.pop("concurrency_per_proc", None)
_type = judger_conf.pop("type", None)
if _type is None:
_type = judger_name
if _type not in registered_judgers:
raise KeyError(
f"{judger_name} use unregistered judger type: {_type}. "
f"Available judgers are: {list(registered_judgers.keys())}"
)
cls = registered_judgers[_type]
return cls(**judger_conf)
def _try_catch_subprocess_exceptions(self):
exc_handles: List[Tuple[str, Exception]] = []
while True:
try:
exc_handle = self._exc_queue.get(timeout=0.001)
exc_handles.append(exc_handle)
except queue.Empty:
break
if exc_handles:
error_message = "\n".join(
[
f"- [{judger_name}] {type(exc).__name__}: {exc}"
for judger_name, exc in exc_handles
]
)
raise RuntimeError(
"Following threads/processes raise exceptions unexpectedly:\n"
f"{error_message}\n"
"Program terminated"
)
def shutdown(self, timeout: float = 2.0):
if not hasattr(self, "_processes") or not self._processes:
return
if not self._stop_event.is_set():
self._stop_event.set()
connection.wait([p.sentinel for p in self._processes], timeout=timeout)
for p in self._processes:
if p.is_alive():
p.kill()
p.join()
self._processes = []

View file

@ -0,0 +1,485 @@
# flake8: noqa
# isort: skip_file
import multiprocessing
import re
from math import isclose
from typing import Optional, Union
from collections import defaultdict, Counter
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
def extract_answer(pred_str: str, execute: bool = False) -> str:
if re.search("\\boxed|boxed|\\box|box", pred_str):
answer = re.split("\\boxed|boxed|\\box|box", pred_str)[-1]
if len(answer) == 0:
return ""
elif answer[0] == "{":
stack = 1
a = ""
for c in answer[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = answer.split("$")[0].strip()
elif re.search("[Tt]he (final )?answer is:?", pred_str):
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
else: # use the last number
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
if len(pred) >= 1:
a = pred[-1]
else:
a = ""
choice = re.findall(r"([A-E]):\s*(.*)", a)
if len(choice) > 0:
for option, content in choice:
a = option
choice = re.findall(r"\(([A-E])\)\s*(.*)", a)
if len(choice) > 0:
for option, content in choice:
a = option
a = re.split(r"=|\\approx|≈", a)[-1]
# multiple lines
answer = ""
preds = re.split("\n", a)
for pred in preds:
if "\\begin{align" in pred or pred.endswith(":"):
continue
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
pred = re.sub(r"^[a-zA-Z0-9]+[\)]\s*", "", pred)
for p in pred.split("{}"):
if p != "":
pred = p
break
pred = re.sub(r"^\{([A-Z])\}|\(([A-Z])\)", r"\1\2", pred)
if pred != "":
answer = pred
break
return answer
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == f"{a}/{b}"
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
string = string.replace("\\!", "")
string = string.replace("\\ ", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\text", "")
string = string.replace("x\\in", "")
# remove percentage
string = string.replace("\\%", "")
string = string.replace(r"\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
string = string.replace("\\cdot", "")
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0+$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def extract_answer(pred_str: str, execute: bool = False) -> str:
if re.search("\boxed|boxed", pred_str):
answer = re.split("\boxed|boxed", pred_str)[-1]
if len(answer) == 0:
return ""
elif answer[0] == "{":
stack = 1
a = ""
for c in answer[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = answer.split("$")[0].strip()
elif re.search("[Tt]he (final )?answer is:?", pred_str):
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
elif pred_str.startswith("```python") and execute:
# fall back to program
from lagent import get_tool
a = get_tool("IPythonInteractive").exec(pred_str).value or ""
else: # use the last number
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
if len(pred) >= 1:
a = pred[-1]
else:
a = ""
# multiple lines
pred = a.split("\n")[0]
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
return pred
def is_digit(s):
try:
float(str(s).replace(",", ""))
return True
except ValueError:
return False
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
tolerance: float = 1e-4,
timeout: bool = False,
) -> bool:
"""Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(",", ""))
reference = float(str(reference).replace(",", ""))
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, rel_tol=tolerance):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (
prediction.startswith("[")
and prediction.endswith("]")
and not reference.startswith("(")
) or (
prediction.startswith("(")
and prediction.endswith(")")
and not reference.startswith("[")
):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ["{", "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if (
(prediction.startswith("[") and prediction.endswith("]"))
and (reference.startswith("[") and reference.endswith("]"))
or (prediction.startswith("(") and prediction.endswith(")"))
and (reference.startswith("(") and reference.endswith(")"))
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all(
[
math_equal(
pred_parts[i], ref_parts[i], include_percentage, is_close
)
for i in range(len(pred_parts))
]
):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def math_equal_process(param):
if param[-2] is None:
return False
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except Exception:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a - b) == 0:
return True
except Exception:
pass
try:
if isclose(N(a), N(b), rel_tol=1e-3):
return True
except Exception:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()
def math_majority_vote(answers: list, majority: Optional[int] = None):
# threshold = len(answers) // 2 + 1
ans2cnt, ans2idx = Counter(), defaultdict(list)
for i, ans in enumerate(answers):
if isinstance(ans, str) and ans.strip():
for key in ans2cnt.keys():
if math_equal(ans, key):
ans2cnt[key] += 1
ans2idx[key].append(i)
break
else:
ans2cnt[ans] += 1
ans2idx[ans].append(i)
if ans2cnt:
maj, cnt = ans2cnt.most_common(1)[0]
if maj and cnt >= (majority or 1):
return maj, ans2idx[maj]
return None, []