mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
refactor(bootcamp): simplify InChI and SMILES generation
- Reduce the number of generated InChIs and SMILES from 10 to 1 - Remove random selection, always return the first generated structure - Comment out debug prints and unused code
This commit is contained in:
parent
f4913c6f02
commit
18f47e0a3a
3 changed files with 7 additions and 8 deletions
|
|
@ -20,11 +20,10 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
self.InChIGenerator = InChIGenerator(max_atoms=self.max_atoms, min_atoms=self.min_atoms, elements=None, seed=None)
|
||||||
inchis = self.InChIGenerator.generate_n_valid_inchi(10)
|
inchis = self.InChIGenerator.generate_n_valid_inchi(1)
|
||||||
# print(inchis)
|
# print(inchis)
|
||||||
n = random.randint(0, 9)
|
|
||||||
# print(n)
|
# print(n)
|
||||||
return inchis[n]
|
return inchis[0]
|
||||||
|
|
||||||
def prompt_func(self, InChI) -> str:
|
def prompt_func(self, InChI) -> str:
|
||||||
|
|
||||||
|
|
@ -76,7 +75,7 @@ class InChI2logPbootcamp(Basebootcamp):
|
||||||
return 0.0 # Error is too large, score is 0
|
return 0.0 # Error is too large, score is 0
|
||||||
else:
|
else:
|
||||||
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
||||||
return 1.0
|
# return 1.0
|
||||||
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
生成一组数字和目标值。
|
生成一组数字和目标值。
|
||||||
"""
|
"""
|
||||||
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
self.SMILESGenerator = SMILESGenerator(min_len=self.min_len, max_len=self.max_len, seed=None)
|
||||||
return self.SMILESGenerator.generate_n_valid_smiles(10)[random.randint(0, 9)]
|
return self.SMILESGenerator.generate_n_valid_smiles(1)[0]
|
||||||
|
|
||||||
def prompt_func(self, SMILES) -> str:
|
def prompt_func(self, SMILES) -> str:
|
||||||
|
|
||||||
|
|
@ -63,5 +63,5 @@ class SMILES2logPbootcamp(InChI2logPbootcamp):
|
||||||
return 0.0 # Error is too large, score is 0
|
return 0.0 # Error is too large, score is 0
|
||||||
else:
|
else:
|
||||||
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
# Linear interpolation: score decreases linearly from 1 to 0 as error goes from 0 to max_relative_error
|
||||||
return 1.0
|
# return 1.0
|
||||||
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
return 1 - (relative_error / max_relative_error) * 0.5 ## For RL
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ class InChIGenerator:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# This can happen if the molecule is somehow malformed even after sanitization,
|
# This can happen if the molecule is somehow malformed even after sanitization,
|
||||||
# or if InChI generation itself encounters an issue (rare).
|
# or if InChI generation itself encounters an issue (rare).
|
||||||
print(f"Debug: MolToInchi failed: {e} for SMILES: {Chem.MolToSmiles(mol)}")
|
# print(f"Debug: MolToInchi failed: {e} for SMILES: {Chem.MolToSmiles(mol)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return list(valid_inchi_set)
|
return list(valid_inchi_set)
|
||||||
|
|
@ -244,7 +244,7 @@ class SMILESGenerator:
|
||||||
valid_smiles_set = set()
|
valid_smiles_set = set()
|
||||||
total_attempts_overall = 0
|
total_attempts_overall = 0
|
||||||
|
|
||||||
print(f"Attempting to generate {n} valid SMILES (min_len={self.min_len}, max_len={self.max_len})...")
|
# print(f"Attempting to generate {n} valid SMILES (min_len={self.min_len}, max_len={self.max_len})...")
|
||||||
while len(valid_smiles_set) < n:
|
while len(valid_smiles_set) < n:
|
||||||
attempts_for_current_smiles = 0
|
attempts_for_current_smiles = 0
|
||||||
generated_this_round = False
|
generated_this_round = False
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue