switch to normalized error verify

This commit is contained in:
Jucheng Hu 2025-06-16 15:15:01 +08:00
parent 4a53b3dba0
commit f0255839e2
4 changed files with 28 additions and 30 deletions

View file

@ -12,7 +12,6 @@ class InChI2MRBootCamp(InChI2logPbootcamp):
prompt = instruction + '\n' + instruction_following
return prompt
@classmethod
def _verify_correction(cls, solution, InChI)->bool:
"""
@ -20,9 +19,11 @@ class InChI2MRBootCamp(InChI2logPbootcamp):
"""
mol = Chem.MolFromInchi(InChI)
true_MR = Crippen.MolMR(mol)
print(f"Comparing pred: {solution}, ground_truth: {true_MR}")
return abs(true_MR - float(solution)) <= 0.01 # maybe mse or mae better?
solution_float = float(solution)
# Handle case where true_logp is 0
if true_MR == 0:
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
else:
return abs(true_MR - solution_float)/abs(true_MR) <= 0.01

View file

@ -6,11 +6,11 @@ from rdkit.Chem import Crippen
class InChI2logPbootcamp(Basebootcamp):
def __init__(self, num_numbers=4, max_atoms=15, min_atoms=3, elements=None, seed=None):
def __init__(self, max_atoms=15, min_atoms=3, elements=None, seed=None):
# super.__init__()
self.num_numbers = num_numbers
self.InChIGenerator = InChIGenerator(max_atoms=max_atoms, min_atoms=min_atoms, elements=elements, seed=seed)
self.tolerance_factor = tolerance_factor # 1 for 1% error consider true, 0.1 for 0.1% error true, 10 for 10% error
def case_generator(self) -> str:
"""
生成一组数字和目标值
@ -48,9 +48,10 @@ class InChI2logPbootcamp(Basebootcamp):
"""
mol = Chem.MolFromInchi(InChI)
true_logp = Crippen.MolLogP(mol)
print(f"Comparing pred: {solution}, ground_truth: {true_logp}")
return abs(true_logp - float(solution)) <= 0.01 # maybe mse or mae better?
solution_float = float(solution)
# Handle case where true_logp is 0
if true_logp == 0:
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
else:
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01

View file

@ -8,8 +8,6 @@ from .SMILES2logPBootCamp import SMILES2logPBootCamp
class SMILES2MRBootCamp(SMILES2logPBootCamp):
def prompt_func(self, SMILES) -> str:
instruction = f"Given the SMILES, determine the Molar Refractivity (MR) value of the material. The SMILES is: {SMILES}"
@ -26,9 +24,9 @@ class SMILES2MRBootCamp(SMILES2logPBootCamp):
"""
mol = Chem.MolFromSmiles(SMILES)
true_MR = Crippen.MolMR(mol)
print(f"Comparing pred: {solution}, ground_truth: {true_MR}")
return abs(true_MR - float(solution)) <= 0.01 # maybe mse or mae better?
solution_float = float(solution)
if true_MR == 0:
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
else:
return abs(true_MR - solution_float)/abs(true_MR) <= 0.01

View file

@ -7,11 +7,10 @@ from rdkit.Chem import Crippen
from .InChI2logPBootCamp import InChI2logPbootcamp
class SMILES2logPBootCamp(InChI2logPbootcamp):
def __init__(self, num_numbers=4, min_len=5, max_len=25,
def __init__(self,min_len=5, max_len=25,
seed=None):
# super.__init__()
self.num_numbers = num_numbers
self.SMILESGenerator = SMILESGenerator(min_len=5, max_len=25, seed=None)
self.SMILESGenerator = SMILESGenerator(min_len=min_len, max_len=max_len, seed=seed)
def case_generator(self) -> str:
"""
@ -35,9 +34,8 @@ class SMILES2logPBootCamp(InChI2logPbootcamp):
"""
mol = Chem.MolFromSmiles(SMILES)
true_logp = Crippen.MolLogP(mol)
print(f"Comparing pred: {solution}, ground_truth: {true_logp}")
return abs(true_logp - float(solution)) <= 0.01 # maybe mse or mae better?
solution_float = float(solution)
if true_logp == 0:
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
else:
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01