returnn.sprint.error_signals
¶
This provides the Theano Op SprintErrorSigOp to get a loss and error signal which is calculated via Sprint. And there are helper classes to communicate with the Sprint subprocess to transfer the posteriors and get back the loss and error signal. It uses the SprintControl Sprint interface for the communication.
- class returnn.sprint.error_signals.SprintSubprocessInstance(sprintExecPath, minPythonControlVersion=2, sprintConfigStr='', sprintControlConfig=None, usePythonSegmentOrder=True)[source]¶
The Sprint instance which is used to calculate the error signal. Communication is over a pipe. We pass the fds via cmd-line to the child proc. Basic protocol with the subprocess (encoded via pickle):
- Commands:
“init”, name, version -> “ok”, child_name, version “exit” -> (exit) “get_loss_and_error_signal”, seg_name, seg_len, posteriors -> “ok”, loss, error_signal
Numpy arrays encoded via TaskSystem.Pickler (which is optimized for Numpy).
On the Sprint side, we handle this via the SprintControl Sprint interface.
- Parameters:
sprintExecPath (str) – this executable will be called for the sub proc.
minPythonControlVersion (int) – will be checked in the subprocess. via Sprint PythonControl
sprintConfigStr (str) – passed to Sprint as command line args. can have “config:” prefix - in that case, looked up in config. handled via eval_shell_str(), can thus have lazy content (if it is callable, will be called).
sprintControlConfig (dict[str]|None) – passed to SprintControl.init().
- get_loss_and_error_signal__send(seg_name, seg_len, log_posteriors)[source]¶
- Parameters:
seg_name (str) – the segment name (seq_tag)
seg_len (int) – the segment length in frames
log_posteriors (numpy.ndarray) – 2d (time,label) float array, log probs
- class returnn.sprint.error_signals.ReaderThread(instance, instance_idx, batch_idxs, tags, seq_lengths, log_posteriors, batch_loss, batch_error_signal)[source]¶
Sprint reader thread.
- Parameters:
instance (SprintSubprocessInstance)
instance_idx (int)
batch_idxs (list[int])
tags (list[str]) – seq names, length = batch
seq_lengths (numpy.ndarray) – 1d (batch)
log_posteriors (numpy.ndarray) – 3d (time,batch,label)
batch_loss (numpy.ndarray) – 1d (batch). will write result into it.
batch_error_signal (numpy.ndarray) – 3d (time,batch,label). will write results into it.
- class returnn.sprint.error_signals.SprintInstancePool(sprint_opts)[source]¶
This is a pool of Sprint instances. First, for each unique sprint_opts, there is a singleton
which can be accessed via get_global_instance.
- Then, this can be used in multiple ways.
get_batch_loss_and_error_signal.
…
- Parameters:
sprint_opts (dict[str])
- classmethod get_global_instance(sprint_opts)[source]¶
- Parameters:
sprint_opts (dict[str])
- Return type:
- get_batch_loss_and_error_signal(log_posteriors, seq_lengths, tags=None)[source]¶
- Parameters:
log_posteriors (numpy.ndarray) – 3d (time,batch,label)
seq_lengths (numpy.ndarray) – 1d (batch)
tags (list[str]) – seq names, length = batch
:rtype (numpy.ndarray, numpy.ndarray) :returns (loss, error_signal). error_signal has the same shape as posteriors. loss is a 1d-array (batch).
Note that this accesses some global references, like global current seg info, via the current Device instance. Thus this is expected to be run from the Device host proc,
inside from SprintErrorSigOp.perform.
This also expects that we don’t have chunked seqs.
- get_automata_for_batch(tags)[source]¶
- Parameters:
tags (list[str]|numpy.ndarray) – sequence names, used for Sprint (ndarray of shape (batch, max_str_len))
- Returns:
(edges, weights, start_end_states). all together in one automaton. edges are of shape (4, num_edges), each (from, to, emission-idx, seq-idx), of dtype uint32. weights are of shape (num_edges,), of dtype float32. start_end_states are of shape (2, batch), each (start,stop) state idx, batch = len(tags), of dtype uint32.
- Return type:
(numpy.ndarray, numpy.ndarray, numpy.ndarray)