returnn.torch.updater
¶
This module covers the optimizer (SGD, Adam, etc) logic, and model param update logic in general.
- returnn.torch.updater.get_optimizer_class(class_name) Type[Optimizer] [source]¶
- Parameters:
class_name (str|()->torch.optim.Optimizer|type[torch.optim.Optimizer]) – Optimizer data, e.g. “adam”, torch.optim.Adam…
- Returns:
Optimizer class
- class returnn.torch.updater.Updater(*, config, network, device, initial_learning_rate=1.0)[source]¶
Wraps a torch.optim.Optimizer, and extends it by some further functionality.
- Parameters:
config (returnn.config.Config) – config defining the training conditions.
network (torch.nn.Module) – PyTorch Module defining the network.
device (torch.device|str)
initial_learning_rate (float)
- set_learning_rate(value)[source]¶
Updates the learning rate of the optimizer at each (sub)epoch.
- Parameters:
value (float) – New learning rate.
- set_current_train_step(*, global_train_step: int, epoch: int)[source]¶
Obtains an updated learning rate for the current training step inside a (sub)epoch.
- step(*, grad_scaler: GradScaler | None = None)[source]¶
Perform one step, i.e. update the parameters using the optimizer given the current calculated gradients.
- load_optimizer(filename)[source]¶
Loads a torch.optim.Optimizer from disk and stores it in self.optimizer.
- Parameters:
filename (str) – File from which to load the optimizer state.