mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
485
examples/xpuyu_usage/bootcamp_rl/judgers/utils.py
Executable file
485
examples/xpuyu_usage/bootcamp_rl/judgers/utils.py
Executable file
|
|
@ -0,0 +1,485 @@
|
|||
# flake8: noqa
|
||||
# isort: skip_file
|
||||
|
||||
import multiprocessing
|
||||
import re
|
||||
from math import isclose
|
||||
from typing import Optional, Union
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
|
||||
def extract_answer(pred_str: str, execute: bool = False) -> str:
|
||||
if re.search("\\boxed|boxed|\\box|box", pred_str):
|
||||
answer = re.split("\\boxed|boxed|\\box|box", pred_str)[-1]
|
||||
if len(answer) == 0:
|
||||
return ""
|
||||
elif answer[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in answer[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = answer.split("$")[0].strip()
|
||||
elif re.search("[Tt]he (final )?answer is:?", pred_str):
|
||||
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
|
||||
else: # use the last number
|
||||
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
a = pred[-1]
|
||||
else:
|
||||
a = ""
|
||||
choice = re.findall(r"([A-E]):\s*(.*)", a)
|
||||
if len(choice) > 0:
|
||||
for option, content in choice:
|
||||
a = option
|
||||
choice = re.findall(r"\(([A-E])\)\s*(.*)", a)
|
||||
if len(choice) > 0:
|
||||
for option, content in choice:
|
||||
a = option
|
||||
|
||||
a = re.split(r"=|\\approx|≈", a)[-1]
|
||||
|
||||
# multiple lines
|
||||
answer = ""
|
||||
preds = re.split("\n", a)
|
||||
for pred in preds:
|
||||
if "\\begin{align" in pred or pred.endswith(":"):
|
||||
continue
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred)
|
||||
pred = re.sub(r"^[a-zA-Z0-9]+[\)]\s*", "", pred)
|
||||
for p in pred.split("{}"):
|
||||
if p != "":
|
||||
pred = p
|
||||
break
|
||||
|
||||
pred = re.sub(r"^\{([A-Z])\}|\(([A-Z])\)", r"\1\2", pred)
|
||||
if pred != "":
|
||||
answer = pred
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) > 0 and substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except Exception:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
if "sqrt" not in a:
|
||||
a = int(a)
|
||||
if "sqrt" not in b:
|
||||
b = int(b)
|
||||
assert string == f"{a}/{b}"
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except Exception:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
||||
return _string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
string = str(string).strip()
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# right "."
|
||||
string = string.rstrip(".")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
string = string.replace("\\ ", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove unit: miles, dollars if after is not none
|
||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||
if _string != "" and _string != string:
|
||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||
string = _string
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
string = string.replace("$", "")
|
||||
|
||||
string = string.replace("\\text", "")
|
||||
string = string.replace("x\\in", "")
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
|
||||
# cdot
|
||||
string = string.replace("\\cdot", "")
|
||||
|
||||
# inf
|
||||
string = string.replace("infinity", "\\infty")
|
||||
if "\\infty" not in string:
|
||||
string = string.replace("inf", "\\infty")
|
||||
string = string.replace("+\\inity", "\\infty")
|
||||
|
||||
# and
|
||||
string = string.replace("and", "")
|
||||
string = string.replace("\\mathbf", "")
|
||||
|
||||
# use regex to remove \mbox{...}
|
||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||
|
||||
# quote
|
||||
string.replace("'", "")
|
||||
string.replace('"', "")
|
||||
|
||||
# i, j
|
||||
if "j" in string and "i" not in string:
|
||||
string = string.replace("j", "i")
|
||||
|
||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
||||
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
||||
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
string = _fix_sqrt(string)
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind("\\boxed")
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx : right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def extract_answer(pred_str: str, execute: bool = False) -> str:
|
||||
if re.search("\boxed|boxed", pred_str):
|
||||
answer = re.split("\boxed|boxed", pred_str)[-1]
|
||||
if len(answer) == 0:
|
||||
return ""
|
||||
elif answer[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in answer[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = answer.split("$")[0].strip()
|
||||
elif re.search("[Tt]he (final )?answer is:?", pred_str):
|
||||
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
|
||||
elif pred_str.startswith("```python") and execute:
|
||||
# fall back to program
|
||||
from lagent import get_tool
|
||||
|
||||
a = get_tool("IPythonInteractive").exec(pred_str).value or ""
|
||||
else: # use the last number
|
||||
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
a = pred[-1]
|
||||
else:
|
||||
a = ""
|
||||
# multiple lines
|
||||
pred = a.split("\n")[0]
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred)
|
||||
return pred
|
||||
|
||||
|
||||
def is_digit(s):
|
||||
try:
|
||||
float(str(s).replace(",", ""))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
tolerance: float = 1e-4,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""Exact match of math if and only if:
|
||||
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = float(str(prediction).replace(",", ""))
|
||||
reference = float(str(reference).replace(",", ""))
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, rel_tol=tolerance):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (
|
||||
prediction.startswith("[")
|
||||
and prediction.endswith("]")
|
||||
and not reference.startswith("(")
|
||||
) or (
|
||||
prediction.startswith("(")
|
||||
and prediction.endswith(")")
|
||||
and not reference.startswith("[")
|
||||
):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str == ref_str:
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (
|
||||
(prediction.startswith("[") and prediction.endswith("]"))
|
||||
and (reference.startswith("[") and reference.endswith("]"))
|
||||
or (prediction.startswith("(") and prediction.endswith(")"))
|
||||
and (reference.startswith("(") and reference.endswith(")"))
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
if param[-2] is None:
|
||||
return False
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except Exception:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), rel_tol=1e-3):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def math_majority_vote(answers: list, majority: Optional[int] = None):
|
||||
# threshold = len(answers) // 2 + 1
|
||||
ans2cnt, ans2idx = Counter(), defaultdict(list)
|
||||
for i, ans in enumerate(answers):
|
||||
if isinstance(ans, str) and ans.strip():
|
||||
for key in ans2cnt.keys():
|
||||
if math_equal(ans, key):
|
||||
ans2cnt[key] += 1
|
||||
ans2idx[key].append(i)
|
||||
break
|
||||
else:
|
||||
ans2cnt[ans] += 1
|
||||
ans2idx[ans].append(i)
|
||||
if ans2cnt:
|
||||
maj, cnt = ans2cnt.most_common(1)[0]
|
||||
if maj and cnt >= (majority or 1):
|
||||
return maj, ans2idx[maj]
|
||||
return None, []
|
||||
Loading…
Add table
Add a link
Reference in a new issue