Update hpo.py

This commit is contained in:
crStiv 2025-06-19 22:59:42 +02:00 committed by GitHub
parent e934094173
commit b65b614132
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,7 +13,7 @@ from neps.utils.common import get_initial_directory, load_lightning_checkpoint
from pytorch_fob.engine.engine import Engine, Run
#############################################################
# Definig the seeds for reproducibility
# Defining the seeds for reproducibility
def set_seed(seed=42):
@ -56,7 +56,7 @@ def search_space(run: Run) -> dict:
return space
def create_exmperiment(run: Run, config: dict) -> dict:
def create_experiment(run: Run, config: dict) -> dict:
new_config = run.get_config().copy()
for k, v in config.items():
if k == "one_minus_beta1":
@ -78,7 +78,7 @@ def create_pipline(base_run: Run):
# Initialize the model and checkpoint dir
engine = Engine()
engine.parse_experiment(create_exmperiment(base_run, config))
engine.parse_experiment(create_experiment(base_run, config))
run = next(engine.runs())
run.ensure_max_steps()
model, datamodule = run.get_task()