fix: handle validation without training

Added validation functionality to the training process and refactored validation method to use a dedicated validator instance.
This commit is contained in:
Gengar 2026-02-21 15:53:37 +02:00 committed by GitHub
parent 708b42a00f
commit 34c8c87f0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -81,7 +81,7 @@ class Run:
self.run_dir.mkdir(parents=True, exist_ok=True)
self.export_config()
scores: dict[str, _EVALUATE_OUTPUT] = {}
if any([self.engine.train, self.engine.test]):
if any([self.engine.train, self.engine.test, self.engine.validate]):
self._ensure_resume_path()
self.ensure_max_steps()
torch.set_float32_matmul_precision("high")
@ -94,7 +94,8 @@ class Run:
"optimizer_time"
].total_mean_optimizer_step_time_ms
if self.engine.validate:
scores["validation"] = self._validate(trainer, model, data_module)
validator = self.get_validator()
scores["validation"] = self._validate(validator, model, data_module)
if self.engine.test:
tester = self.get_tester()
if (
@ -151,9 +152,9 @@ class Run:
f.write(str(train_time) + "\n")
def _validate(
self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule
self, validator: Trainer, model: LightningModule, data_module: LightningDataModule
) -> _EVALUATE_OUTPUT:
score = trainer.validate(model, datamodule=data_module)
score = validator.validate(model, datamodule=data_module)
return score
def _test(
@ -233,6 +234,16 @@ class Run:
accelerator=self.engine.accelerator,
)
def get_validator(self) -> Trainer:
return Trainer(
devices=1,
logger=False,
enable_progress_bar=(not self.engine.silent),
deterministic=self.engine.deterministic,
precision=precision_with_fallback(self.engine.precision), # type: ignore
accelerator=self.engine.accelerator,
)
def get_best_checkpoint(self) -> Optional[Path]:
model_checkpoint = self._callbacks.get("best_model_checkpoint", None)
if model_checkpoint is not None: