mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
Merge remote-tracking branch 'remotes/origin/main' into feature/add-typhoon
This commit is contained in:
commit
e454b2eeb6
3 changed files with 69 additions and 18 deletions
|
|
@ -1,6 +1,8 @@
|
||||||
|
import random
|
||||||
|
|
||||||
from internbootcamp.bootcamp.base import Basebootcamp
|
from internbootcamp.bootcamp.base import Basebootcamp
|
||||||
from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import InChIGenerator
|
from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import InChIGenerator
|
||||||
from .utils import last_boxed_only_string, remove_boxed
|
from internbootcamp.bootcamp.ChemStructure2Property.utils import last_boxed_only_string, remove_boxed
|
||||||
from rdkit import Chem
|
from rdkit import Chem
|
||||||
from rdkit.Chem import Crippen
|
from rdkit.Chem import Crippen
|
||||||
|
|
||||||
|
|
@ -18,7 +20,11 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
||||||
return self.InChIGenerator.generate_n_valid_inchi(1)[0]
|
inchis = self.InChIGenerator.generate_n_valid_inchi(10)
|
||||||
|
# print(inchis)
|
||||||
|
n = random.randint(0, 9)
|
||||||
|
# print(n)
|
||||||
|
return inchis[n]
|
||||||
|
|
||||||
def prompt_func(self, InChI) -> str:
|
def prompt_func(self, InChI) -> str:
|
||||||
|
|
||||||
|
|
@ -45,9 +51,10 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
return remove_boxed(output)
|
return remove_boxed(output)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _verify_correction(cls, solution, InChI)->bool:
|
def _verify_correction(cls, solution, InChI) -> float:
|
||||||
"""
|
"""
|
||||||
Verify the correction of the solution.
|
Verify the correction of the solution and return a score between 0 and 1.
|
||||||
|
The score is based on the relative error with respect to a maximum relative error of 0.1.
|
||||||
"""
|
"""
|
||||||
mol = Chem.MolFromInchi(InChI)
|
mol = Chem.MolFromInchi(InChI)
|
||||||
true_logp = Crippen.MolLogP(mol)
|
true_logp = Crippen.MolLogP(mol)
|
||||||
|
|
@ -55,6 +62,27 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
|
|
||||||
# Handle case where true_logp is 0
|
# Handle case where true_logp is 0
|
||||||
if true_logp == 0:
|
if true_logp == 0:
|
||||||
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
|
# If true_logp is 0, we check how close the solution is to 0
|
||||||
|
relative_error = abs(solution_float)
|
||||||
else:
|
else:
|
||||||
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01
|
# Calculate the relative error
|
||||||
|
relative_error = abs(true_logp - solution_float) / abs(true_logp)
|
||||||
|
|
||||||
|
# Define the maximum allowed relative error
|
||||||
|
max_relative_error = 0.1
|
||||||
|
|
||||||
|
# Calculate the score based on the relative error
|
||||||
|
if relative_error >= max_relative_error:
|
||||||
|
return 0.0 # Error is too large, score is 0
|
||||||
|
else:
|
||||||
|
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
||||||
|
return 1.0
|
||||||
|
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bootcamp = InChI2logPbootcamp()
|
||||||
|
while True:
|
||||||
|
case = bootcamp.case_generator()
|
||||||
|
print('case')
|
||||||
|
print(case)
|
||||||
|
input()
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import random
|
||||||
|
|
||||||
from internbootcamp.bootcamp.base import Basebootcamp
|
from internbootcamp.bootcamp.base import Basebootcamp
|
||||||
from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import SMILESGenerator
|
from internbootcamp.libs.chemStructure2Property.ChemStructureGenerator import SMILESGenerator
|
||||||
from .utils import last_boxed_only_string, remove_boxed
|
from .utils import last_boxed_only_string, remove_boxed
|
||||||
|
|
@ -19,7 +21,7 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
||||||
return self.SMILESGenerator.generate_n_valid_smiles(1)[0]
|
return self.SMILESGenerator.generate_n_valid_smiles(10)[random.randint(0, 9)]
|
||||||
|
|
||||||
def prompt_func(self, SMILES) -> str:
|
def prompt_func(self, SMILES) -> str:
|
||||||
|
|
||||||
|
|
@ -31,14 +33,35 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _verify_correction(cls, solution, SMILES)->bool:
|
def _verify_correction(cls, solution, SMILES) -> float:
|
||||||
"""
|
"""
|
||||||
Verify the correction of the solution.
|
Verify the correction of the solution and return a score between 0 and 1.
|
||||||
|
The score is based on the relative error with respect to a maximum relative error of 0.1.
|
||||||
"""
|
"""
|
||||||
mol = Chem.MolFromSmiles(SMILES)
|
mol = Chem.MolFromSmiles(SMILES)
|
||||||
|
if mol is None:
|
||||||
|
raise ValueError("Invalid SMILES string provided.")
|
||||||
|
|
||||||
true_logp = Crippen.MolLogP(mol)
|
true_logp = Crippen.MolLogP(mol)
|
||||||
solution_float = float(solution)
|
solution_float = float(solution)
|
||||||
|
|
||||||
|
# print('true_logp: ', true_logp, ' solution_float: ', solution_float)
|
||||||
|
|
||||||
|
# Handle case where true_logp is 0
|
||||||
if true_logp == 0:
|
if true_logp == 0:
|
||||||
return abs(solution_float) <= 0.01 # Just check if solution is close to 0
|
# If true_logp is 0, we check how close the solution is to 0
|
||||||
|
relative_error = abs(solution_float)
|
||||||
else:
|
else:
|
||||||
return abs(true_logp - solution_float)/abs(true_logp) <= 0.01
|
# Calculate the relative error
|
||||||
|
relative_error = abs(true_logp - solution_float) / abs(true_logp)
|
||||||
|
|
||||||
|
# Define the maximum allowed relative error
|
||||||
|
max_relative_error = 0.1
|
||||||
|
|
||||||
|
# Calculate the score based on the relative error
|
||||||
|
if relative_error >= max_relative_error:
|
||||||
|
return 0.0 # Error is too large, score is 0
|
||||||
|
else:
|
||||||
|
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
||||||
|
return 1.0
|
||||||
|
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ class InChIGenerator:
|
||||||
def __init__(self, max_atoms=15, min_atoms=3, elements=None,
|
def __init__(self, max_atoms=15, min_atoms=3, elements=None,
|
||||||
seed=None):
|
seed=None):
|
||||||
RDLogger.DisableLog('rdApp.*')
|
RDLogger.DisableLog('rdApp.*')
|
||||||
random.seed(42) if seed is None else random.seed(seed)
|
random.seed(seed)
|
||||||
self.max_atoms = max_atoms
|
self.max_atoms = max_atoms
|
||||||
self.min_atoms = min_atoms
|
self.min_atoms = min_atoms
|
||||||
if elements is None:
|
if elements is None:
|
||||||
|
|
@ -123,7 +123,7 @@ class SMILESGenerator:
|
||||||
def __init__(self, min_len=5, max_len=25,
|
def __init__(self, min_len=5, max_len=25,
|
||||||
seed=None):
|
seed=None):
|
||||||
RDLogger.DisableLog('rdApp.*')
|
RDLogger.DisableLog('rdApp.*')
|
||||||
random.seed(42) if seed is None else random.seed(seed)
|
random.seed(seed)
|
||||||
self.min_len = min_len
|
self.min_len = min_len
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue