fix(ChemStructure2Property): implement scoring system for logP prediction

- Add random selection of InChI and SMILES strings
- Implement relative error-based scoring for logP prediction
- Update verification functions to return scores instead of boolean
- Refactor InChI and SMILES generation for better randomness
This commit is contained in:
chenyongkang 2025-06-17 12:07:14 +08:00
parent 9773357be4
commit fb4009b871
3 changed files with 69 additions and 18 deletions

View file

@ -1,6 +1,8 @@
import random
from internbootcamp.bootcamp.base import Basebootcamp
from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import InChIGenerator
from .utils import last_boxed_only_string, remove_boxed
from internbootcamp.bootcamp.ChemStructure2Property.utils import last_boxed_only_string, remove_boxed
from rdkit import Chem
from rdkit.Chem import Crippen
@ -18,7 +20,11 @@ class InChI2logPbootcamp(Basebootcamp):
生成一组数字和目标值
"""
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
return self.InChIGenerator.generate_n_valid_inchi(1)[0]
inchis = self.InChIGenerator.generate_n_valid_inchi(10)
# print(inchis)
n = random.randint(0, 9)
# print(n)
return inchis[n]
def prompt_func(self, InChI) -> str:
@ -44,17 +50,39 @@ class InChI2logPbootcamp(Basebootcamp):
return None
return remove_boxed(output)
@classmethod
def _verify_correction(cls, solution, InChI)->bool:
@classmethod
def _verify_correction(cls, solution, InChI) -> float:
"""
Verify the correction of the solution and return a score between 0 and 1.
The score is based on the relative error with respect to a maximum relative error of 0.1.
"""
Verify the correction of the solution.
"""
mol = Chem.MolFromInchi(InChI)
true_logp = Crippen.MolLogP(mol)
solution_float = float(solution)
# Handle case where true_logp is 0
if true_logp == 0:
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
# If true_logp is 0, we check how close the solution is to 0
relative_error = abs(solution_float)
else:
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01
# Calculate the relative error
relative_error = abs(true_logp - solution_float) / abs(true_logp)
# Define the maximum allowed relative error
max_relative_error = 0.1
# Calculate the score based on the relative error
if relative_error >= max_relative_error:
return 0.0 # Error is too large, score is 0
else:
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
return 1.0
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
if __name__ == "__main__":
bootcamp = InChI2logPbootcamp()
while True:
case = bootcamp.case_generator()
print('case')
print(case)
input()