mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix: handle validation without training
Added validation functionality to the training process and refactored validation method to use a dedicated validator instance.
This commit is contained in:
parent
708b42a00f
commit
34c8c87f0f
1 changed files with 15 additions and 4 deletions
|
|
@ -81,7 +81,7 @@ class Run:
|
|||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.export_config()
|
||||
scores: dict[str, _EVALUATE_OUTPUT] = {}
|
||||
if any([self.engine.train, self.engine.test]):
|
||||
if any([self.engine.train, self.engine.test, self.engine.validate]):
|
||||
self._ensure_resume_path()
|
||||
self.ensure_max_steps()
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
|
@ -94,7 +94,8 @@ class Run:
|
|||
"optimizer_time"
|
||||
].total_mean_optimizer_step_time_ms
|
||||
if self.engine.validate:
|
||||
scores["validation"] = self._validate(trainer, model, data_module)
|
||||
validator = self.get_validator()
|
||||
scores["validation"] = self._validate(validator, model, data_module)
|
||||
if self.engine.test:
|
||||
tester = self.get_tester()
|
||||
if (
|
||||
|
|
@ -151,9 +152,9 @@ class Run:
|
|||
f.write(str(train_time) + "\n")
|
||||
|
||||
def _validate(
|
||||
self, trainer: Trainer, model: LightningModule, data_module: LightningDataModule
|
||||
self, validator: Trainer, model: LightningModule, data_module: LightningDataModule
|
||||
) -> _EVALUATE_OUTPUT:
|
||||
score = trainer.validate(model, datamodule=data_module)
|
||||
score = validator.validate(model, datamodule=data_module)
|
||||
return score
|
||||
|
||||
def _test(
|
||||
|
|
@ -233,6 +234,16 @@ class Run:
|
|||
accelerator=self.engine.accelerator,
|
||||
)
|
||||
|
||||
def get_validator(self) -> Trainer:
|
||||
return Trainer(
|
||||
devices=1,
|
||||
logger=False,
|
||||
enable_progress_bar=(not self.engine.silent),
|
||||
deterministic=self.engine.deterministic,
|
||||
precision=precision_with_fallback(self.engine.precision), # type: ignore
|
||||
accelerator=self.engine.accelerator,
|
||||
)
|
||||
|
||||
def get_best_checkpoint(self) -> Optional[Path]:
|
||||
model_checkpoint = self._callbacks.get("best_model_checkpoint", None)
|
||||
if model_checkpoint is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue