returnn.torch.engine

Main engine for PyTorch

returnn.torch.engine.random() x in the interval [0, 1).[source]
class returnn.torch.engine.Engine(config: Config)[source]

PyTorch engine

Parameters:

config

init_network_from_config(config: Config | None = None)[source]

init model

init_train_from_config(config: Config | None = None, train_data: Dataset | None = None, dev_data: Dataset | None = None, eval_data: Dataset | None = None)[source]
Parameters:
  • config

  • train_data

  • dev_data

  • eval_data

train()[source]

Main training loop.

init_train_epoch()[source]

init train (sub)epoch. LR etc

train_epoch()[source]

train one (sub)epoch

eval_model(*, skip_already_evaluated: bool = False)[source]

Runs model on all eval datasets and calculates the loss.

get_pt_model() Module | None[source]
Returns:

PyTorch Module. in case this is using RF, it will return the wrapped module

get_pt_optimizer() Optimizer | None[source]
Returns:

PyTorch optimizer

forward_with_callback(*, dataset: Dataset, callback: ForwardCallbackIface)[source]

forward

static delete_model(filename)[source]
Parameters:

filename (str)

Returns:

accumulated file-size in bytes of deleted files

Return type:

int

returnn.torch.engine.get_device_from_config_opt(device: str | None) ResultWithReason[str][source]
Parameters:

device – as in config

Returns:

resolved device