returnn.frontend.loss

Loss functions

returnn.frontend.loss.cross_entropy(*, estimated: Tensor, target: Tensor, axis: Dim, estimated_type: str) Tensor[source]

target is supposed to be in probability space (normalized). It can also be sparse, i.e. contain class indices. estimated can be probs, log-probs or logits, specified via estimated_type.

Assuming both are in probability space, the cross entropy is:

H(target,estimated) = -reduce_sum(target * log(estimated), axis=axis)

= -matmul(target, log(estimated), reduce=axis)

In case you want label smoothing, you can use e.g.:

ce = nn.cross_entropy(
    target=nn.label_smoothing(target, 0.1),
    estimated=estimated)
Parameters:
  • estimated – probs, log-probs or logits, specified via estimated_type

  • target – probs, normalized, can also be sparse

  • axis – class labels dim over which softmax is computed

  • estimated_type – “probs”, “log-probs” or “logits”

Returns:

cross entropy (same Dims as ‘estimated’ but without ‘axis’)

returnn.frontend.loss.ctc_loss(*, logits: Tensor, logits_normalized: bool = False, targets: Tensor, input_spatial_dim: Dim, targets_spatial_dim: Dim, blank_index: int, max_approx: bool = False) Tensor[source]

Calculates the CTC loss.

Internally, this uses returnn.tf.native_op.ctc_loss() which is equivalent to tf.nn.ctc_loss but more efficient.

Output is of shape [B].

Parameters:
  • logits – (before softmax). shape [B…,input_spatial,C]

  • logits_normalized – whether the logits are already normalized (e.g. via log-softmax)

  • targets – sparse. shape [B…,targets_spatial] -> C

  • input_spatial_dim – spatial dim of input logits

  • targets_spatial_dim – spatial dim of targets

  • blank_index – vocab index of the blank symbol

  • max_approx – if True, use max instead of sum over alignments (max approx, Viterbi)

Returns:

loss shape [B…]

returnn.frontend.loss.edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = 'int32') Tensor[source]
Parameters:
  • a – [B,Ta]

  • a_spatial_dim – Ta

  • b – [B,Tb]

  • b_spatial_dim – Tb

  • dtype

Returns:

[B]