| .. | ||
| __init__.py | ||
| optimizers.py | ||
| README.md | ||
Optimizers
We currently have the following optimizers:
| Name | Optimizer | LR Scheduler |
|---|---|---|
| adamw_baseline | AdamW | Cosine Annealing with linear warmup |
| adamcpr | AdamCPR | Cosine Annealing with linear warmup |
| sgd_baseline | Stochastic Gradient Descent | Cosine Annealing |
| sgd_stepwise | Stochastic Gradient Descent | StepLR |
| adafactor | Adafactor | Constant |
Creating your own optimizer
To add your own optimizer, you need to create a subfolder in the optimizers directory. The name of that folder will be the name used to invoke the optimizer. Within the folder you need to provide two files: optimizer.py and default.yaml. There is a template optimizer with useful comments, which can be used as a starting point.
optimizer.py
Here you need to implement a function configure_optimizers with the following signature:
configure_optimizers(model: GroupedModel, config: OptimizerConfig) -> OptimizerLRScheduler
- The return type is the same as described here.
- The
GroupedModelis a wrapper around atorch.nn.Module. It additionally provides a methodgrouped_parameters, which returns the model parameters grouped by theirweight_decayandlearning_ratesettings. This is useful for some tasks that want to use e.g. lower learning rates for different parts of the model or to avoid applying weight decay to your norm layers. The underlyingtorch.nn.Modulecan be accessed withmodel.model. - The
OptimizerConfighas thelr_interval, max_steps, max_epochsattributes. It also gains all attributes provided in theoptimizersection of theexperiment.yaml.
default.yaml
Here you can provide default values for all the hyperparameters your optimizer needs. These values will be added to the OptimizerConfig passed to the configure_optimizers. So if you have the following default.yaml:
optimizer:
name: my_awesome_optimizer
output_dir_name: my_awesome_optimizer
learning_rate: 1.e-3
important:
extra:
parameter: 42
you could use config.important.extra.parameter in the configure_optimizers function.