TFSprint
¶
Like SprintErrorSignals.py but for TensorFlow.
-
TFSprint.
py_get_sprint_automata_for_batch
(sprint_opts, tags)[source]¶ Parameters: - sprint_opts (dict[str]) –
- tags (list[str]) –
Returns: (edges, weights, start_end_states)
Return type: (numpy.ndarray, numpy.ndarray, numpy.ndarray)
-
TFSprint.
get_sprint_automata_for_batch_op
(sprint_opts, tags)[source]¶ Parameters: - sprint_opts (dict[str]) –
- tags (tf.Tensor) – shape (batch,), of dtype string
Returns: (edges, weights, start_end_states). all together in one automaton. edges are of shape (4, num_edges), each (from, to, emission-idx, seq-idx), of dtype int32. weights are of shape (num_edges,), of dtype float32. start_end_states are of shape (2, batch), each (start,stop) state idx, batch = len(tags), of dtype int32.
Return type: (tf.Tensor, tf.Tensor, tf.Tensor)
-
TFSprint.
py_get_sprint_loss_and_error_signal
(sprint_opts, log_posteriors, seq_lengths, seq_tags)[source]¶ Parameters: - sprint_opts (dict[str]) –
- log_posteriors (numpy.ndarray) – 3d (time,batch,label)
- seq_lengths (numpy.ndarray) – 1d (batch)
- seq_tags (list[str]) – seq names
Returns: (loss, error_signal), error_signal has the same shape as posteriors. loss is a 1d-array (batch).
Return type: (numpy.ndarray, numpy.ndarray)
-
TFSprint.
get_sprint_loss_and_error_signal
(sprint_opts, log_posteriors, seq_lengths, seq_tags)[source]¶ Parameters: - sprint_opts (dict[str]) –
- log_posteriors (tf.Tensor) – 3d (time,batch,label)
- seq_lengths (tf.Tensor) – 1d (batch,)
- seq_tags (tf.Tensor) – 1d (batch,), seq names
Returns: (loss, error_signal), error_signal has the same shape as posteriors. loss is a 1d-array (batch).
Return type: (tf.Tensor, tf.Tensor)