mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
switch to normalized error verify
This commit is contained in:
parent
4a53b3dba0
commit
f0255839e2
4 changed files with 28 additions and 30 deletions
|
|
@ -6,11 +6,11 @@ from rdkit.Chem import Crippen
|
|||
|
||||
|
||||
class InChI2logPbootcamp(Basebootcamp):
|
||||
def __init__(self, num_numbers=4, max_atoms=15, min_atoms=3, elements=None, seed=None):
|
||||
def __init__(self, max_atoms=15, min_atoms=3, elements=None, seed=None):
|
||||
# super.__init__()
|
||||
self.num_numbers = num_numbers
|
||||
self.InChIGenerator = InChIGenerator(max_atoms=max_atoms, min_atoms=min_atoms, elements=elements, seed=seed)
|
||||
|
||||
self.tolerance_factor = tolerance_factor # 1 for 1% error consider true, 0.1 for 0.1% error true, 10 for 10% error
|
||||
|
||||
def case_generator(self) -> str:
|
||||
"""
|
||||
生成一组数字和目标值。
|
||||
|
|
@ -48,9 +48,10 @@ class InChI2logPbootcamp(Basebootcamp):
|
|||
"""
|
||||
mol = Chem.MolFromInchi(InChI)
|
||||
true_logp = Crippen.MolLogP(mol)
|
||||
print(f"Comparing pred: {solution}, ground_truth: {true_logp}")
|
||||
return abs(true_logp - float(solution)) <= 0.01 # maybe mse or mae better?
|
||||
solution_float = float(solution)
|
||||
|
||||
|
||||
|
||||
|
||||
# Handle case where true_logp is 0
|
||||
if true_logp == 0:
|
||||
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
|
||||
else:
|
||||
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01
|
||||
Loading…
Add table
Add a link
Reference in a new issue