returnn.learning_rate_control

Provides the learning rate scheduling logic. The base class is LearningRateControl.

class returnn.learning_rate_control.LearningRateControl(default_learning_rate, min_learning_rate=0.0, default_learning_rates=None, error_measure_key=None, relative_error_also_relative_to_learning_rate=False, min_num_epochs_per_new_learning_rate=0, relative_error_div_by_old=False, learning_rate_decay=1.0, learning_rate_growth=1.0, filename=None)[source]

Base class for learning rate control / scheduling.

Parameters:
  • default_learning_rate (float) – default learning rate. usually for epoch 1

  • default_learning_rates (list[float] | dict[int,float]) – learning rates

  • error_measure_key (str|list[str]|None) – for get_epoch_error_value() the key for EpochData.error which is a dict

  • min_num_epochs_per_new_learning_rate (int) – if the lr was recently updated, use it for at least N epochs

  • relative_error_div_by_old (bool) – if True, compute relative error as (new - old) / old.

  • learning_rate_decay (float|(float)->float)

  • learning_rate_growth (float|(float)->float)

  • filename (str) – load from and save to file

need_error_info = True[source]
class EpochData(*, learning_rate: float | None = None, error: Dict[str, float] | None = None, meta: Dict[str, Any] | None = None, **kwargs)[source]

Encapsulates all relevant information for one epoch, needed to perform learning rate scheduling, such as the individual scores (cv or train; cross-entropy or frame-error or whatever).

Parameters:
  • learning_rate

  • error – scores (loss values) and errors (frame error rates, etc)

  • meta – any other extra information (e.g. effective learning rate)

Note that this is serialized as EpochData(learningRate=…, error=…), and we keep that for compatibility, so that is why we have special handling for kwargs.

classmethod load_initial_kwargs_from_config(config)[source]
Return type:

dict[str]

classmethod load_initial_from_config(config)[source]
Return type:

LearningRateControl

calc_learning_rate_decay_or_grow(learning_rate, decay, grow=None)[source]
Parameters:
  • learning_rate (float)

  • decay (bool)

  • grow (bool|None) – default is not decay

Returns:

lr with decay or growth applied

Return type:

float

calc_learning_rate_for_epoch(epoch)[source]

:returns learning rate :rtype: float

calc_new_learning_rate_for_epoch(epoch)[source]
Parameters:

epoch (int)

Returns:

new learning rate for this epoch

Return type:

float

get_learning_rate_for_epoch(epoch)[source]
Return type:

float

set_default_learning_rate_for_epoch(epoch, learning_rate)[source]
get_last_epoch(epoch)[source]
Parameters:

epoch (int)

Returns:

last epoch before epoch where we have some epoch data

Return type:

int

get_most_recent_learning_rate(epoch, exclude_current=True)[source]
Parameters:
  • epoch (int)

  • exclude_current (bool)

Returns:

most learning rate before or including epoch

Return type:

float

calc_relative_error(old_epoch, new_epoch)[source]
Parameters:
  • old_epoch (int)

  • new_epoch (int)

Returns:

relative error between old epoch and new epoch

Return type:

float

set_epoch_error(epoch, error)[source]
get_error_key(epoch)[source]
Parameters:

epoch (int)

Returns:

key which we should look in scores/errors, for this epoch

Return type:

str

get_epoch_error_dict(epoch)[source]
Parameters:

epoch (int)

Return type:

dict[str,float]

get_epoch_error_value(epoch)[source]
Parameters:

epoch (int)

Returns:

error/score for the specific epoch, given the error-key, see get_error_key()

Return type:

float

get_epoch_error_key_value(epoch)[source]
Parameters:

epoch (int)

Returns:

key, error

Return type:

(str, float)

get_last_best_epoch(last_epoch, first_epoch=1, only_last_epochs=None, filter_score=inf, only_last_n=-1, min_score_dist=0.0)[source]
Parameters:
  • first_epoch (int) – will check all epochs >= first_epoch

  • last_epoch (int) – inclusive. will check all epochs <= last_epoch

  • only_last_epochs (int|None) – if set, will only check the last N epochs, inclusive

  • filter_score (float) – all epochs which values over this score are not considered

  • only_last_n (int) – if set (>=1), from the resulting list, we consider only the last only_last_n

  • min_score_dist (float) – filter out epochs where the diff to the most recent is not big enough

Returns:

the last best epoch. to get the details then, you might want to use getEpochErrorDict.

Return type:

int|None

save()[source]

Save the current epoch data to file (self.filename).

load()[source]

Loads the saved epoch data from file (self.filename).

class returnn.learning_rate_control.ConstantLearningRate(default_learning_rate, min_learning_rate=0.0, default_learning_rates=None, error_measure_key=None, relative_error_also_relative_to_learning_rate=False, min_num_epochs_per_new_learning_rate=0, relative_error_div_by_old=False, learning_rate_decay=1.0, learning_rate_growth=1.0, filename=None)[source]

Just a constant learning rate.

Parameters:
  • default_learning_rate (float) – default learning rate. usually for epoch 1

  • default_learning_rates (list[float] | dict[int,float]) – learning rates

  • error_measure_key (str|list[str]|None) – for get_epoch_error_value() the key for EpochData.error which is a dict

  • min_num_epochs_per_new_learning_rate (int) – if the lr was recently updated, use it for at least N epochs

  • relative_error_div_by_old (bool) – if True, compute relative error as (new - old) / old.

  • learning_rate_decay (float|(float)->float)

  • learning_rate_growth (float|(float)->float)

  • filename (str) – load from and save to file

need_error_info = False[source]
calc_learning_rate_for_epoch(epoch)[source]

Dummy constant learning rate. Returns initial learning rate. :type epoch: int :returns learning rate :rtype: float

class returnn.learning_rate_control.NewbobRelative(relative_error_threshold, **kwargs)[source]

If relative diff between old and new error is over some threshold, decay learning rate.

classmethod load_initial_kwargs_from_config(config)[source]
Return type:

dict[str]

calc_learning_rate_for_epoch(epoch)[source]

Newbob+ on train data. :type epoch: int :returns learning rate :rtype: float

class returnn.learning_rate_control.NewbobAbs(error_threshold, **kwargs)[source]

If absolute diff between old and new error is over some threshold, decay learning rate.

classmethod load_initial_kwargs_from_config(config)[source]
Return type:

dict[str]

calc_learning_rate_for_epoch(epoch)[source]

Newbob+ on train data.

:returns learning rate :rtype: float

class returnn.learning_rate_control.NewbobMultiEpoch(num_epochs, update_interval, relative_error_threshold, relative_error_grow_threshold, **kwargs)[source]

Like NewbobRelative, but looks at the average relative error over multiple epochs. This is useful together with partition_epoch from Dataset.

Parameters:
  • num_epochs (int)

  • update_interval (int)

  • relative_error_threshold (float)

  • relative_error_grow_threshold (float)

classmethod load_initial_kwargs_from_config(config)[source]
Return type:

dict[str]

calc_learning_rate_for_epoch(epoch)[source]

Newbob+ on train data. :type epoch: int :returns learning rate :rtype: float

returnn.learning_rate_control.learning_rate_control_type(type_name)[source]
Parameters:

type_name (str)

Return type:

type[LearningRateControl]|LearningRateControl

returnn.learning_rate_control.load_learning_rate_control_from_config(config)[source]
Return type:

LearningRateControl

returnn.learning_rate_control.demo()[source]

Demo run. Given some learning rate file (with scores / existing lrs), will calculate how lrs would have been set, given some config.