diff --git a/internbootcamp/bootcamp/ChemStructure2Property/InChI2logPBootCamp.py b/internbootcamp/bootcamp/ChemStructure2Property/InChI2logPBootCamp.py index 702a00d..d6ddc90 100755 --- a/internbootcamp/bootcamp/ChemStructure2Property/InChI2logPBootCamp.py +++ b/internbootcamp/bootcamp/ChemStructure2Property/InChI2logPBootCamp.py @@ -1,6 +1,8 @@ +import random + from internbootcamp.bootcamp.base import Basebootcamp from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import InChIGenerator -from .utils import last_boxed_only_string, remove_boxed +from internbootcamp.bootcamp.ChemStructure2Property.utils import last_boxed_only_string, remove_boxed from rdkit import Chem from rdkit.Chem import Crippen @@ -18,7 +20,11 @@ class InChI2logPbootcamp(Basebootcamp): 生成一组数字和目标值。 """ self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None) - return self.InChIGenerator.generate_n_valid_inchi(1)[0] + inchis = self.InChIGenerator.generate_n_valid_inchi(10) + # print(inchis) + n = random.randint(0, 9) + # print(n) + return inchis[n] def prompt_func(self, InChI) -> str: @@ -44,17 +50,39 @@ class InChI2logPbootcamp(Basebootcamp): return None return remove_boxed(output) - @classmethod - def _verify_correction(cls, solution, InChI)->bool: + @classmethod + def _verify_correction(cls, solution, InChI) -> float: + """ + Verify the correction of the solution and return a score between 0 and 1. + The score is based on the relative error with respect to a maximum relative error of 0.1. """ - Verify the correction of the solution. - """ mol = Chem.MolFromInchi(InChI) true_logp = Crippen.MolLogP(mol) solution_float = float(solution) - + # Handle case where true_logp is 0 if true_logp == 0: - return abs(solution_float) <= 0.01 # Just check if solution is close to 0 + # If true_logp is 0, we check how close the solution is to 0 + relative_error = abs(solution_float) else: - return abs(true_logp - solution_float)/abs(true_logp) <= 0.01 \ No newline at end of file + # Calculate the relative error + relative_error = abs(true_logp - solution_float) / abs(true_logp) + + # Define the maximum allowed relative error + max_relative_error = 0.1 + + # Calculate the score based on the relative error + if relative_error >= max_relative_error: + return 0.0 # Error is too large, score is 0 + else: + # Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error + return 1.0 + return 1 - (relative_error / max_relative_error) * 0.5 ## For RL + +if __name__ == "__main__": + bootcamp = InChI2logPbootcamp() + while True: + case = bootcamp.case_generator() + print('case') + print(case) + input() \ No newline at end of file diff --git a/internbootcamp/bootcamp/ChemStructure2Property/SMILES2logPBootCamp.py b/internbootcamp/bootcamp/ChemStructure2Property/SMILES2logPBootCamp.py index 79144fd..f8a2eca 100755 --- a/internbootcamp/bootcamp/ChemStructure2Property/SMILES2logPBootCamp.py +++ b/internbootcamp/bootcamp/ChemStructure2Property/SMILES2logPBootCamp.py @@ -1,3 +1,5 @@ +import random + from internbootcamp.bootcamp.base import Basebootcamp from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import SMILESGenerator from .utils import last_boxed_only_string, remove_boxed @@ -19,7 +21,7 @@ class SMILES2logPbootcamp(InChI2logPbootcamp): 生成一组数字和目标值。 """ self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None) - return self.SMILESGenerator.generate_n_valid_smiles(1)[0] + return self.SMILESGenerator.generate_n_valid_smiles(10)[random.randint(0, 9)] def prompt_func(self, SMILES) -> str: @@ -30,15 +32,36 @@ class SMILES2logPbootcamp(InChI2logPbootcamp): return prompt - @classmethod - def _verify_correction(cls, solution, SMILES)->bool: + @classmethod + def _verify_correction(cls, solution, SMILES) -> float: + """ + Verify the correction of the solution and return a score between 0 and 1. + The score is based on the relative error with respect to a maximum relative error of 0.1. """ - Verify the correction of the solution. - """ mol = Chem.MolFromSmiles(SMILES) + if mol is None: + raise ValueError("Invalid SMILES string provided.") + true_logp = Crippen.MolLogP(mol) solution_float = float(solution) + + # print('true_logp: ', true_logp, ' solution_float: ', solution_float) + + # Handle case where true_logp is 0 if true_logp == 0: - return abs(solution_float) <= 0.01 # Just check if solution is close to 0 + # If true_logp is 0, we check how close the solution is to 0 + relative_error = abs(solution_float) else: - return abs(true_logp - solution_float)/abs(true_logp) <= 0.01 + # Calculate the relative error + relative_error = abs(true_logp - solution_float) / abs(true_logp) + + # Define the maximum allowed relative error + max_relative_error = 0.1 + + # Calculate the score based on the relative error + if relative_error >= max_relative_error: + return 0.0 # Error is too large, score is 0 + else: + # Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error + return 1.0 + return 1 - (relative_error / max_relative_error) * 0.5 ## For RL diff --git a/internbootcamp/libs/chemStructure2Property/ChemStructureGenerator.py b/internbootcamp/libs/chemStructure2Property/ChemStructureGenerator.py index 94fddaa..bfe23cb 100755 --- a/internbootcamp/libs/chemStructure2Property/ChemStructureGenerator.py +++ b/internbootcamp/libs/chemStructure2Property/ChemStructureGenerator.py @@ -9,7 +9,7 @@ class InChIGenerator: def __init__(self, max_atoms=15, min_atoms=3, elements=None, seed=None): RDLogger.DisableLog('rdApp.*') - random.seed(42) if seed is None else random.seed(seed) + random.seed(seed) self.max_atoms = max_atoms self.min_atoms = min_atoms if elements is None: @@ -123,7 +123,7 @@ class SMILESGenerator: def __init__(self, min_len=5, max_len=25, seed=None): RDLogger.DisableLog('rdApp.*') - random.seed(42) if seed is None else random.seed(seed) + random.seed(seed) self.min_len = min_len self.max_len = max_len