returnn.torch.data.pipeline

Code to create PyTorch datasets that can be used with the PyTorch DataLoader.

We make use of TorchData data pipelines.

Most functionality is implemented as a dataset/datapipe, as this seems to be the common way in PyTorch, as it is also commonly done in Fairseq:

This is also the intended way for TorchData.

We potentially could also implement some functionality as part of the data loader (v1), but DataLoader2 suggests to decouple this, as we do here.

We also have ChunkShuffleDataset on RETURNN dataset level. However, having this separate pure PyTorch implementation is useful to allow to use other PyTorch datasets more directly, including also HuggingFace datasets.

returnn.torch.data.pipeline.create_tensor(array: ndarray) Tensor | ndarray[source]

Adjust non-supported dtypes

Parameters:

array – numpy array to be converted

returnn.torch.data.pipeline.collate_batch(batch: List[Dict[str, ndarray]]) Dict[str, Tensor | ndarray][source]
Parameters:

batch

class returnn.torch.data.pipeline.ChunkingIterDataPipe(dataset: IterableDataset, chunking, *, min_chunk_size=0)[source]

Splits each sequence in the given dataset into chunks according to the ‘chunking’ config option. So it transforms one sequences into multiple sequences.

Parameters:
  • dataset – dataset to apply chunking to

  • chunking (None|int|(int,int)|dict|(dict,dict)) – tuple (chunk_size, chunk_step). If given as single value, value will be used for both. Both chunk_size and chunk_step can be given as a dict data_key -> size/step. This can be used to apply chunking to only a subset of all data keys, or to use different chunking for different data keys. (The number of resulting chunks has to be match though for all given data keys, i.e. sequence lengths have to be considered.)

class returnn.torch.data.pipeline.BatchingIterDataPipe(dataset: IterableDataset, batch_size=1, max_seqs=None)[source]

Converts a dataset yielding sequences (dict data_key -> array per sequence) into a dataset yielding lists of these sequences, i.e. batches. Sequences are grouped in-order according to the ‘max_tokens’ and ‘max_seqs’ batch size limits. Note, that batches are not yet merged into a single (padded) data array here, this happens in ‘collate_batch()’.

Parameters:
  • dataset – dataset to apply batching to

  • batch_size (int|dict[str,int]|None) – Maximum number of time steps (e.g. audio frames / words) in one batch (padding included). If given as a dict data_key -> value, sets different individual limits per data key. If None, no limit.

  • max_seqs (int|None) – maximum number of sequences in a batch, None means unlimited (also -1 to match TF backend)

class returnn.torch.data.pipeline.LenFilterDataPipe(dataset: IterableDataset, min_seq_length: int | NumbersDict | None = None, max_seq_length: int | NumbersDict | None = None)[source]

Removes sequences which are either too long or too short from a dataset Returns dataset yielding list of data lengths within the defined range

Parameters:
  • dataset – dataset to apply the filter to

  • min_seq_length – minimum sequence length either in general or per data_key via dict

  • max_seq_length – maximum sequence length either in general or per data_key via dict

returnn.torch.data.pipeline.create_data_loader_from_batches(batches_dataset: Dataset, loader_opts: Dict[str, Any] | None = None) DataLoader[source]

Create DataLoader based on dataset over batches, e.g. via BatchingIterDataPipe.