returnn.learning_rate_control
¶
Provides the learning rate scheduling logic.
The base class is LearningRateControl
.
- class returnn.learning_rate_control.LearningRateControl(default_learning_rate, min_learning_rate=0.0, default_learning_rates=None, error_measure_key=None, relative_error_also_relative_to_learning_rate=False, min_num_epochs_per_new_learning_rate=0, relative_error_div_by_old=False, learning_rate_decay=1.0, learning_rate_growth=1.0, filename=None)[source]¶
Base class for learning rate control / scheduling.
- Parameters:
default_learning_rate (float) – default learning rate. usually for epoch 1
default_learning_rates (list[float] | dict[int,float]) – learning rates
error_measure_key (str|list[str]|None) – for get_epoch_error_value() the key for EpochData.error which is a dict
min_num_epochs_per_new_learning_rate (int) – if the lr was recently updated, use it for at least N epochs
relative_error_div_by_old (bool) – if True, compute relative error as (new - old) / old.
learning_rate_decay (float|(float)->float)
learning_rate_growth (float|(float)->float)
filename (str) – load from and save to file
- class EpochData(*, learning_rate: float | None = None, error: Dict[str, float] | None = None, meta: Dict[str, Any] | None = None, **kwargs)[source]¶
Encapsulates all relevant information for one epoch, needed to perform learning rate scheduling, such as the individual scores (cv or train; cross-entropy or frame-error or whatever).
- Parameters:
learning_rate
error – scores (loss values) and errors (frame error rates, etc)
meta – any other extra information (e.g. effective learning rate)
Note that this is serialized as EpochData(learningRate=…, error=…), and we keep that for compatibility, so that is why we have special handling for kwargs.
- calc_learning_rate_decay_or_grow(learning_rate, decay, grow=None)[source]¶
- Parameters:
learning_rate (float)
decay (bool)
grow (bool|None) – default is not decay
- Returns:
lr with decay or growth applied
- Return type:
float
- calc_new_learning_rate_for_epoch(epoch)[source]¶
- Parameters:
epoch (int)
- Returns:
new learning rate for this epoch
- Return type:
float
- get_last_epoch(epoch)[source]¶
- Parameters:
epoch (int)
- Returns:
last epoch before
epoch
where we have some epoch data- Return type:
int
- get_most_recent_learning_rate(epoch, exclude_current=True)[source]¶
- Parameters:
epoch (int)
exclude_current (bool)
- Returns:
most learning rate before or including
epoch
- Return type:
float
- calc_relative_error(old_epoch, new_epoch)[source]¶
- Parameters:
old_epoch (int)
new_epoch (int)
- Returns:
relative error between old epoch and new epoch
- Return type:
float
- get_error_key(epoch)[source]¶
- Parameters:
epoch (int)
- Returns:
key which we should look in scores/errors, for this epoch
- Return type:
str
- get_epoch_error_value(epoch)[source]¶
- Parameters:
epoch (int)
- Returns:
error/score for the specific epoch, given the error-key, see
get_error_key()
- Return type:
float
- get_epoch_error_key_value(epoch)[source]¶
- Parameters:
epoch (int)
- Returns:
key, error
- Return type:
(str, float)
- get_last_best_epoch(last_epoch, first_epoch=1, only_last_epochs=None, filter_score=inf, only_last_n=-1, min_score_dist=0.0)[source]¶
- Parameters:
first_epoch (int) – will check all epochs >= first_epoch
last_epoch (int) – inclusive. will check all epochs <= last_epoch
only_last_epochs (int|None) – if set, will only check the last N epochs, inclusive
filter_score (float) – all epochs which values over this score are not considered
only_last_n (int) – if set (>=1), from the resulting list, we consider only the last only_last_n
min_score_dist (float) – filter out epochs where the diff to the most recent is not big enough
- Returns:
the last best epoch. to get the details then, you might want to use getEpochErrorDict.
- Return type:
int|None
- class returnn.learning_rate_control.ConstantLearningRate(default_learning_rate, min_learning_rate=0.0, default_learning_rates=None, error_measure_key=None, relative_error_also_relative_to_learning_rate=False, min_num_epochs_per_new_learning_rate=0, relative_error_div_by_old=False, learning_rate_decay=1.0, learning_rate_growth=1.0, filename=None)[source]¶
Just a constant learning rate.
- Parameters:
default_learning_rate (float) – default learning rate. usually for epoch 1
default_learning_rates (list[float] | dict[int,float]) – learning rates
error_measure_key (str|list[str]|None) – for get_epoch_error_value() the key for EpochData.error which is a dict
min_num_epochs_per_new_learning_rate (int) – if the lr was recently updated, use it for at least N epochs
relative_error_div_by_old (bool) – if True, compute relative error as (new - old) / old.
learning_rate_decay (float|(float)->float)
learning_rate_growth (float|(float)->float)
filename (str) – load from and save to file
- class returnn.learning_rate_control.NewbobRelative(relative_error_threshold, **kwargs)[source]¶
If relative diff between old and new error is over some threshold, decay learning rate.
- class returnn.learning_rate_control.NewbobAbs(error_threshold, **kwargs)[source]¶
If absolute diff between old and new error is over some threshold, decay learning rate.
- class returnn.learning_rate_control.NewbobMultiEpoch(num_epochs, update_interval, relative_error_threshold, relative_error_grow_threshold, **kwargs)[source]¶
Like
NewbobRelative
, but looks at the average relative error over multiple epochs. This is useful together withpartition_epoch
fromDataset
.- Parameters:
num_epochs (int)
update_interval (int)
relative_error_threshold (float)
relative_error_grow_threshold (float)
- returnn.learning_rate_control.learning_rate_control_type(type_name)[source]¶
- Parameters:
type_name (str)
- Return type: