returnn.torch.engine
¶
Main engine for PyTorch
- class returnn.torch.engine.Engine(config: Config)[source]¶
PyTorch engine
- Parameters:
config
- init_train_from_config(config: Config | None = None, train_data: Dataset | None = None, dev_data: Dataset | None = None, eval_data: Dataset | None = None)[source]¶
- Parameters:
config
train_data
dev_data
eval_data
- eval_model(*, skip_already_evaluated: bool = False)[source]¶
Runs model on all eval datasets and calculates the loss.
- get_pt_model() Module | None [source]¶
- Returns:
PyTorch Module. in case this is using RF, it will return the wrapped module
- forward_with_callback(*, dataset: Dataset, callback: ForwardCallbackIface, dataset_init_epoch: bool = True, allow_skipping_seqs: bool = False)[source]¶
forward
- returnn.torch.engine.get_device_from_config_opt(device: str | None) ResultWithReason[str] [source]¶
- Parameters:
device – as in config
- Returns:
resolved device