TFSprint

Like SprintErrorSignals.py but for TensorFlow.

TFSprint.py_get_sprint_automata_for_batch(sprint_opts, tags)[source]
Parameters:
  • sprint_opts (dict[str]) –
  • tags (list[str]) –
Returns:

(edges, weights, start_end_states)

Return type:

(numpy.ndarray, numpy.ndarray, numpy.ndarray)

TFSprint.get_sprint_automata_for_batch_op(sprint_opts, tags)[source]
Parameters:
  • sprint_opts (dict[str]) –
  • tags (tf.Tensor) – shape (batch,), of dtype string
Returns:

(edges, weights, start_end_states). all together in one automaton. edges are of shape (4, num_edges), each (from, to, emission-idx, seq-idx), of dtype int32. weights are of shape (num_edges,), of dtype float32. start_end_states are of shape (2, batch), each (start,stop) state idx, batch = len(tags), of dtype int32.

Return type:

(tf.Tensor, tf.Tensor, tf.Tensor)

TFSprint.py_get_sprint_loss_and_error_signal(sprint_opts, log_posteriors, seq_lengths, seq_tags)[source]
Parameters:
  • sprint_opts (dict[str]) –
  • log_posteriors (numpy.ndarray) – 3d (time,batch,label)
  • seq_lengths (numpy.ndarray) – 1d (batch)
  • seq_tags (list[str]) – seq names
Returns:

(loss, error_signal), error_signal has the same shape as posteriors. loss is a 1d-array (batch).

Return type:

(numpy.ndarray, numpy.ndarray)

TFSprint.get_sprint_loss_and_error_signal(sprint_opts, log_posteriors, seq_lengths, seq_tags)[source]
Parameters:
  • sprint_opts (dict[str]) –
  • log_posteriors (tf.Tensor) – 3d (time,batch,label)
  • seq_lengths (tf.Tensor) – 1d (batch,)
  • seq_tags (tf.Tensor) – 1d (batch,), seq names
Returns:

(loss, error_signal), error_signal has the same shape as posteriors. loss is a 1d-array (batch).

Return type:

(tf.Tensor, tf.Tensor)