returnn.torch.engine
#
Main engine for PyTorch
- class returnn.torch.engine.Engine(config: Config)[source]#
PyTorch engine
- Parameters:
config –
- init_train_from_config(config: Config | None = None, train_data: Dataset | None = None, dev_data: Dataset | None = None, eval_data: Dataset | None = None)[source]#
- Parameters:
config –
train_data –
dev_data –
eval_data –
- eval_model(*, skip_already_evaluated: bool = False)[source]#
Runs model on all eval datasets and calculates the loss.
- get_pt_model() Module | None [source]#
- Returns:
PyTorch Module. in case this is using RF, it will return the wrapped module
- forward_with_callback(*, dataset: Dataset, callback: ForwardCallbackIface)[source]#
forward
- returnn.torch.engine.get_device_from_config_opt(device: str | None) ResultWithReason[str] [source]#
- Parameters:
device – as in config
- Returns:
resolved device