diff --git a/environments/community/pytorch_optimizer_coding/FOB/pytorch_fob/engine/run.py b/environments/community/pytorch_optimizer_coding/FOB/pytorch_fob/engine/run.py index c33970ea..03f1695c 100644 --- a/environments/community/pytorch_optimizer_coding/FOB/pytorch_fob/engine/run.py +++ b/environments/community/pytorch_optimizer_coding/FOB/pytorch_fob/engine/run.py @@ -81,7 +81,7 @@ class Run: self.run_dir.mkdir(parents=True, exist_ok=True) self.export_config() scores: dict[str, _EVALUATE_OUTPUT] = {} - if any([self.engine.train, self.engine.test]): + if any([self.engine.train, self.engine.test, self.engine.validate]): self._ensure_resume_path() self.ensure_max_steps() torch.set_float32_matmul_precision("high") @@ -94,7 +94,8 @@ class Run: "optimizer_time" ].total_mean_optimizer_step_time_ms if self.engine.validate: - scores["validation"] = self._validate(trainer, model, data_module) + validator = self.get_validator() + scores["validation"] = self._validate(validator, model, data_module) if self.engine.test: tester = self.get_tester() if ( @@ -151,9 +152,9 @@ class Run: f.write(str(train_time) + "\n") def _validate( - self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule + self, validator: Trainer, model: LightningModule, data_module: LightningDataModule ) -> _EVALUATE_OUTPUT: - score = trainer.validate(model, datamodule=data_module) + score = validator.validate(model, datamodule=data_module) return score def _test( @@ -233,6 +234,16 @@ class Run: accelerator=self.engine.accelerator, ) + def get_validator(self) -> Trainer: + return Trainer( + devices=1, + logger=False, + enable_progress_bar=(not self.engine.silent), + deterministic=self.engine.deterministic, + precision=precision_with_fallback(self.engine.precision), # type: ignore + accelerator=self.engine.accelerator, + ) + def get_best_checkpoint(self) -> Optional[Path]: model_checkpoint = self._callbacks.get("best_model_checkpoint", None) if model_checkpoint is not None: