mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
fix bugs for symbolic regression bootcamp
This commit is contained in:
parent
33101ef068
commit
5eb513f014
11 changed files with 148 additions and 64 deletions
|
|
@ -1 +1 @@
|
||||||
{"bootcamp_name": "medcalculator", "sample_number": 100, "config_file": "med_calculator", "bootcamp_cls_name": "Medcalculatorbootcamp"}
|
{"bootcamp_name": "SymbolicRegression", "sample_number": 100, "config_file": "Symbolic_Regression", "bootcamp_cls_name": "SymbolicRegressionbootcamp"}
|
||||||
|
|
@ -1 +1 @@
|
||||||
{"bootcamp_name": "medcalculator", "sample_number": 100000, "config_file": "med_calculator", "bootcamp_cls_name": "Medcalculatorbootcamp"}
|
{"bootcamp_name": "SymbolicRegression", "sample_number": 30000, "config_file": "Symbolic_Regression", "bootcamp_cls_name": "SymbolicRegressionbootcamp"}
|
||||||
|
|
@ -65,26 +65,27 @@ def main_pipeline(
|
||||||
print("bootcamp_name:", bootcamp_cls_name,"+", bootcamp_cls)
|
print("bootcamp_name:", bootcamp_cls_name,"+", bootcamp_cls)
|
||||||
count = 0
|
count = 0
|
||||||
failure = 0
|
failure = 0
|
||||||
|
bootcamp = bootcamp_cls(**config)
|
||||||
while count < _n:
|
while count < _n:
|
||||||
try:
|
try:
|
||||||
bootcamp = bootcamp_cls(**config)
|
|
||||||
bootcamp_case = bootcamp.case_generator()
|
bootcamp_case = bootcamp.case_generator()
|
||||||
prompt = bootcamp.prompt_func(bootcamp_case)
|
prompt = bootcamp.prompt_func(bootcamp_case)
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
length = len(tokenizer.encode(prompt))
|
length = len(tokenizer.encode(prompt))
|
||||||
if length > max_prompt_len:
|
if length > max_prompt_len:
|
||||||
continue
|
continue
|
||||||
failure = 0
|
|
||||||
writer.write(json.dumps({
|
writer.write(json.dumps({
|
||||||
"data_source": bootcamp_cls_name.replace("bootcamp", ""),
|
"data_source": bootcamp_cls_name.replace("bootcamp", ""),
|
||||||
"prompt": prompt.strip(),
|
"prompt": prompt.strip(),
|
||||||
"ground_truth": bootcamp_case
|
"ground_truth": bootcamp_case
|
||||||
}, ensure_ascii=False) + "\n")
|
}, ensure_ascii=False) + "\n")
|
||||||
bar.update()
|
bar.update()
|
||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failure += 1
|
failure += 1
|
||||||
if failure > 1000:
|
if failure > 512:
|
||||||
print(config, f"seems to be a too challenging config to generate cases , because of {e}")
|
print(config, f"seems to be a too challenging config to generate cases , because of {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data_path":"./internbootcamp/libs/symbolic_regression/test_data.pkl",
|
||||||
|
"sample_num_range":[64,144]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data_path":"./internbootcamp/libs/symbolic_regression/train_data.pkl",
|
||||||
|
"sample_num_range":[64,144]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
@ -10,15 +10,17 @@ fi
|
||||||
# 时间戳
|
# 时间戳
|
||||||
timestamp=$(date +"%Y-%m-%d-%H:%M:%S")
|
timestamp=$(date +"%Y-%m-%d-%H:%M:%S")
|
||||||
# cipher输入集
|
# cipher输入集
|
||||||
cipher_input_file='internbootcamp/libs/data/words_alpha_370000.txt'
|
|
||||||
|
|
||||||
tokenizer="/cpfs01/shared/llm_ddd/lipeiji/hf_hub_1/models--Qwen--Qwen2.5-32B-Instruct/snapshots/afb2829595f63efa3548e9d6b13aa66e61aa0f38" # tokenizer is used to calculate the sequence length of the prompt
|
tokenizer="/cpfs01/shared/llm_ddd/lipeiji/hf_hub_1/models--Qwen--Qwen2.5-32B-Instruct/snapshots/afb2829595f63efa3548e9d6b13aa66e61aa0f38" # tokenizer is used to calculate the sequence length of the prompt
|
||||||
max_prompt_len=4096
|
max_prompt_len=4096
|
||||||
max_jobs=60 # 设置最大并发进程数
|
max_jobs=64 # 设置最大并发进程数
|
||||||
jobs=() # 用于存储后台进程的PID
|
jobs=() # 用于存储后台进程的PID
|
||||||
|
|
||||||
|
|
||||||
# initialize, do not modify this
|
|
||||||
|
# initialize, do not modify below part
|
||||||
|
cipher_input_file='internbootcamp/libs/data/words_alpha_370000.txt'
|
||||||
cipher_test_nums_for_single_cipher=0
|
cipher_test_nums_for_single_cipher=0
|
||||||
cipher_train_nums_for_single_cipher=0
|
cipher_train_nums_for_single_cipher=0
|
||||||
|
|
||||||
|
|
@ -64,7 +66,8 @@ while IFS= read -r line || [ -n "$line" ]; do
|
||||||
|
|
||||||
pid=$! # 获取后台进程的PID
|
pid=$! # 获取后台进程的PID
|
||||||
jobs+=("$pid") # 将PID加入数组
|
jobs+=("$pid") # 将PID加入数组
|
||||||
|
# 打印当前进程总数
|
||||||
|
# echo "Current running jobs: ${#jobs[@]}"
|
||||||
# 控制并发数量
|
# 控制并发数量
|
||||||
while [ ${#jobs[@]} -ge $max_jobs ]; do
|
while [ ${#jobs[@]} -ge $max_jobs ]; do
|
||||||
wait -n # 等待任意一个子进程结束
|
wait -n # 等待任意一个子进程结束
|
||||||
|
|
@ -79,6 +82,9 @@ while IFS= read -r line || [ -n "$line" ]; do
|
||||||
done
|
done
|
||||||
done < examples/pipelines/data_configs/data_config_train.jsonl
|
done < examples/pipelines/data_configs/data_config_train.jsonl
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
echo "train set generation finished, start test generation."
|
||||||
|
|
||||||
while IFS= read -r line || [ -n "$line" ]; do
|
while IFS= read -r line || [ -n "$line" ]; do
|
||||||
# 跳过空行
|
# 跳过空行
|
||||||
|
|
@ -125,9 +131,10 @@ while IFS= read -r line || [ -n "$line" ]; do
|
||||||
done
|
done
|
||||||
done < examples/pipelines/data_configs/data_config_test.jsonl
|
done < examples/pipelines/data_configs/data_config_test.jsonl
|
||||||
|
|
||||||
# 等待所有后台任务完成
|
|
||||||
wait
|
wait
|
||||||
|
|
||||||
|
echo "test set generation finished"
|
||||||
|
|
||||||
# cipher test-set gen
|
# cipher test-set gen
|
||||||
python examples/pipelines/cipher_data_generator.py \
|
python examples/pipelines/cipher_data_generator.py \
|
||||||
--nums $cipher_test_nums_for_single_cipher \
|
--nums $cipher_test_nums_for_single_cipher \
|
||||||
|
|
|
||||||
|
|
@ -250,15 +250,15 @@ async def main():
|
||||||
help='Base URL of the OpenAI API compatible service. Default format is http://{ip}:{port}/v1.')
|
help='Base URL of the OpenAI API compatible service. Default format is http://{ip}:{port}/v1.')
|
||||||
parser.add_argument('--api_key', default='EMPTY',
|
parser.add_argument('--api_key', default='EMPTY',
|
||||||
help='API key for accessing the model service. Set to "EMPTY" if no key is required.')
|
help='API key for accessing the model service. Set to "EMPTY" if no key is required.')
|
||||||
parser.add_argument('--model_name', default='Qwen2.5-32B-Instruct',
|
parser.add_argument('--model_name', default='r1_32b',
|
||||||
help='Name of the model to be evaluated, e.g., r1_32B or other custom model name.')
|
help='Name of the model to be evaluated, e.g., r1_32B or other custom model name.')
|
||||||
parser.add_argument('--test_dir', default='/cpfs01/shared/llm_ddd/lipeiji/InternBootcamp/examples/bootcamp_generator_outputs/2025-06-12-14:29:13/test',
|
parser.add_argument('--test_dir', default='/cpfs01/shared/llm_ddd/lipeiji/InternBootcamp/examples/bootcamp_generator_outputs/2025-06-16-16:47:31/test',
|
||||||
help='Path to the directory containing test JSONL files for evaluation.')
|
help='Path to the directory containing test JSONL files for evaluation.')
|
||||||
parser.add_argument('--max_concurrent_requests', type=int, default=144,
|
parser.add_argument('--max_concurrent_requests', type=int, default=144,
|
||||||
help='Maximum number of concurrent requests allowed globally.')
|
help='Maximum number of concurrent requests allowed globally.')
|
||||||
parser.add_argument('--template', default='internbootcamp_v2',choices=['r1', 'qwen', 'internthinker', 'chatml','internbootcamp'],
|
parser.add_argument('--template', default='r1',choices=['r1', 'qwen', 'internthinker', 'chatml','internbootcamp'],
|
||||||
help='Predefined conversation template used to format prompts. Only valid when api_mode is completion.')
|
help='Predefined conversation template used to format prompts. Only valid when api_mode is completion.')
|
||||||
parser.add_argument('--max_tokens', type=int, default=8192,
|
parser.add_argument('--max_tokens', type=int, default=16384,
|
||||||
help='Maximum number of tokens the model can generate.')
|
help='Maximum number of tokens the model can generate.')
|
||||||
parser.add_argument('--temperature', type=float, default=0,
|
parser.add_argument('--temperature', type=float, default=0,
|
||||||
help='Controls randomness in text generation. Lower values produce more deterministic outputs.')
|
help='Controls randomness in text generation. Lower values produce more deterministic outputs.')
|
||||||
|
|
|
||||||
|
|
@ -1085,3 +1085,4 @@ from .ddistinctpaths.ddistinctpaths import Ddistinctpathsbootcamp
|
||||||
from .eereaderdisplay.eereaderdisplay import Eereaderdisplaybootcamp
|
from .eereaderdisplay.eereaderdisplay import Eereaderdisplaybootcamp
|
||||||
from .clunarnewyearandnumberdivision.clunarnewyearandnumberdivision import Clunarnewyearandnumberdivisionbootcamp
|
from .clunarnewyearandnumberdivision.clunarnewyearandnumberdivision import Clunarnewyearandnumberdivisionbootcamp
|
||||||
from .med_calculator.med_calculator import Medcalculatorbootcamp
|
from .med_calculator.med_calculator import Medcalculatorbootcamp
|
||||||
|
from .symbolic_regression.symbolic_regression import SymbolicRegressionbootcamp
|
||||||
|
|
@ -85,30 +85,29 @@ class Arcbootcamp(Basebootcamp):
|
||||||
verifiers_mapper = get_verifiers()
|
verifiers_mapper = get_verifiers()
|
||||||
def __init__(self, task_key_file: str = None):
|
def __init__(self, task_key_file: str = None):
|
||||||
task_key_file = "/".join(__file__.split('/')[:-4]) + "/" + task_key_file
|
task_key_file = "/".join(__file__.split('/')[:-4]) + "/" + task_key_file
|
||||||
task_key = [json.loads(f) for f in open(task_key_file, 'r').readlines()]
|
self.task_keys = [json.loads(f) for f in open(task_key_file, 'r').readlines()]
|
||||||
self.task_key = random.choice(task_key)['key']
|
|
||||||
self.generators = get_generators()
|
self.generators = get_generators()
|
||||||
self.hint_examples = []
|
|
||||||
self.current_example = None
|
self.current_example = None
|
||||||
|
|
||||||
|
|
||||||
def case_generator(self):
|
def case_generator(self):
|
||||||
if self.task_key not in self.generators:
|
task_key = random.choice(self.task_keys)['key']
|
||||||
raise ValueError(f"Task key '{self.task_key}' not found in generators.")
|
if task_key not in self.generators:
|
||||||
generator = self.generators[self.task_key]
|
raise ValueError(f"Task key '{task_key}' not found in generators.")
|
||||||
|
generator = self.generators[task_key]
|
||||||
self.current_example = generator(0, 1)
|
self.current_example = generator(0, 1)
|
||||||
|
hint_examples = []
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.hint_examples.append(generator(0, 1))
|
hint_examples.append(generator(0, 1))
|
||||||
input_grid = self.current_example['input']
|
input_grid = self.current_example['input']
|
||||||
return {'input_grid': input_grid, 'task_key': self.task_key}
|
return {'hint_examples':hint_examples ,'input_grid': input_grid, 'task_key': task_key}
|
||||||
|
|
||||||
|
|
||||||
def prompt_func(self, identity) -> str:
|
def prompt_func(self, identity) -> str:
|
||||||
"""
|
"""
|
||||||
Process the input_data and return the processed prompt.
|
Process the input_data and return the processed prompt.
|
||||||
"""
|
"""
|
||||||
return generate_arc_puzzle(self.hint_examples, identity['input_grid'])
|
return generate_arc_puzzle(identity['hint_examples'], identity['input_grid'])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_output(output:str)->Grid:
|
def extract_output(output:str)->Grid:
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class Basebootcamp:
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=256, ans_threshold=128, format_penalty=False) -> float:
|
def verify_score(cls, model_output, identity: dict, format_score=0, short_penalty=False, short_threshold=256, think_threshold=128, ans_threshold=128, format_penalty=False) -> float:
|
||||||
"""
|
"""
|
||||||
Verify the output against the ground truth.
|
Verify the output against the ground truth.
|
||||||
|
|
||||||
|
|
@ -83,10 +83,11 @@ class Basebootcamp:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ans_output = model_output.rsplit("</think>", 1)[1] if "</think>" in model_output else ""
|
ans_output = model_output.rsplit("</think>", 1)[1] if "</think>" in model_output else ""
|
||||||
|
think_length = len(model_output) - len(ans_output)
|
||||||
if (short_penalty and len(model_output) < short_threshold) or (short_penalty and len(ans_output) < ans_threshold):
|
score = max(0, score) # Ensure score is non-negative
|
||||||
|
if (short_penalty and len(model_output) < short_threshold) or (short_penalty and len(ans_output) < ans_threshold) or (short_penalty and think_length < think_threshold):
|
||||||
# if the output is too short, consider it incorrect
|
# if the output is too short, consider it incorrect
|
||||||
return min(score * len(model_output) / short_threshold, score * len(ans_output) / ans_threshold)
|
return min(score * len(model_output) / short_threshold, score * len(ans_output) / ans_threshold, score * think_length / think_threshold)
|
||||||
|
|
||||||
# This for training Debug
|
# This for training Debug
|
||||||
if random.randint(1,1024) == 1:
|
if random.randint(1,1024) == 1:
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,80 @@
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
import random
|
||||||
from internbootcamp.bootcamp.base import Basebootcamp
|
from internbootcamp.bootcamp.base import Basebootcamp
|
||||||
from sklearn.metrics import r2_score, root_mean_squared_error
|
from sklearn.metrics import r2_score, root_mean_squared_error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import pickle
|
import pickle
|
||||||
|
def last_boxed_only_string(string):
|
||||||
|
idx = string.rfind("\\boxed")
|
||||||
|
if "\\boxed " in string:
|
||||||
|
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
||||||
|
if idx < 0:
|
||||||
|
idx = string.rfind("\\fbox")
|
||||||
|
if idx < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
i = idx
|
||||||
|
right_brace_idx = None
|
||||||
|
num_left_braces_open = 0
|
||||||
|
while i < len(string):
|
||||||
|
if string[i] == "{":
|
||||||
|
num_left_braces_open += 1
|
||||||
|
if string[i] == "}":
|
||||||
|
num_left_braces_open -= 1
|
||||||
|
if num_left_braces_open == 0:
|
||||||
|
right_brace_idx = i
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if right_brace_idx is None:
|
||||||
|
retval = None
|
||||||
|
else:
|
||||||
|
retval = string[idx:right_brace_idx + 1]
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
class SymblocRegression(Basebootcamp):
|
def remove_boxed(s):
|
||||||
def __init__(self, data_path):
|
if "\\boxed " in s:
|
||||||
|
left = "\\boxed "
|
||||||
|
assert s[:len(left)] == left
|
||||||
|
return s[len(left):]
|
||||||
|
|
||||||
|
left = "\\boxed{"
|
||||||
|
|
||||||
|
assert s[:len(left)] == left
|
||||||
|
assert s[-1] == "}"
|
||||||
|
|
||||||
|
return s[len(left):-1]
|
||||||
|
|
||||||
|
class SymbolicRegressionbootcamp(Basebootcamp):
|
||||||
|
def __init__(self, data_path='./internbootcamp/libs/symbolic_regression/train_data.pkl', sample_num_range=[64,144]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
|
self.sample_num_range = sample_num_range
|
||||||
|
with open(f'{self.data_path}', 'rb') as f:
|
||||||
|
self.formula_data = pickle.load(f)
|
||||||
|
|
||||||
def case_generator(self, sample_num=300) -> object:
|
def case_generator(self) -> object:
|
||||||
"""
|
"""
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with open(f'{self.data_path}', 'rb') as f:
|
i = random.choice(range(len(self.formula_data)))
|
||||||
formula_data = pickle.load(f)
|
true_formula = self.formula_data[i]['formula']
|
||||||
data_list = []
|
dataset = self.formula_data[i]['data']
|
||||||
for i in range(len(formula_data)):
|
sample_num = np.random.randint(self.sample_num_range[0], self.sample_num_range[1])
|
||||||
true_formula = formula_data[i]['formula']
|
rand_idx = np.random.choice(dataset.shape[0], sample_num, replace=False)
|
||||||
dataset = formula_data[i]['data']
|
dataset = dataset[rand_idx]
|
||||||
rand_idx = np.random.choice(dataset.shape[0], sample_num, replace=False)
|
return {
|
||||||
dataset = dataset[rand_idx]
|
# 'id': formula_data[i]['id'],
|
||||||
data_list.append({
|
'true_formula': true_formula,
|
||||||
'id': formula_data[i]['id'],
|
'data':dataset.tolist(),
|
||||||
'true_formula': true_formula,
|
}
|
||||||
'data':dataset,
|
|
||||||
})
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
def prompt_func(self, identity) -> str:
|
def prompt_func(self, identity) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -43,10 +86,13 @@ class SymblocRegression(Basebootcamp):
|
||||||
Returns:
|
Returns:
|
||||||
str: The processed prompt.
|
str: The processed prompt.
|
||||||
"""
|
"""
|
||||||
data = identity['data']
|
data = np.array(identity['data'])
|
||||||
length_data = data.shape[0]
|
length_data = data.shape[0]
|
||||||
split_idx = int(length_data * 0.97)
|
split_idx = int(length_data * 0.97)
|
||||||
prompt = f"""You will be provided with a set of input-output pairs. Based on these data, infer the mathematical relationship between y and multiple input variables. Please note that the possible mathematical operations include: +, -, *, /, exp, sqrt, sin, arcsin, and constant terms. The input sample data are as follows: {change_data_to_prompt(data[:split_idx, :])} Based on the above data, please infer the possible formula. Ensure that your inference applies to all the provided data points, and consider both linear and nonlinear combinations. Verify whether your formula applies to the following new data point and adjust it to ensure accuracy: {change_data_to_prompt(data[split_idx:, :])} Finally, please output only the formula string you inferred (e.g. z=x_0 * x_1), without any additional information."""
|
prompt = f"""You will be provided with a set of input-output pairs. Based on these data, infer the mathematical relationship between y and multiple input variables. Please note that the possible mathematical operations include: +, -, *, /, exp, sqrt, sin, arcsin, and constant terms. The input sample data are as follows:
|
||||||
|
{change_data_to_prompt(data[:split_idx, :])}
|
||||||
|
Based on the above data, please infer the possible formula. Ensure that your inference applies to all the provided data points, and consider both linear and nonlinear combinations. Verify whether your formula applies to the following new data point and adjust it to ensure accuracy:
|
||||||
|
{change_data_to_prompt(data[split_idx:, :])}""" + """Finally, please output the formula string you inferred within \\boxed{}(e.g. \\boxed{y=sqrt(x0 + x1) / (2 * pi)}). Note that you should express mathematical formulas using Python syntax(sqrt(x0)) instead of LaTeX format(\sqrt(x_0))."""
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -60,9 +106,11 @@ class SymblocRegression(Basebootcamp):
|
||||||
Returns:
|
Returns:
|
||||||
The processed output.
|
The processed output.
|
||||||
"""
|
"""
|
||||||
infer_formula = llm_translate(output, mllm='gpt-4o') # gpt-4o Qwen2.5-vl-72b
|
# infer_formula = llm_translate(output, mllm='gpt-4o') # gpt-4o Qwen2.5-vl-72b
|
||||||
infer_formula = clean_formula_string(infer_formula)
|
output = last_boxed_only_string(output)
|
||||||
return infer_formula
|
if output is None:
|
||||||
|
return None
|
||||||
|
return remove_boxed(output)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _verify_correction(self, infer_formula, gt_case, mllm='gpt-4o')->bool:
|
def _verify_correction(self, infer_formula, gt_case, mllm='gpt-4o')->bool:
|
||||||
|
|
@ -70,9 +118,9 @@ class SymblocRegression(Basebootcamp):
|
||||||
Verify the correction of the solution.
|
Verify the correction of the solution.
|
||||||
"""
|
"""
|
||||||
gt_formula = gt_case['true_formula']
|
gt_formula = gt_case['true_formula']
|
||||||
data = gt_case['data']
|
data = np.array(gt_case['data'])
|
||||||
metrics = {
|
metrics = {
|
||||||
'LLM_Score': None,
|
# 'LLM_Score': None,
|
||||||
'RMSE': None,
|
'RMSE': None,
|
||||||
'NMSE': None, # 新增:Normalized MSE
|
'NMSE': None, # 新增:Normalized MSE
|
||||||
'SymbolicMatch': False,
|
'SymbolicMatch': False,
|
||||||
|
|
@ -80,13 +128,21 @@ class SymblocRegression(Basebootcamp):
|
||||||
}
|
}
|
||||||
|
|
||||||
# 结构评分(用 LLM)
|
# 结构评分(用 LLM)
|
||||||
metrics['LLM_Score'] = llm_evaluate(infer_formula, gt_formula, mllm=mllm)
|
# metrics['LLM_Score'] = llm_evaluate(infer_formula, gt_formula, mllm=mllm)
|
||||||
|
|
||||||
# 数值拟合
|
# 数值拟合
|
||||||
func_pred, variable_names = parse_formula(infer_formula)
|
try:
|
||||||
func_gt, variable_names = parse_formula(gt_formula)
|
func_pred, variable_names = parse_formula(infer_formula)
|
||||||
var_num = len(variable_names)
|
func_gt, variable_names = parse_formula(gt_formula)
|
||||||
x, y_true = data[:, :var_num], data[:, -1]
|
var_num = len(variable_names)
|
||||||
|
x, y_true = data[:, :var_num], data[:, -1]
|
||||||
|
except Exception as e:
|
||||||
|
# import traceback
|
||||||
|
print("Exception while parsing symbolic formulas:", e)
|
||||||
|
print("Infer formula:", infer_formula)
|
||||||
|
print("Ground truth formula:", gt_formula)
|
||||||
|
# traceback.print_exc()
|
||||||
|
return 0.0
|
||||||
if func_pred is not None:
|
if func_pred is not None:
|
||||||
try:
|
try:
|
||||||
x_vars = [x[:, i] for i in range(var_num)]
|
x_vars = [x[:, i] for i in range(var_num)]
|
||||||
|
|
@ -122,7 +178,10 @@ class SymblocRegression(Basebootcamp):
|
||||||
# 判断方程等价性
|
# 判断方程等价性
|
||||||
metrics['SymbolicMatch'] = is_symbolically_equivalent(infer_formula, gt_formula, var_num)
|
metrics['SymbolicMatch'] = is_symbolically_equivalent(infer_formula, gt_formula, var_num)
|
||||||
|
|
||||||
return metrics
|
if metrics['SymbolicMatch']:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return max(0, metrics['R2'])
|
||||||
|
|
||||||
|
|
||||||
def _send_request(messages, mllm='gpt-4o'):
|
def _send_request(messages, mllm='gpt-4o'):
|
||||||
|
|
@ -258,6 +317,8 @@ def parse_formula(formula_str: str):
|
||||||
return func, variable_names
|
return func, variable_names
|
||||||
except (SyntaxError, TypeError, AttributeError, sp.SympifyError) as e:
|
except (SyntaxError, TypeError, AttributeError, sp.SympifyError) as e:
|
||||||
print(f'[Parse Error] 无法解析公式 "{formula_str}": {e}')
|
print(f'[Parse Error] 无法解析公式 "{formula_str}": {e}')
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'[Parse Error] 解析公式 "{formula_str}" 时发生意外错误: {e}')
|
print(f'[Parse Error] 解析公式 "{formula_str}" 时发生意外错误: {e}')
|
||||||
|
|
@ -292,11 +353,13 @@ def change_data_to_prompt(points):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# example
|
# example
|
||||||
data_path = 'test_data.pkl'
|
random.seed(42) # For reproducibility
|
||||||
bootcamp = SymblocRegression(data_path)
|
bootcamp = SymbolicRegressionbootcamp()
|
||||||
case = bootcamp.case_generator()[0] # 选取1个case
|
case = bootcamp.case_generator() # 选取1个case
|
||||||
print(bootcamp.prompt_func(case))
|
print(bootcamp.prompt_func(case))
|
||||||
example_answer = "y = x0 * x1"
|
example_answer = """这道问题的解是:\\boxed{ sqrt(x0)} hahaha"""
|
||||||
print(f"answer: {example_answer}")
|
print(f"answer: {example_answer}")
|
||||||
|
example_answer = bootcamp.extract_output(example_answer)
|
||||||
|
print(f'Extracted answer: {example_answer}')
|
||||||
metrics = bootcamp._verify_correction(example_answer, case)
|
metrics = bootcamp._verify_correction(example_answer, case)
|
||||||
print(f'GT: {case['true_formula'].ljust(40)} | Pred: {example_answer.ljust(40)} | Score: {metrics["LLM_Score"]} | RMSE: {metrics["RMSE"]} | NMSE: {metrics["NMSE"]} | R2: {metrics["R2"]} | Match: {metrics["SymbolicMatch"]}')
|
print(f'GT: {case["true_formula"].ljust(40)} | Pred: {example_answer.ljust(40)} | Metrics: {metrics}')
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue