mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
fix rdkit version
This commit is contained in:
parent
f4913c6f02
commit
bd2933d3a5
5 changed files with 16 additions and 16 deletions
|
|
@ -20,11 +20,10 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
||||||
inchis = self.InChIGenerator.generate_n_valid_inchi(10)
|
inchis = self.InChIGenerator.generate_n_valid_inchi(1)
|
||||||
# print(inchis)
|
# print(inchis)
|
||||||
n = random.randint(0, 9)
|
|
||||||
# print(n)
|
# print(n)
|
||||||
return inchis[n]
|
return inchis[0]
|
||||||
|
|
||||||
def prompt_func(self, InChI) -> str:
|
def prompt_func(self, InChI) -> str:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
||||||
return self.SMILESGenerator.generate_n_valid_smiles(10)[random.randint(0, 9)]
|
return self.SMILESGenerator.generate_n_valid_smiles(1)[0]
|
||||||
|
|
||||||
def prompt_func(self, SMILES) -> str:
|
def prompt_func(self, SMILES) -> str:
|
||||||
|
|
||||||
|
|
@ -63,5 +63,5 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
return 0.0 # Error is too large, score is 0
|
return 0.0 # Error is too large, score is 0
|
||||||
else:
|
else:
|
||||||
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
||||||
return 1.0
|
# return 1.0
|
||||||
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
||||||
|
|
|
||||||
|
|
@ -138,9 +138,9 @@ Based on the above data, please infer the possible formula. Ensure that your inf
|
||||||
x, y_true = data[:, :var_num], data[:, -1]
|
x, y_true = data[:, :var_num], data[:, -1]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# import traceback
|
# import traceback
|
||||||
print("Exception while parsing symbolic formulas:", e)
|
# print("Exception while parsing symbolic formulas:", e)
|
||||||
print("Infer formula:", infer_formula)
|
# print("Infer formula:", infer_formula)
|
||||||
print("Ground truth formula:", gt_formula)
|
# print("Ground truth formula:", gt_formula)
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
return 0.0
|
return 0.0
|
||||||
if func_pred is not None:
|
if func_pred is not None:
|
||||||
|
|
@ -157,7 +157,7 @@ Based on the above data, please infer the possible formula. Ensure that your inf
|
||||||
metrics['R2'] = r2_score(y_true, y_pred)
|
metrics['R2'] = r2_score(y_true, y_pred)
|
||||||
metrics['NMSE'] = np.mean((y_true - y_pred) ** 2) / np.var(y_true)
|
metrics['NMSE'] = np.mean((y_true - y_pred) ** 2) / np.var(y_true)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {e}")
|
# print(f"Exception: {e}")
|
||||||
try:
|
try:
|
||||||
x0_vals, x1_vals = generate_samples()
|
x0_vals, x1_vals = generate_samples()
|
||||||
gt_vals = func_gt(x0_vals, x1_vals)
|
gt_vals = func_gt(x0_vals, x1_vals)
|
||||||
|
|
@ -174,7 +174,8 @@ Based on the above data, please infer the possible formula. Ensure that your inf
|
||||||
metrics['R2'] = 1 - np.sum((gt_valid - pred_valid) ** 2) / np.var(gt_valid)
|
metrics['R2'] = 1 - np.sum((gt_valid - pred_valid) ** 2) / np.var(gt_valid)
|
||||||
metrics['NMSE'] = np.mean((gt_valid - pred_valid) ** 2) / np.var(gt_valid)
|
metrics['NMSE'] = np.mean((gt_valid - pred_valid) ** 2) / np.var(gt_valid)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
# print(e)
|
||||||
|
pass
|
||||||
# 判断方程等价性
|
# 判断方程等价性
|
||||||
metrics['SymbolicMatch'] = is_symbolically_equivalent(infer_formula, gt_formula, var_num)
|
metrics['SymbolicMatch'] = is_symbolically_equivalent(infer_formula, gt_formula, var_num)
|
||||||
|
|
||||||
|
|
@ -215,7 +216,7 @@ def _send_request(messages, mllm='gpt-4o'):
|
||||||
content = response.json()['choices'][0]['message']['content']
|
content = response.json()['choices'][0]['message']['content']
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}, {response.json()}")
|
# print(f"Error: {e}, {response.json()}")
|
||||||
pass
|
pass
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
@ -299,7 +300,7 @@ def parse_formula(formula_str: str):
|
||||||
expr_str = formula_str.strip()
|
expr_str = formula_str.strip()
|
||||||
|
|
||||||
if not expr_str:
|
if not expr_str:
|
||||||
print(f"[Parse Error] 公式字符串为空或剥离后为空: '{formula_str}'")
|
# print(f"[Parse Error] 公式字符串为空或剥离后为空: '{formula_str}'")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
local_dict = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "sqrt": sp.sqrt, "log": sp.log,
|
local_dict = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "sqrt": sp.sqrt, "log": sp.log,
|
||||||
|
|
@ -316,12 +317,12 @@ def parse_formula(formula_str: str):
|
||||||
func = sp.lambdify(symbols, expr, modules=numpy_modules)
|
func = sp.lambdify(symbols, expr, modules=numpy_modules)
|
||||||
return func, variable_names
|
return func, variable_names
|
||||||
except (SyntaxError, TypeError, AttributeError, sp.SympifyError) as e:
|
except (SyntaxError, TypeError, AttributeError, sp.SympifyError) as e:
|
||||||
print(f'[Parse Error] 无法解析公式 "{formula_str}": {e}')
|
# print(f'[Parse Error] 无法解析公式 "{formula_str}": {e}')
|
||||||
# import traceback
|
# import traceback
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'[Parse Error] 解析公式 "{formula_str}" 时发生意外错误: {e}')
|
# print(f'[Parse Error] 解析公式 "{formula_str}" 时发生意外错误: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ class SMILESGenerator:
|
||||||
valid_smiles_set = set()
|
valid_smiles_set = set()
|
||||||
total_attempts_overall = 0
|
total_attempts_overall = 0
|
||||||
|
|
||||||
print(f"Attempting to generate {n} valid SMILES (min_len={self.min_len}, max_len={self.max_len})...")
|
# print(f"Attempting to generate {n} valid SMILES (min_len={self.min_len}, max_len={self.max_len})...")
|
||||||
while len(valid_smiles_set) < n:
|
while len(valid_smiles_set) < n:
|
||||||
attempts_for_current_smiles = 0
|
attempts_for_current_smiles = 0
|
||||||
generated_this_round = False
|
generated_this_round = False
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -20,7 +20,7 @@ setuptools.setup(
|
||||||
"langdetect",
|
"langdetect",
|
||||||
"pympler",
|
"pympler",
|
||||||
"shortuuid",
|
"shortuuid",
|
||||||
"rdkit"
|
"rdkit==2024.3.2"
|
||||||
],
|
],
|
||||||
|
|
||||||
package_data={
|
package_data={
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue