returnn.tf.native_op
#
TF implementation of NativeOp
.
Wrappers for most relevant NativeOp ops.
- class returnn.tf.native_op.OpDescription(in_info, out_info, c_fw_code, c_bw_code=None, c_extra_support_code=None, code_version=None, cpu_support=True, grad_input_map=None, name=None)[source]#
Meta-info about an op, used by
OpMaker
.- Parameters:
in_info (list[dict(str)]) –
each dict describes one input var. attribs in the dict:
int ndim: the ndim. tuple shape: tuple and can contain None for specific dimensions.
- optional attribs:
str dtype: “float32” by default. bool need_contiguous: false by default. int want_inplace: -1 by default. try to optimize to destroy input, on output-index.
”dummy_out” is a special value which will add another output.
bool is_inplace: false by default. whether the optimization was applied. str gradient: can be “disconnected”. see grad(). bool bw_input: True by default. add this param to the bw input.
other attribs are just ignored.
out_info (list[dict(str)]) –
like in_info. slightly different behavior for:
shape: we also allow refs to the in_info in the form (in-idx,dim). see infer_shape(). need_contiguous/want_inplace: used for bw, in case for bw_input == True.
c_fw_code (str) – C code for forward pass
c_extra_support_code (str|dict[str]) – C support code (for c_support_code)
c_bw_code (str|None) – C code for backward pass (for gradient)
code_version (tuple[int]) – will be returned by c_code_cache_version.
cpu_support (bool) –
grad_input_map (tuple[int]|callable) – selection of grad inputs. by default, we get all inputs + all outputs + all grad outputs.
name (str) – name
- classmethod from_gen_base(gen_base)[source]#
- Parameters:
gen_base (returnn.native_op.NativeOpGenBase|type[returnn.native_op.NativeOpGenBase]) –
- Return type:
- grad()[source]#
- Return type:
OpDescription|None
- class returnn.tf.native_op.OpMaker(description, compiler_opts=None, search_for_runtime_blas=True, search_for_numpy_blas=True, search_for_system_blas=True, blas_lib=None)[source]#
https://www.tensorflow.org/guide/extend/op
- Parameters:
description (OpDescription) –
compiler_opts (dict[str]|None) – passed on to OpCodeCompiler as kwargs
- returnn.tf.native_op.load_dump_file(filename)[source]#
See dump_to_file() in NativeOp.cpp.
- Parameters:
filename (str) –
- Return type:
numpy.ndarray
- returnn.tf.native_op.make_op(cls, **kwargs)[source]#
- Parameters:
cls (type[returnn.native_op.NativeOpGenBase]) –
kwargs – passed to OpMaker
- Returns:
op
- Return type:
(tf.Tensor) -> tuple[tf.Tensor]
- returnn.tf.native_op.make_lstm_op(**kwargs)[source]#
See
NativeLstmCell
for usage.- Returns:
op
- Return type:
(tf.Tensor) -> tuple[tf.Tensor]
- class returnn.tf.native_op.RecSeqCellOp(n_hidden, n_input_dim=None, n_input_dim_parts=None, input_is_sparse=False, step=None)[source]#
In TF terminology, this is a “fused” cell, i.e. the op loops over the time. Similar is e.g.
tf.contrib.rnnLSTMBlockFusedCell
.- Parameters:
n_hidden (int) –
n_input_dim (int) –
n_input_dim_parts (int|list[int]) –
input_is_sparse (bool) –
step (int) – what direction and step to use
- class returnn.tf.native_op.NativeLstmCell(forget_bias=0.0, **kwargs)[source]#
Native LSTM.
- Parameters:
forget_bias (float) –
- classmethod map_layer_inputs_to_op(z, rec_weights, i, initial_state=None)[source]#
Just like NativeOp.LstmGenericBase.map_layer_inputs_to_op().
- Parameters:
z (tf.Tensor) – Z: inputs: shape (time,batch,n_hidden*4)
rec_weights (tf.Tensor) – V_h / W_re: shape (n_hidden,n_hidden*4)
i (tf.Tensor) – index: shape (time,batch)
initial_state (tf.Tensor|None) – shape (batch,n_hidden)
- Return type:
(tf.Tensor,tf.Tensor,tf.Tensor,tf.Tensor)
- class returnn.tf.native_op.NativeLstmLowMemCell(**kwargs)[source]#
Native LSTM, low mem variant.
- Parameters:
n_hidden (int) –
n_input_dim (int) –
n_input_dim_parts (int|list[int]) –
input_is_sparse (bool) –
step (int) – what direction and step to use
- map_layer_inputs_to_op(x, weights, b, i, initial_state=None)[source]#
Just like NativeOp.LstmGenericBase.map_layer_inputs_to_op(). :param tf.Tensor x: inputs: shape (time,batch,n_input_dim) :param tf.Tensor weights: shape (n_input_dim+n_hidden,n_hidden*4) :param tf.Tensor b: shape (n_hidden*4,) :param tf.Tensor i: index: shape (time,batch) :param tf.Tensor|None initial_state: shape (batch,n_hidden) :rtype: tuple[tf.Tensor]
- class returnn.tf.native_op.NativeLstm2(rec_weight_dropout=0.0, rec_weight_dropout_shape=None, forget_bias=0.0, **kwargs)[source]#
Native LSTM 2. See
NativeOp.NativeLstm2
.- Parameters:
rec_weight_dropout (float) – weight dropout in the recurrent matrix, https://openreview.net/pdf?id=SyyGPP0TZ
rec_weight_dropout_shape (tuple[int|None]|None) – e.g. (None,1) to use dropout on all rec inputs (save memory)
forget_bias (float) –
- class returnn.tf.native_op.TwoDNativeLstmCell(pooling, **kwargs)[source]#
Native 2D LSTM.
- Parameters:
n_hidden (int) –
n_input_dim (int) –
n_input_dim_parts (int|list[int]) –
input_is_sparse (bool) –
step (int) – what direction and step to use
- classmethod map_layer_inputs_to_op(X, V_h, V_v, W, i, previous_state=None, previous_output=None, iteration=None)[source]#
Just like NativeOp.LstmGenericBase.map_layer_inputs_to_op(). :param tf.Tensor X: inputs: shape (timeT,timeS,batch,n_hidden*5) :param tf.Tensor V_h: W_re: shape (n_hidden,n_hidden*5) :param tf.Tensor V_v: W_re: shape (n_hidden,n_hidden*5) :param tf.Tensor W: :param tf.Tensor i: index: shape (time,batch) :param tf.Tensor previous_state: :param tf.Tensor previous_output: :param tf.Tensor iteration: :rtype: (tf.Tensor,tf.Tensor,tf.Tensor,tf.Tensor)
- returnn.tf.native_op.chunk(x, index, chunk_size, chunk_step)[source]#
- Parameters:
x (tf.Tensor) – (time,batch,dim)
index (tf.Tensor) – (time,batch)
chunk_size (int|tf.Tensor) –
chunk_step (int|tf.Tensor) –
- Returns:
out, oindex. out is of shape (chunk_size, n_batch * n_chunks, n_dim), oindex of shape (chunk_size, n_batch * n_chunks).
- Return type:
(tf.Tensor,tf.Tensor)
- returnn.tf.native_op.unchunk(x, index, chunk_size, chunk_step, n_time, n_batch)[source]#
- Parameters:
x (tf.Tensor) – output e.g. from
chunk()
index (tf.Tensor) –
chunk_size (int|tf.Tensor) –
chunk_step (int|tf.Tensor) –
n_time (tf.Tensor) –
n_batch (tf.Tensor) –
- Returns:
out, oindex, ofactors
- Return type:
(tf.Tensor,tf.Tensor,tf.Tensor)
- returnn.tf.native_op.make_fast_baum_welch_op(**kwargs)[source]#
- Returns:
op
- Return type:
(tf.Tensor) -> tuple[tf.Tensor]
- returnn.tf.native_op.fast_baum_welch(am_scores, edges, weights, start_end_states, float_idx, state_buffer=None)[source]#
- Parameters:
am_scores (tf.Tensor) – (time, batch, dim), in -log space
edges (tf.Tensor) – (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
weights (tf.Tensor) – (num_edges,), weights of the edges
start_end_states (tf.Tensor) – (2, batch), (start,end) state idx in automaton. there is only one single automaton.
float_idx (tf.Tensor) – (time, batch) -> 0 or 1 (index mask, via seq lens)
state_buffer (tf.Tensor) – (2, num_states)
- Returns:
(fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
- Return type:
(tf.Tensor, tf.Tensor)
- returnn.tf.native_op.fast_baum_welch_by_sprint_automata(am_scores, float_idx, tags, sprint_opts, tdp_scale=1.0)[source]#
- Parameters:
am_scores (tf.Tensor) – (time, batch, dim), in -log space
float_idx (tf.Tensor) – (time, batch) -> 0 or 1 (index mask, via seq lens)
tags (tf.Tensor) – (batch,) -> seq name (str)
tdp_scale (float) – weights are multiplied by this
sprint_opts (dict[str]) –
- Returns:
(fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
- Return type:
(tf.Tensor, tf.Tensor)
- returnn.tf.native_op.tf_fast_bw_fsa_staircase(seq_lens, **opts)[source]#
- Parameters:
seq_lens (tf.Tensor) – shape (batch,)
opts – passed to
Fsa.fast_bw_fsa_staircase()
- Returns:
edges, weights, start_end_states
- Return type:
(tf.Tensor, tf.Tensor, tf.Tensor)
- returnn.tf.native_op.get_ctc_fsa_fast_bw(targets, seq_lens, blank_idx, label_loop=True)[source]#
See
NativeOp.GetCtcFsaFastBwOp
. Generates a FSA with CTC topology. The output format is compatible tofast_baum_welch()
.- Parameters:
targets (tf.Tensor) – shape (batch,time), int32
seq_lens (tf.Tensor) – shape (batch), int32
blank_idx (int) – vocab index of the blank symbol
label_loop (bool) – True -> normal CTC; False -> RNA-like
- Returns:
edges, weights, start_end_states; edges is (4,num_edges), int32, edges of the graph (from,to,emission_idx,sequence_idx). weights is (num_edges,), float32. all zero. start_end_states is (2,batch), int32, (start,end) state idx in FSA.
- Return type:
(tf.Tensor,tf.Tensor,tf.Tensor)
- returnn.tf.native_op.fast_baum_welch_staircase(am_scores, seq_lens, **opts)[source]#
- Parameters:
am_scores (tf.Tensor) – (time, batch, dim), in -log space
seq_lens (tf.Tensor) – (batch,) -> values in [1, …, dim-1]
opts – passed to
Fsa.fast_bw_fsa_staircase()
- Returns:
(fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
- Return type:
(tf.Tensor, tf.Tensor)
- returnn.tf.native_op.ctc_loss(logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens, ctc_merge_repeated=True, logits_normalize=True, grad_wrt_softmax_in=True, blank_index=-1)[source]#
Similar to
tf.nn.ctc_loss()
. We use ourfast_baum_welch()
. Also seeFastBaumWelchLoss
.- Parameters:
logits (tf.Tensor) – (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
logits_seq_lens (tf.Tensor) – shape (batch,) of int32|int64
logits_time_major (bool) –
targets (tf.Tensor) – batch-major, [batch,time]
targets_seq_lens (tf.Tensor) – (batch,)
ctc_merge_repeated (bool) –
logits_normalize (bool) – apply log_softmax on logits (default). if False, you might also set grad_wrt_softmax_in=False
grad_wrt_softmax_in (bool) – assume
p(s|x) = softmax(logits)
, and define the gradient w.r.t. logits. This isp(s|x) - bw
, wherebw
is the Baum-Welch soft alignment. If logits are already normalized (e.g. we just uselog p(s|x) = logits
), the error signal to logits should be-bw
.blank_index (int) – vocab index of the blank symbol
- Returns:
loss, shape (batch,)
- Return type:
tf.Tensor
- returnn.tf.native_op.fast_viterbi(am_scores, am_seq_len, edges, weights, start_end_states)[source]#
- Parameters:
am_scores (tf.Tensor) – (time, batch, dim), in +log space (unlike fast_baum_welch)
am_seq_len (tf.Tensor) – (batch,), int32
edges (tf.Tensor) – (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
weights (tf.Tensor) – (num_edges,), weights of the edges
start_end_states (tf.Tensor) – (2, batch), (start,end) state idx in automaton. there is only one single automaton.
- Returns:
(alignment, scores), alignment is (time, batch), scores is (batch,), in +log space
- Return type:
(tf.Tensor, tf.Tensor)
- returnn.tf.native_op.ctc_loss_viterbi(logits, logits_seq_lens, logits_time_major, targets, targets_seq_lens, blank_index=-1)[source]#
Similar to
ctc_loss()
. However, instead of using the full sum, we use the best path (i.e. Viterbi instead of Baum-Welch). We use ourfast_viterbi()
.- Parameters:
logits (tf.Tensor) – (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
logits_seq_lens (tf.Tensor) – shape (batch,) of int32|int64
logits_time_major (bool) –
targets (tf.Tensor) – batch-major, [batch,time]
targets_seq_lens (tf.Tensor) – (batch,)
blank_index (int) – vocab index of the blank symbol
- Returns:
loss, shape (batch,)
- Return type:
tf.Tensor
- returnn.tf.native_op.edit_distance(a, a_len, b, b_len)[source]#
Wraps
NativeOp.EditDistanceOp
.- Parameters:
a (tf.Tensor) – (batch,time1), int32
a_len (tf.Tensor) – (batch,), int32
b (tf.Tensor) – (batch,time2), int32
b_len (tf.Tensor) – (batch,), int32
- Returns:
(batch,) tensor, int32, un-normalized edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.optimal_completion_edit_distance(a, a_len, b, b_len)[source]#
Wraps
NativeOp.OptimalCompletionEditDistanceOp
.- Parameters:
a (tf.Tensor) – (batch,time1), int32. prefix
a_len (tf.Tensor) – (batch,), int32
b (tf.Tensor) – (batch,time2), int32
b_len (tf.Tensor) – (batch,), int32
- Returns:
(batch,) tensor, int32, un-normalized edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.optimal_completion_edit_distance_per_successor(a, a_len, b, b_len, successors)[source]#
Wraps
NativeOp.OptimalCompletionEditDistancePerSuccessorOp
.- Parameters:
a (tf.Tensor) – (batch,time1), int32. prefix
a_len (tf.Tensor) – (batch,), int32
b (tf.Tensor) – (batch,time2), int32
b_len (tf.Tensor) – (batch,), int32
successors (tf.Tensor|int) – (n_labels,), int32. scalar means tf.range(successors)
- Returns:
(batch,n_labels) tensor, int32, un-normalized edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.next_edit_distance_row(last_row, a, a_n, a_ended, b, b_len)[source]#
Wraps
NativeOp.NextEditDistanceRowOp
.- Parameters:
last_row (tf.Tensor) – 2d (batch,b_time + 1), int32. last edit distances
a (tf.Tensor) – symbols. 1d (batch,), int32. current.
a_n (tf.Tensor) – scalar or 1d (batch,), int32. current position
a_ended (tf.Tensor) – 1d (batch,), int32 (casted from bool, because int32 easier to handle)
b (tf.Tensor) – symbols. 2d (batch,b_time), int32
b_len (tf.Tensor) – 1d (batch,), int32
- Returns:
2d (batch,b_time + 1), int32, next (unnormalized) edit distance row
- Return type:
tf.Tensor
- returnn.tf.native_op.edit_distance_via_next_edit_distance_row(a, a_len, b, b_len, optimal_completion=False, full_row_output=False)[source]#
This is mostly for demonstration and debugging. Should be equivalent to
edit_distance()
oroptimal_completion_edit_distance()
(which should be much faster).- Parameters:
a (tf.Tensor) – (batch,time1), int32
a_len (tf.Tensor) – (batch,), int32
b (tf.Tensor) – (batch,time2), int32
b_len (tf.Tensor) – (batch,), int32
optimal_completion (bool) – calc optimal completion edit distance instead
full_row_output (bool) – outputs the full final row
- Returns:
(batch,) or (batch,time2+1) tensor, int32, un-normalized edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.next_edit_distance_reduce(last_row, a, a_n, a_ended, b, b_len, optimal_completion=False, a_blank_idx=None)[source]#
Wraps
NativeOp.NextEditDistanceReduceOp
.- Parameters:
last_row (tf.Tensor) – 2d (batch,b_time + 1), int32. last edit distances
a (tf.Tensor) – symbols. 2d (batch|1,n_labels), int32. current.
a_n (tf.Tensor) – scalar or 1d (batch,), int32. current position
a_ended (tf.Tensor) – 1d (batch,), int32 (casted from bool, because int32 easier to handle)
b (tf.Tensor) – symbols. 2d (batch,b_time), int32
b_len (tf.Tensor) – 1d (batch,), int32
a_blank_idx (tf.Tensor|int|None) – scalar, int32
optimal_completion (bool|tf.Tensor) –
- Returns:
2d (batch,n_labels), int32, next (unnormalized) (optimal completion) edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.optimal_completion_edit_distance_per_successor_via_next_edit_distance(a, a_len, b, b_len, successors)[source]#
Uses
next_edit_distance_reduce()
andedit_distance_via_next_edit_distance_row()
. Mostly for demonstration/testing. In practice, you would do something similar, but in your own loop. Similar tooptimal_completion_edit_distance_per_successor()
, but the handling of ended sequences (froma
) is different.- Parameters:
a (tf.Tensor) – (batch,time1), int32. prefix
a_len (tf.Tensor) – (batch,), int32
b (tf.Tensor) – (batch,time2), int32
b_len (tf.Tensor) – (batch,), int32
successors (tf.Tensor|int) – (n_labels,), int32. scalar means tf.range(successors)
- Returns:
(batch,n_labels) tensor, int32, un-normalized edit distance
- Return type:
tf.Tensor
- returnn.tf.native_op.have_blocksparse_requirements()[source]#
- Returns:
whether we can use the OpenAI blocksparse module
- Return type:
bool