returnn.native_op#

Generic interface which automatically creates:

  • CPU and GPU (CUDA) op

  • inplace and not inplace

  • grad variants

See returnn.tf.native_op and returnn.theano.native_op for usage in TensorFlow and Theano.

See Native operations for more background.

class returnn.native_op.NativeOpBaseMixin(in_info, out_info, c_fw_code, c_bw_code=None, c_extra_support_code=None, code_version=None, cpu_support=True, grad_input_map=None, name=None)[source]#

The purpose of having this as a separate base class is to make this independent of any Theano specific functionality so that we can also use this base for example for TensorFlow.

Parameters:
  • in_info (list[dict(str)]) –

    each dict describes one input var. attribs in the dict:

    int ndim: the ndim. tuple shape: tuple and can contain None for specific dimensions.

    optional attribs:

    str dtype: “float32” by default. bool need_contiguous: false by default. int want_inplace: -1 by default. try to optimize to destroy input, on output-index.

    ”dummy_out” is a special value which will add another output.

    bool is_inplace: false by default. whether the optimization was applied. str gradient: can be “disconnected”. see grad(). bool bw_input: True by default. add this param to the bw input.

    other attribs are just ignored.

  • out_info (list[dict(str)]) –

    like in_info. slightly different behavior for:

    shape: we also allow refs to the in_info in the form (in-idx,dim). see infer_shape(). need_contiguous/want_inplace: used for bw, in case for bw_input == True.

  • c_fw_code (str) – C code for forward pass

  • c_extra_support_code (str|dict[str]) – C support code (for c_support_code)

  • c_bw_code (str|None) – C code for backward pass (for gradient)

  • code_version (tuple[int]) – will be returned by c_code_cache_version.

  • cpu_support (bool) –

  • grad_input_map (tuple[int]|callable) – selection of grad inputs. by default, we get all inputs + all outputs + all grad outputs.

  • name (str) – name

infer_shape(node, input_shapes)[source]#
Parameters:
  • node

  • input_shapes

Return type:

list[tuple[int]]

kwargs_for_grad_op()[source]#
Returns:

the kwargs for creating a NativeOp for the gradient op. e.g. includes in_info, out_info, etc

Return type:

dict[str]

Note: The inputs of the gradient are by default: fwd_op.inputs + fwd_op.outputs + output_grads. We filter them via self._filter_grad_inputs.

make_results_of_gradient(grad_op_outputs, disconnected_type=None)[source]#
Parameters:
  • grad_op_outputs (list[T]|tuple[T]) – this is already with dummy outputs removed

  • disconnected_type (S) –

Returns:

gradient for each input of our op

Return type:

list[T|S]

class returnn.native_op.NativeOpGenBase[source]#

Base interface for op generation. See NativeOp.__init__() for attribs.

in_info: Tuple[Dict[str]] = None[source]#
out_info: Tuple[Dict[str]] = None[source]#
c_fw_code: str = None[source]#
c_bw_code: str = None[source]#
c_extra_support_code: Dict[str, str] = None[source]#
code_version: Union[Tuple[int], int] = None[source]#
grad_input_map = None[source]#
theano_custom_grad = None[source]#
cpu_support = True[source]#
classmethod map_layer_inputs_to_op(*inputs)[source]#
Parameters:

inputs

Returns:

inputs

classmethod map_layer_output_from_op(*outputs)[source]#
Parameters:

outputs

Returns:

outputs[0]

class returnn.native_op.LstmGenericBase[source]#
inputs:
param Z:

{input,output,forget} gate + cell state. 3d (time,batch,dim*4)

param V_h:

recurrent matrix. 2d (dim,dim*4)

param c:

initial cell state. 2d (batch,dim)

param i:

index. 2d (time,batch) -> 0 or 1

outputs:
param Y:

output. 3d (time,batch,dim)

param H:

gates and cell state. 3d (time,batch,dim*4)

param d:

final cell state. 2d (batch,dim)

in_info: Tuple[Dict[str]] = ({'bw_out_var': {'shape': ((2, 0), (2, 1), (0, 1))}, 'name': 'Z', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None), 'want_inplace': 1}, {'name': 'V_h', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
out_info: Tuple[Dict[str]] = ({'bw_grad_var': {'want_inplace': 'dummy_out'}, 'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'bw_in_var': {'want_inplace': 0}, 'name': 'H', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((2, 0), (2, 1))})[source]#
classmethod grad_input_map(Z, V_h, c, i, Y, H, d, DY, DH, Dd)[source]#

Map grads.

c_extra_support_code: Dict[str, str] = {'lstm_bwd_kernel': '\n      DEF_KERNEL\n      void lstm_bwd_kernel(\n            float* delta, float* epsilon, const float* next_epsilon, const float* old_state,\n            bool old_state_strided, const float* Y, int n_cells, int n_batch, const float* i) {\n        //layout:\n        //delta[0*n_cells..1*n_cells-1] : input gate\n        //delta[1*n_cells..2*n_cells-1] : forget gate\n        //delta[2*n_cells..3*n_cells-1] : output gate\n        //delta[3*n_cells..4*n_cells-1] : cell state\n        //epsilon[0*n_cells..1*n_cells-1]: cell output derivative (later overwritten, see below)\n        //next_epsilon[0*n_cells..1*n_cells-1]: cell state derivative * forget_gate (of next timestep)\n        //repeated for every mini-batch\n\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int batch_offset = batch_idx * 4 * n_cells;\n          int cell_offset = idx % n_cells;\n          int start = batch_offset + cell_offset;\n          float i_batch = i[batch_idx];\n\n          float inpGate = delta[start + n_cells];\n          float fgtGate = delta[start + 2 * n_cells];\n          float outGate = delta[start + 3 * n_cells];\n          float oldState = old_state_strided ? old_state[start] : old_state[idx];\n          float state = delta[start];\n          float eps = epsilon[idx];\n\n          //avoid division by 0\n          float gc = tanhf(state); //g(c(t))\n          float gzc = (state - fgtGate * oldState) / fmaxf(inpGate, float(1e-16)); //g(z_c(t))\n\n          //delta_output\n          delta[start + 3 * n_cells] = outGate * (1.f - outGate) * gc * eps * i_batch;\n\n          //epsilon_c\n          float epsilon_c = (1.f - (gc * gc)) * outGate * eps;\n          epsilon_c += next_epsilon[idx];\n          epsilon[idx] = epsilon_c * fgtGate * i_batch + next_epsilon[idx] * (1.f - i_batch);\n\n          //delta_cell\n          delta[start] = inpGate * (1.f - (gzc * gzc)) * epsilon_c * i_batch;\n\n          //delta_forget\n          delta[start + 2 * n_cells] = fgtGate * (1.f - fgtGate) * oldState * epsilon_c * i_batch;\n\n          //delta_input\n          delta[start + n_cells] = inpGate * (1.f - inpGate) * gzc * epsilon_c * i_batch;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      ', 'lstm_kernel': '\n      DEF_KERNEL\n      void lstm_kernel(float* data, const float* old_state, bool old_state_strided,\n                       float* output, float* state_out, int n_cells, int n_batch, const float* i) {\n        //layout:\n        //data[0*n_cells..1*n_cells-1] : cell state\n        //data[1*n_cells..2*n_cells-1] : input gate\n        //data[2*n_cells..3*n_cells-1] : forget gate\n        //data[3*n_cells..4*n_cells-1] : output gate\n        //output[0*n_cells..1*n_cells-1]: cell output\n        //repeated for every mini-batch\n\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int start = batch_idx * 4 * n_cells + idx % n_cells;\n          float i_batch = i[batch_idx];\n\n          //input, forget and output gates\n          float inpGate = 1.f / (1.f + expf(-data[start + n_cells]));\n          float fgtGate = 1.f / (1.f + expf(-data[start + 2 * n_cells]));\n          float outGate = 1.f / (1.f + expf(-data[start + 3 * n_cells]));\n          float state = inpGate * tanhf(data[start]);\n          float old_state_batch = old_state_strided ? old_state[start] : old_state[idx];\n\n          state += fgtGate * old_state_batch;\n          state = state * i_batch + old_state_batch * (1.f - i_batch);\n\n          //cell output\n          output[idx] = outGate * tanhf(state) * i_batch;\n\n          data[start] = state;\n          data[start + n_cells] = inpGate;\n          data[start + 2 * n_cells] = fgtGate;\n          data[start + 3 * n_cells] = outGate;\n          if(state_out)\n            state_out[idx] = state;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    // Z*, V_h, c, i = input_names (*: inplace)\n    // Y, H, d = output_names\n    assert(n_inputs == 4);\n    assert(n_outputs == 3);\n    Ndarray* V_h = inputs[1];\n    Ndarray* c = inputs[2];\n    Ndarray* i = inputs[3];\n    Ndarray* Y = *outputs[0];\n    Ndarray* H = *outputs[1]; // inplace on Z\n    Ndarray* d = *outputs[2];\n\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    assert(Ndarray_DIMS(H)[2] %% 4 == 0); // 3 gates + cell\n    int n_cells = Ndarray_DIMS(H)[2] / 4;\n\n    assert(T > 0);\n    for(int x = 0; x < T; ++x) {\n      if(x > 0) {\n        //H += Y[x-1]*V_h\n        affine_y_x(x-1, Y,  x, V_h,  x, H);\n      }\n\n      start_dev_kernel(lstm_kernel, (\n        data_ptr(H, x),\n        x > 0 ? data_ptr(H, x - 1) : Ndarray_DEV_DATA(c),\n        x > 0,\n        data_ptr(Y, x),\n        (x == T - 1) ? Ndarray_DEV_DATA(d) : 0,\n        n_cells,\n        n_batch,\n        Ndarray_DEV_DATA(i) + x * n_batch\n      ));\n    }\n    HANDLE_LAST_ERROR();\n  '[source]#
c_bw_code: str = '\n    // V_h, c, i,   Y, H*,   DY*, Dd = input_names (*: inplace)\n    // DZ, DV_h, Dc, tmpDc = output_names\n    assert(n_inputs == 7);\n    assert(n_outputs == 4);\n    Ndarray* V_h = inputs[0];\n    Ndarray* c = inputs[1];\n    Ndarray* i = inputs[2];\n    Ndarray* Y = inputs[3];\n    Ndarray* Dd = inputs[6];\n    Ndarray* DZ = *outputs[0]; // inplace on H\n    Ndarray* DV_h = *outputs[1];\n    Ndarray* Dc = *outputs[2];\n    Ndarray* tmpDc = *outputs[3]; // (old DY), inplace buffer\n\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    assert(Ndarray_DIMS(DZ)[2] %% 4 == 0); // 3 gates + cell\n    int n_cells = Ndarray_DIMS(DZ)[2] / 4;\n\n    assert(T > 0);\n    for(int x = T - 1; x >= 0; --x) {\n      // add recurrent\n      bool rightBorder = (x == T - 1);\n      if(!rightBorder)\n        affine_y_x(x+1, DZ,  x, V_h,  x, tmpDc,  false, true);\n\n      start_dev_kernel(lstm_bwd_kernel, (\n        data_ptr(DZ, x),\n        data_ptr(tmpDc, x),\n        rightBorder ? Ndarray_DEV_DATA(Dd) : data_ptr(tmpDc, x + 1),\n        x > 0 ? data_ptr(DZ, x - 1) : Ndarray_DEV_DATA(c),\n        x > 0,\n        data_ptr(Y, x),\n        n_cells,\n        n_batch,\n        Ndarray_DEV_DATA(i) + x * n_batch\n      ));\n    }\n\n    //DV_h = Y[0..end-1]^T * DZ[1..end]\n    affine_global(Y, DZ, DV_h, true, false, 1, 0.0f);\n\n    Ndarray_DIMS_Type Dc_dim = Ndarray_HOST_DIMS(Dc);\n    Ndarray_memcpy(\n      Ndarray_DEV_DATA(Dc), Ndarray_DEV_DATA(tmpDc),\n      Dc_dim[0] * Dc_dim[1] * sizeof(float));\n    HANDLE_LAST_ERROR();\n  '[source]#
code_version: Union[Tuple[int], int] = ()[source]#
class returnn.native_op.LstmLowMem[source]#

This is designed to require minimal memory during training. It only stores the outputs and the cell states, i.e. it requires time * cells * 2 floats for memory in total.

inputs:
param X:

(time,batch,in_dim)

param W:

forward+recurrent matrix. 2d (in_dim+dim,dim*4)

param b:

bias. 1d (dim*4,)

param y0:

initial output|hidden state. 2d (batch,dim)

param c0:

initial cell state. 2d (batch,dim)

param i:

index. 2d (time,batch) -> 0 or 1

param start:

where to start. must be >=0, default is usually 0. dtype int, scalar.

param step:

+1 for fwd, -1 for bwd direction. can also be |step|>1 for wider steps. dtype int, scalar. for bwd (<0), will start at T-start-1.

outputs:
param Y:

output. 3d (time,batch,dim)

param C:

cell states. 3d (time,batch,dim). gradient ignored!

param d:

final cell state. 2d (batch,dim)

in_info: Tuple[Dict[str]] = ({'name': 'X', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'b', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'name': 'y0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'start', 'ndim': 0, 'shape': ()}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'step', 'ndim': 0, 'shape': ()})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (4, 1))}, {'name': 'C', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (4, 1))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 1), (4, 1))})[source]#
classmethod grad_input_map(X, W, b, y0, c0, i, start, step, Y, C, d, DY, DC, Dd)[source]#

Map args.

c_extra_support_code: Dict[str, str] = {'add_bias_kernel': '\n      DEF_KERNEL\n      void add_bias_kernel(int n_batch, int n_dim, float* x, float* b) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_batch * n_dim) {\n          int dim_idx = idx % n_dim;\n          x[idx] += b[dim_idx];\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', 'copy_x_h_kernel': '\n      DEF_KERNEL\n      void copy_x_h_kernel(\n        int n_batch, int n_in, int n_cells,\n        float* x_h,\n        float* x,\n        float* h)\n      {\n        int n_total_in = n_in + n_cells;\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_batch * n_total_in) {\n          int batch_idx = idx / n_total_in;\n          int in_dim_idx = idx % n_total_in;\n\n          if(in_dim_idx < n_in)\n            x_h[idx] = x[batch_idx * n_in + in_dim_idx];\n          else\n            x_h[idx] = h[batch_idx * n_cells + in_dim_idx - n_in];\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      ', 'inv_copy_x_h_kernel': '\n    DEF_KERNEL\n    void inv_copy_x_h_kernel(\n      int n_batch, int n_in, int n_cells,\n      float* x_h,\n      float* x,\n      float* h)\n    {\n      int n_total_in = n_in + n_cells;\n      int idx = threadIdx.x + blockDim.x * blockIdx.x;\n      while (idx < n_batch * n_total_in) {\n        int batch_idx = idx / n_total_in;\n        int in_dim_idx = idx % n_total_in;\n\n        if(in_dim_idx < n_in)\n          x[batch_idx * n_in + in_dim_idx] = x_h[idx];\n        else\n          h[batch_idx * n_cells + in_dim_idx - n_in] = x_h[idx];\n\n        idx += gridDim.x * blockDim.x;\n      }\n    }\n    ', 'lstm_bwd_kernel': '\n      DEF_KERNEL\n      void lstm_bwd_kernel(\n        int n_batch, int n_in, int n_cells, const float* mask,\n        float* x_h,\n        float* intern,\n        float* prev_c,\n        float* y,\n        float* c,\n        float* d_y,\n        float* d_h,\n        float* d_c,\n        float* d_intern,\n        float* d_b)\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int cell_idx = idx % n_cells;\n          int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n          float mask_b = mask[batch_idx];\n          float d_y_b = d_y[idx] * mask_b + d_h[idx];\n          float d_c_b = d_c[idx] * mask_b;\n          float prev_c_b = prev_c[idx];\n\n          // cell-in + input, forget and output gates\n          float cellIn = tanhf(intern[intern_offset]);\n          float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));\n          float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));\n          float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));\n\n          float c_b = prev_c_b * fgtGate + cellIn * inpGate;\n          float gc = tanhf(c_b);\n          float d_outGate_in = (1.f - outGate) * outGate * gc * d_y_b;\n          float d_c2 = d_c_b + outGate * d_y_b * (1.f - gc * gc);\n          float d_cellIn_in = (1.f - cellIn * cellIn) * inpGate * d_c2;\n          float d_inpGate_in = (1.f - inpGate) * inpGate * cellIn * d_c2;\n          float d_fgtGate_in = (1.f - fgtGate) * fgtGate * prev_c_b * d_c2;\n          d_c[idx] = fgtGate * d_c2 + d_c[idx] * (1.f - mask_b);\n\n          d_intern[intern_offset] = d_cellIn_in;\n          d_intern[intern_offset + n_cells] = d_inpGate_in;\n          d_intern[intern_offset + 2 * n_cells] = d_fgtGate_in;\n          d_intern[intern_offset + 3 * n_cells] = d_outGate_in;\n\n          elem_atomic_add(&d_b[cell_idx], d_cellIn_in);\n          elem_atomic_add(&d_b[cell_idx + n_cells], d_inpGate_in);\n          elem_atomic_add(&d_b[cell_idx + 2 * n_cells], d_fgtGate_in);\n          elem_atomic_add(&d_b[cell_idx + 3 * n_cells], d_outGate_in);\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      ', 'lstm_kernel': '\n      DEF_KERNEL\n      void lstm_kernel(\n        int n_batch, int n_cells, const float* mask,\n        float* intern,\n        float* prev_c,\n        float* y,\n        float* c)\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int cell_idx = idx % n_cells;\n          int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n          float prev_c_b = prev_c[idx];\n          float mask_b = mask[batch_idx];\n\n          // cell-in + input, forget and output gates\n          float cellIn = tanhf(intern[intern_offset]);\n          float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));\n          float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));\n          float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));\n\n          float c_b = (prev_c_b * fgtGate + cellIn * inpGate) * mask_b\n                      + prev_c_b * (1.f - mask_b);\n          c[idx] = c_b;\n          y[idx] = tanhf(c_b) * outGate * mask_b;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      '}[source]#
c_fw_code: str = '\n    // X, W, b, y0, c0, i, start, step = input_names\n    // Y, C, d = output_names\n    assert(n_inputs == 8);\n    assert(n_outputs == 3);\n    Ndarray* X = inputs[0];\n    Ndarray* W = inputs[1];\n    Ndarray* b = inputs[2];\n    Ndarray* y0 = inputs[3];\n    Ndarray* c0 = inputs[4];\n    Ndarray* i = inputs[5];\n    assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n    assert_cmp(Ndarray_NDIM(inputs[7]), ==, 0);\n    int start = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n    int step = Ndarray_DEV_DATA_int32_scalar(inputs[7]);\n    Ndarray* Y = *outputs[0];\n    Ndarray* C = *outputs[1];\n    Ndarray* d = *outputs[2];\n\n    assert_cmp(Ndarray_NDIM(X), ==, 3);\n    assert_cmp(Ndarray_NDIM(W), ==, 2);\n    assert_cmp(Ndarray_NDIM(b), ==, 1);\n    assert_cmp(Ndarray_NDIM(y0), ==, 2);\n    assert_cmp(Ndarray_NDIM(c0), ==, 2);\n    assert_cmp(Ndarray_NDIM(i), ==, 2);\n    assert_cmp(Ndarray_NDIM(Y), ==, 3);\n    assert_cmp(Ndarray_NDIM(C), ==, 3);\n    assert_cmp(Ndarray_NDIM(d), ==, 2);\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    int n_cells = Ndarray_DIMS(y0)[1];\n    int n_in = Ndarray_DIMS(X)[2];\n    assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(W)[0], ==, n_in + n_cells);\n    assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(d)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(d)[1], ==, n_cells);\n\n    float* x_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n    float* intern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float));  // 3 gates + in\n\n    assert_cmp(T, >, 0);\n    assert_cmp(start, >=, 0);\n    assert_cmp(start, <, T);\n    assert_cmp(step, !=, 0);\n    int end = T - 1;\n    if(step < 0) {\n      end = start;\n      start = T - start - 1;\n    }\n    int t = start;\n    for(; (step > 0) ? (t <= end) : (t >= end); t += step) {\n      // x_h = X[t], Y[t-1]\n      start_dev_kernel(copy_x_h_kernel,\n        (n_batch, n_in, n_cells, x_h, data_ptr(X, t), (t != start) ? data_ptr(Y, t-step) : Ndarray_DEV_DATA(y0)));\n      // intern = x_h * W\n      affine_raw(\n        x_h, n_batch, n_in + n_cells,\n        Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n        intern, n_batch, n_cells * 4,\n        false, false, 0.0);\n      // intern += b\n      start_dev_kernel(add_bias_kernel, (\n        n_batch, n_cells * 4, intern, Ndarray_DEV_DATA(b)));\n\n      start_dev_kernel(lstm_kernel, (\n        n_batch,\n        n_cells,\n        Ndarray_DEV_DATA(i) + t * n_batch,\n        intern,\n        (t != start) ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n        data_ptr(Y, t),  // out\n        data_ptr(C, t)  // out\n      ));\n    }\n    HANDLE_LAST_ERROR();\n\n    device_free(x_h);\n    device_free(intern);\n\n    Ndarray_memcpy(Ndarray_DEV_DATA(d), data_ptr(C, t - step), n_batch * n_cells * sizeof(float));\n  '[source]#
c_bw_code: str = '\n    // X, W, b, y0, c0, i, start, step,   Y, C,   DY, Dd = input_names\n    // DX, DW, Db, Dh, Dc = output_names\n    assert(n_inputs == 12);\n    assert(n_outputs == 5);\n    Ndarray* X = inputs[0];\n    Ndarray* W = inputs[1];\n    Ndarray* b = inputs[2];\n    Ndarray* y0 = inputs[3];\n    Ndarray* c0 = inputs[4];\n    Ndarray* i = inputs[5];\n    assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n    assert_cmp(Ndarray_NDIM(inputs[7]), ==, 0);\n    int start = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n    int step = Ndarray_DEV_DATA_int32_scalar(inputs[7]);\n    Ndarray* Y = inputs[8];\n    Ndarray* C = inputs[9];\n    Ndarray* DY = inputs[10];\n    Ndarray* Dd = inputs[11];\n    Ndarray* DX = *outputs[0];\n    Ndarray* DW = *outputs[1];\n    Ndarray* Db = *outputs[2];\n    Ndarray* Dh = *outputs[3];\n    Ndarray* Dc = *outputs[4];\n\n    assert_cmp(Ndarray_NDIM(X), ==, 3);\n    assert_cmp(Ndarray_NDIM(W), ==, 2);\n    assert_cmp(Ndarray_NDIM(b), ==, 1);\n    assert_cmp(Ndarray_NDIM(y0), ==, 2);\n    assert_cmp(Ndarray_NDIM(c0), ==, 2);\n    assert_cmp(Ndarray_NDIM(i), ==, 2);\n    assert_cmp(Ndarray_NDIM(Y), ==, 3);\n    assert_cmp(Ndarray_NDIM(C), ==, 3);\n    assert_cmp(Ndarray_NDIM(DY), ==, 3);\n    assert_cmp(Ndarray_NDIM(Dd), ==, 2);\n    assert_cmp(Ndarray_NDIM(DX), ==, 3);\n    assert_cmp(Ndarray_NDIM(DW), ==, 2);\n    assert_cmp(Ndarray_NDIM(Db), ==, 1);\n    assert_cmp(Ndarray_NDIM(Dh), ==, 2);\n    assert_cmp(Ndarray_NDIM(Dc), ==, 2);\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    int n_cells = Ndarray_DIMS(y0)[1];\n    int n_in = Ndarray_DIMS(X)[2];\n    assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(W)[0], ==, n_in + n_cells);\n    assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(DY)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(DY)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(DY)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Dd)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dd)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(DX)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(DX)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(DX)[2], ==, n_in);\n    assert_cmp(Ndarray_DIMS(DW)[0], ==, n_in + n_cells);\n    assert_cmp(Ndarray_DIMS(DW)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(Db)[0], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(Dh)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dh)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Dc)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dc)[1], ==, n_cells);\n\n    float* x_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n    float* intern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float));  // 3 gates + in\n    float* Dx_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n    float* Dintern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float));  // 3 gates + in\n\n    // We will work inplace on DX/DW/Db.\n    Ndarray_memset(Ndarray_DEV_DATA(DX), 0, T * n_batch * n_in * sizeof(float));\n    Ndarray_memset(Ndarray_DEV_DATA(DW), 0, (n_in + n_cells) * n_cells * 4 * sizeof(float));\n    Ndarray_memset(Ndarray_DEV_DATA(Db), 0, n_cells * 4 * sizeof(float));\n    // We will work inplace on Dh.\n    Ndarray_memset(Ndarray_DEV_DATA(Dh), 0, n_batch * n_cells * sizeof(float));\n    // We will work inplace on Dc, and init it with Dd.\n    Ndarray_memcpy(Ndarray_DEV_DATA(Dc), Ndarray_DEV_DATA(Dd), n_batch * n_cells * sizeof(float));\n\n    assert_cmp(T, >, 0);\n    assert_cmp(start, >=, 0);\n    assert_cmp(start, <, T);\n    assert_cmp(step, !=, 0);\n    int end = T - 1;\n    if(step < 0) {\n      end = start;\n      start = T - start - 1;\n    }\n    int t = end;  // go backwards\n    for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n      bool right = (step > 0) ? (t - step >= start) : (t - step <= start);\n\n      // TODO: correct handling of mask in grad, fwd, initial cell,hidden, etc\n      // x_h = X[t], Y[t-1]\n      start_dev_kernel(copy_x_h_kernel,\n        (n_batch, n_in, n_cells,\n         x_h, data_ptr(X, t), right ? data_ptr(Y, t-step) : Ndarray_DEV_DATA(y0)));\n\n      // intern = x_h * W\n      affine_raw(\n        x_h, n_batch, n_in + n_cells,\n        Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n        intern, n_batch, n_cells * 4,\n        false, false, 0.0);\n      // intern += b\n      start_dev_kernel(add_bias_kernel, (\n        n_batch, n_cells * 4, intern, Ndarray_DEV_DATA(b)));\n\n      start_dev_kernel(lstm_bwd_kernel, (\n        n_batch,\n        n_in,\n        n_cells,\n        Ndarray_DEV_DATA(i) + t * n_batch,\n        x_h,\n        intern,\n        right ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n        data_ptr(Y, t),\n        data_ptr(C, t),\n        data_ptr(DY, t),\n        Ndarray_DEV_DATA(Dh),  // error from prev frame, excluding DY. updated below\n        Ndarray_DEV_DATA(Dc),  // in+out, working inplace. also error from prev frame, initially Dd\n        Dintern,  // out\n        Ndarray_DEV_DATA(Db)  // out\n      ));\n\n      // Dx_h = Dintern * W^T\n      affine_raw(\n        Dintern, n_batch, n_cells * 4,\n        Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n        Dx_h, n_batch, n_in + n_cells,\n        false, true, 0.0);\n\n      // DW += x_h^T * Dintern\n      affine_raw(\n        x_h, n_batch, n_in + n_cells,\n        Dintern, n_batch, n_cells * 4,\n        Ndarray_DEV_DATA(DW), n_in + n_cells, n_cells * 4,\n        true, false);\n\n      // DX[t], Dh = Dx_h\n      start_dev_kernel(inv_copy_x_h_kernel,\n        (n_batch, n_in, n_cells, Dx_h, data_ptr(DX, t), Ndarray_DEV_DATA(Dh)));\n    }\n    HANDLE_LAST_ERROR();\n\n    device_free(x_h);\n    device_free(intern);\n    device_free(Dx_h);\n    device_free(Dintern);\n  '[source]#
class returnn.native_op.NativeLstm2[source]#

Yet another LSTM kernel. This kernel is about 27% than NativeLstm, and also has some more options (like the direction). But it requires time * batch * cells more memory, thus time * batch * cells * 6 in total.

inputs:
param X:

(time,batch,dim*4)

param W:

recurrent matrix. 2d (dim,dim*4)

param y0:

initial output|hidden state. 2d (batch,dim)

param c0:

initial cell state. 2d (batch,dim)

param i:

index. 2d (time,batch) -> 0 or 1

param start:

where to start. must be >=0, default is usually 0. dtype int, scalar.

param step:

+1 for fwd, -1 for bwd direction. can also be |step|>1 for wider steps. dtype int, scalar. for bwd (<0), will start at T-start-1.

outputs:
param Y:

output. 3d (time,batch,dim)

param C:

cell states. 3d (time,batch,dim). gradient ignored!

param H:

cell-in + gates. 3d (time,batch,dim*4). gradient ignored!

param d:

final cell state. 2d (batch,dim)

in_info: Tuple[Dict[str]] = ({'name': 'X', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'start', 'ndim': 0, 'shape': ()}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'step', 'ndim': 0, 'shape': ()})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'name': 'C', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'name': 'H', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 1))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 1), (1, 0))})[source]#
classmethod grad_input_map(X, W, y0, c0, i, start, step, Y, C, H, d, DY, DC, DH, Dd)[source]#
c_extra_support_code: Dict[str, str] = {'lstm_bwd_kernel': '\n      DEF_KERNEL\n      void lstm_bwd_kernel(\n        int n_batch, int n_cells, const float* mask,\n        float* h,\n        float* prev_c,\n        float* y,\n        float* c,\n        float* d_y,\n        float* d_h,\n        float* d_c,\n        float* d_x,\n        float* d_x0)\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int cell_idx = idx % n_cells;\n          int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n          float mask_b = mask[batch_idx];\n          float d_y_b = (d_y[idx] + d_h[idx]) * mask_b;\n          float d_c_b = d_c[idx] * mask_b;\n          float prev_c_b = prev_c[idx];\n\n          // cell-in + input, forget and output gates\n          float cellIn = h[intern_offset];\n          float inpGate = h[intern_offset + n_cells];\n          float fgtGate = h[intern_offset + 2 * n_cells];\n          float outGate = h[intern_offset + 3 * n_cells];\n\n          float c_b = prev_c_b * fgtGate + cellIn * inpGate;\n          float gc = tanhf(c_b);\n          float d_outGate_in = (1.f - outGate) * outGate * gc * d_y_b;\n          float d_c2 = d_c_b + outGate * d_y_b * (1.f - gc * gc);\n          float d_cellIn_in = (1.f - cellIn * cellIn) * inpGate * d_c2;\n          float d_inpGate_in = (1.f - inpGate) * inpGate * cellIn * d_c2;\n          float d_fgtGate_in = (1.f - fgtGate) * fgtGate * prev_c_b * d_c2;\n          d_c[idx] = fgtGate * d_c2 + d_c[idx] * (1.f - mask_b);\n\n          d_x[intern_offset] = d_cellIn_in;\n          d_x[intern_offset + n_cells] = d_inpGate_in;\n          d_x[intern_offset + 2 * n_cells] = d_fgtGate_in;\n          d_x[intern_offset + 3 * n_cells] = d_outGate_in;\n\n          #define set_x0(off) { d_x0[off] = d_x[off] + d_x0[off] * (1.f - mask_b); }\n          set_x0(intern_offset);\n          set_x0(intern_offset + n_cells);\n          set_x0(intern_offset + 2 * n_cells);\n          set_x0(intern_offset + 3 * n_cells);\n          #undef set_x0\n\n          // Reset if used frame, otherwise leave as-is.\n          d_h[idx] *= (1.f - mask_b);\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      ', 'lstm_kernel': '\n      DEF_KERNEL\n      void lstm_kernel(\n        int n_batch, int n_cells, const float* mask,\n        float* h,\n        float* prev_y,\n        float* prev_c,\n        float* y,\n        float* c,\n        float* y_prev_out)\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_cells * n_batch) {\n          int batch_idx = idx / n_cells;\n          int cell_idx = idx % n_cells;\n          int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n          float prev_c_b = prev_c[idx];\n          float mask_b = mask[batch_idx];\n\n          // cell-in + input, forget and output gates\n          float cellIn = tanhf(h[intern_offset]);\n          float inpGate = 1.f / (1.f + expf(-h[intern_offset + n_cells]));\n          float fgtGate = 1.f / (1.f + expf(-h[intern_offset + 2 * n_cells]));\n          float outGate = 1.f / (1.f + expf(-h[intern_offset + 3 * n_cells]));\n\n          h[intern_offset] = cellIn;\n          h[intern_offset + n_cells] = inpGate;\n          h[intern_offset + 2 * n_cells] = fgtGate;\n          h[intern_offset + 3 * n_cells] = outGate;\n\n          float c_b = (prev_c_b * fgtGate + cellIn * inpGate) * mask_b\n                    + prev_c_b * (1.f - mask_b);\n          c[idx] = c_b;\n          float y_b = tanhf(c_b) * outGate * mask_b;\n          y[idx] = y_b;\n          y_prev_out[idx] = y_b + prev_y[idx] * (1.f - mask_b);\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n      '}[source]#
c_fw_code: str = '\n    // X, W, y0, c0, i, start, step = input_names\n    // Y, C, H, d = output_names\n    assert(n_inputs == 7);\n    assert(n_outputs == 4);\n    Ndarray* X = inputs[0];\n    Ndarray* W = inputs[1];\n    Ndarray* y0 = inputs[2];\n    Ndarray* c0 = inputs[3];\n    Ndarray* i = inputs[4];\n    assert_cmp(Ndarray_NDIM(inputs[5]), ==, 0);\n    assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n    int start = Ndarray_DEV_DATA_int32_scalar(inputs[5]);\n    int step = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n    Ndarray* Y = *outputs[0];\n    Ndarray* C = *outputs[1];\n    Ndarray* H = *outputs[2];\n    Ndarray* d = *outputs[3];\n\n    assert_cmp(Ndarray_NDIM(X), ==, 3);\n    assert_cmp(Ndarray_NDIM(W), ==, 2);\n    assert_cmp(Ndarray_NDIM(y0), ==, 2);\n    assert_cmp(Ndarray_NDIM(c0), ==, 2);\n    assert_cmp(Ndarray_NDIM(i), ==, 2);\n    assert_cmp(Ndarray_NDIM(Y), ==, 3);\n    assert_cmp(Ndarray_NDIM(C), ==, 3);\n    assert_cmp(Ndarray_NDIM(H), ==, 3);\n    assert_cmp(Ndarray_NDIM(d), ==, 2);\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    int n_cells = Ndarray_DIMS(y0)[1];\n    assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(X)[2], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(W)[0], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(H)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(H)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(H)[2], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(d)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(d)[1], ==, n_cells);\n\n    if(T == 0) {\n      Ndarray_memcpy(Ndarray_DEV_DATA(d), Ndarray_DEV_DATA(c0), n_batch * n_cells * sizeof(float));\n\n    } else {  // T > 0\n      // It makes the backprop with step<0 easier to implement,\n      // esp. the DW = Y[0..T-2]^T * DX[1..T-1] calculation,\n      // if we can have Y[t] = 0 where mask[t] = 0.\n      // That is why we need to keep track of Y[t-1] explicitly.\n      float* y_prev = (float*) device_malloc(n_batch * n_cells * sizeof(float));\n\n      // H = X\n      Ndarray_memcpy(Ndarray_DEV_DATA(H), Ndarray_DEV_DATA(X), T * n_batch * n_cells * 4 * sizeof(float));\n\n      assert_cmp(T, >, 0);\n      assert_cmp(start, >=, 0);\n      assert_cmp(start, <, T);\n      assert_cmp(step, !=, 0);\n      int end = T - 1;\n      if(step < 0) {\n        end = 0;\n        start = T - start - 1;\n      }\n      int t = start;\n      for(; (step > 0) ? (t <= end) : (t >= end); t += step) {\n        // H[t] += Y[t-1] * W\n        affine_raw(\n          (t != start) ? y_prev : Ndarray_DEV_DATA(y0), n_batch, n_cells,\n          Ndarray_DEV_DATA(W), n_cells, n_cells * 4,\n          data_ptr(H, t), n_batch, n_cells * 4,\n          false, false);\n\n        start_dev_kernel(lstm_kernel, (\n          n_batch,\n          n_cells,\n          Ndarray_DEV_DATA(i) + t * n_batch,\n          data_ptr(H, t),  // inplace\n          (t != start) ? y_prev : Ndarray_DEV_DATA(y0),\n          (t != start) ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n          data_ptr(Y, t),  // out\n          data_ptr(C, t),  // out\n          y_prev  // out\n        ));\n      }\n      HANDLE_LAST_ERROR();\n\n      Ndarray_memcpy(Ndarray_DEV_DATA(d), data_ptr(C, t - step), n_batch * n_cells * sizeof(float));\n\n      device_free(y_prev);\n    }\n  '[source]#
c_bw_code: str = '\n    // X, W, y0, c0, i, start, step,   Y, C, H,   DY, Dd = input_names\n    // DX, DW, Dy0, Dc0 = output_names\n    assert(n_inputs == 12);\n    assert(n_outputs == 4);\n    Ndarray* X = inputs[0];\n    Ndarray* W = inputs[1];\n    Ndarray* y0 = inputs[2];\n    Ndarray* c0 = inputs[3];\n    Ndarray* i = inputs[4];\n    assert_cmp(Ndarray_NDIM(inputs[5]), ==, 0);\n    assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n    int start = Ndarray_DEV_DATA_int32_scalar(inputs[5]);\n    int step = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n    Ndarray* Y = inputs[7];\n    Ndarray* C = inputs[8];\n    Ndarray* H = inputs[9];\n    Ndarray* DY = inputs[10];\n    Ndarray* Dd = inputs[11];\n    Ndarray* DX = *outputs[0];\n    Ndarray* DW = *outputs[1];\n    Ndarray* Dy0 = *outputs[2];\n    Ndarray* Dc0 = *outputs[3];\n\n    assert_cmp(Ndarray_NDIM(X), ==, 3);\n    assert_cmp(Ndarray_NDIM(W), ==, 2);\n    assert_cmp(Ndarray_NDIM(y0), ==, 2);\n    assert_cmp(Ndarray_NDIM(c0), ==, 2);\n    assert_cmp(Ndarray_NDIM(i), ==, 2);\n    assert_cmp(Ndarray_NDIM(Y), ==, 3);\n    assert_cmp(Ndarray_NDIM(C), ==, 3);\n    assert_cmp(Ndarray_NDIM(H), ==, 3);\n    assert_cmp(Ndarray_NDIM(DY), ==, 3);\n    assert_cmp(Ndarray_NDIM(Dd), ==, 2);\n    assert_cmp(Ndarray_NDIM(DX), ==, 3);\n    assert_cmp(Ndarray_NDIM(DW), ==, 2);\n    assert_cmp(Ndarray_NDIM(Dy0), ==, 2);\n    assert_cmp(Ndarray_NDIM(Dc0), ==, 2);\n    long T = Ndarray_DIMS(i)[0];\n    int n_batch = Ndarray_DIMS(i)[1];\n    int n_cells = Ndarray_DIMS(y0)[1];\n    assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(X)[2], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(W)[0], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(H)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(H)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(H)[2], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(DY)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(DY)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(DY)[2], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Dd)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dd)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(DX)[0], ==, T);\n    assert_cmp(Ndarray_DIMS(DX)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(DX)[2], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(DW)[0], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(DW)[1], ==, n_cells * 4);\n    assert_cmp(Ndarray_DIMS(Dy0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dy0)[1], ==, n_cells);\n    assert_cmp(Ndarray_DIMS(Dc0)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(Dc0)[1], ==, n_cells);\n\n    // We will work inplace on DW.\n    Ndarray_memset(Ndarray_DEV_DATA(DW), 0, n_cells * n_cells * 4 * sizeof(float));\n    // We will work inplace on (Dy0) DY[t], initially 0.\n    Ndarray_memset(Ndarray_DEV_DATA(Dy0), 0, n_batch * n_cells * sizeof(float));\n    // We will work inplace on (Dc0) DC[t], and init it with Dd.\n    Ndarray_memcpy(Ndarray_DEV_DATA(Dc0), Ndarray_DEV_DATA(Dd), n_batch * n_cells * sizeof(float));\n\n    if(T == 0) {\n      // just do nothing. at least do not crash\n\n    } else {\n      // Need to keep track of (logical) DX[0], which in practice (masking, step<0)\n      // can be different from data_ptr(DX, start).\n      float* dx0 = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float));\n      Ndarray_memset(dx0, 0, n_batch * n_cells * 4 * sizeof(float));\n\n      assert_cmp(T, >, 0);\n      assert_cmp(start, >=, 0);\n      assert_cmp(start, <, T);\n      assert_cmp(step, !=, 0);\n      int abs_step = std::abs(step);\n\n      if(abs_step > 1 || start > 0)\n        // Normally the kernel would visit and reset all DX.\n        // But with abs_step>1 or start>0, we will not visit all. Reset now.\n        Ndarray_memset(Ndarray_DEV_DATA(DX), 0, T * n_batch * n_cells * 4 * sizeof(float));\n\n      // e.g.:\n      // step=1, start=0, T=10 -> num_steps=10=T\n      // step=5, start=0, T=10 -> num_steps=2=T/step\n      // step=5, start=0, T=9  -> num_steps=2=(T+step-1)/step\n      // step=5, start=0, T=6  -> num_steps=2=(T+step-1)/step\n      // step=5, start=0, T=5  -> num_steps=1=(T+step-1)/step\n      // step=5, start=4, T=10 -> num_steps=2=(T-start+step-1)/step\n      // step=-5, start=0, T=10 -> num_steps=2=T/abs_step\n      // step=-5, start=0, T=9  -> num_steps=2=(T+abs_step-1)/abs_step\n      // step=-5, start=4, T=10 -> num_steps=2=(T-start+abs_step-1)/abs_step\n      int num_steps = (T - start + abs_step - 1) / abs_step;\n      assert_cmp(num_steps, >, 0);\n      if(step < 0)\n        start = T - start - 1;\n      int end = start + (num_steps - 1) * step;  // inclusive\n      assert_cmp(end, >=, 0);\n      assert_cmp(end, <, T);\n      int t = end;  // go backwards\n      for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n        bool right = (step > 0) ? (t - step >= start) : (t - step <= start);\n\n        start_dev_kernel(lstm_bwd_kernel, (\n          n_batch,\n          n_cells,\n          Ndarray_DEV_DATA(i) + t * n_batch,\n          data_ptr(H, t),\n          right ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n          data_ptr(Y, t),\n          data_ptr(C, t),\n          data_ptr(DY, t),\n          Ndarray_DEV_DATA(Dy0),  // in+out, error from prev frame, excluding DY. reset here, updated below\n          Ndarray_DEV_DATA(Dc0),  // in+out, working inplace. also error from prev frame, initially Dd\n          data_ptr(DX, t),  // out\n          dx0  // out\n        ));\n\n        // (Dy0) DY[t-1] += DX[t] * W^T\n        affine_raw(\n          data_ptr(DX, t), n_batch, n_cells * 4,\n          Ndarray_DEV_DATA(W), n_cells, n_cells * 4,\n          Ndarray_DEV_DATA(Dy0), n_batch, n_cells,\n          false, true);\n      }\n\n      //DW = Y[0..T-2]^T * DX[1..T-1]  (if step==1)\n      if(num_steps > 1) {\n        if(abs_step == 1) {\n          affine_raw(\n            data_ptr(Y, std::min(start, end) + std::max(0, -step)), (num_steps - 1) * n_batch, n_cells,\n            data_ptr(DX, std::min(start, end) + std::max(0, step)), (num_steps - 1) * n_batch, n_cells * 4,\n            Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n            true, false, 0.0f, 1.0f);\n        } else {\n          // Unfortunately we cannot do efficient striding. Thus loop again.\n          t = end - step;  // one before\n          for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n            affine_raw(\n              data_ptr(Y, t), n_batch, n_cells,\n              data_ptr(DX, t + step), n_batch, n_cells * 4,\n              Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n              true, false);\n          }\n        }\n      }\n      HANDLE_LAST_ERROR();\n\n      //DW += y0^T * DX[0]\n      affine_raw(\n        Ndarray_DEV_DATA(y0), n_batch, n_cells,\n        dx0, n_batch, n_cells * 4,\n        Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n        true, false);\n\n      device_free(dx0);\n    }\n  '[source]#
class returnn.native_op.TwoDLSTM[source]#
inputs:
param X:

{input,output,forget,lambda} gate + cell state. 3d (timeT,timeS,batch,dim*5) // dim*5 or dim*1 ?

param V_h:

recurrent matrix. 2d (dim,dim*5)

param V_v:

recurrent matrix. 2d (dim,dim*5)

param W:

recurrent matrix. 2d (dim,dim*5)

param b:

bias. 2d (batch,dim)

param ptr_storage:

ptr_storage. 1d (1 * 5 * max_diag_size * sizeof(float*) / sizeof(float))

param valid:

used internally to store which cells are valid (have to be computed). 1d (1 * max_diag_size * n_minibatch)

param workmem2:

used internally. 3d (H[0], H[2], H[3])

param sizes:

height (target) x width (source) of the unpadded sentences. 2d (batch, 2)

outputs:
param CompleteY:

output. 4d (timeS,timeT,batch,dim)

param H:

gates and cell state. 4d (timeS,timeT,batch,dim*5) ?

param d:

final cell state. 3d (timeT,batch,dim)

in_info: Tuple[Dict[str]] = ({'bw_out_var': {'shape': ((0, 0), (0, 1), (0, 2), (0, 3))}, 'name': 'X', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'name': 'V_h', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'V_v', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'b', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'ptr_storage_fwd', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'ptr_storage_bwd', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'valid', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'workmem', 'ndim': 5, 'need_contiguous': True, 'shape': (None, None, None, None, None)}, {'gradient': 'disconnected', 'name': 'workmem2', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'sizes', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'DYDummy', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'initialState', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'initialOutput', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'iteration', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'CompleteY', 'ndim': 4, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2), (1, 0))}, {'name': 'H', 'ndim': 4, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2), (3, 1))})[source]#
classmethod grad_input_map(X, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, workmem2, sizes, DYDummy, initialState, initialOutput, iteration, CompleteY, H, DCompleteY, DH)[source]#
classmethod map_layer_inputs_to_op(Zs, Zt, V_h, V_v, W, b, ptr_storage)[source]#
Parameters:

inputs

Returns:

inputs

c_extra_support_code: Dict[str, str] = {'01_repvec': '\n      DEF_KERNEL\n      void repvec(const float * v, int vlen, int nCopies, float * dest)\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < vlen * nCopies)\n        {\n          dest[idx] = v[idx % vlen];\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '02_fillmat': '\n      void fillmat(OpKernelContext* context, const Ndarray * b, Ndarray * dst)\n      {\n        const float * data_b = Ndarray_DEV_DATA(b);\n        float * data_dst = Ndarray_DEV_DATA(dst);\n        Ndarray_DIMS_Type dims_b = Ndarray_HOST_DIMS(b);\n        int dims_dst[2];\n        lastTwoDims(dst, dims_dst);\n        assert(dims_b[0] == dims_dst[1]);\n        start_dev_kernel(repvec, (\n          data_b,\n          dims_dst[1],\n          Ndarray_SIZE(dst)/dims_dst[1],\n          data_dst\n        ));\n      }\n    ', '03_data_ptr': '\n      // if nd is 2 then assume a weight matrix and just return beginning of data\n      // else nd should be 3 and we pick the x part\n      float* data_ptr(const Ndarray* a, int y, int x, int outer_dim=0) {\n          assert(Ndarray_NDIM(a) == 2 || Ndarray_NDIM(a) == 4 || Ndarray_NDIM(a) == 5);\n          if(Ndarray_NDIM(a) == 2)\n              return Ndarray_DEV_DATA(a);\n          else if(Ndarray_NDIM(a) == 4) {\n              Ndarray_DIMS_Type dims = Ndarray_HOST_DIMS(a);\n              return Ndarray_DEV_DATA(a)\n                + y * dims[1] * dims[2] * dims[3]\n                + x * dims[2] * dims[3]; // row-major or minor?\n          }\n          else {\n              Ndarray_DIMS_Type dims = Ndarray_HOST_DIMS(a);\n              return Ndarray_DEV_DATA(a)\n                + outer_dim * dims[1] * dims[2] * dims[3] * dims[4]\n                + y * dims[2] * dims[3] * dims[4]\n                + x * dims[3] * dims[4];\n          }\n      }\n\n      float * data_ptr(Ndarray * a, int y, int x, int outer_dim=0)\n      {\n        const Ndarray * ca = a;\n        return const_cast<float *>(data_ptr(ca, y, x, outer_dim));\n      }\n    ', '04_affine_y_x_batched_onedir': "\n      // ys and xs: base indices, offset by y_A, x_A (-1,0,1)\n      void affine_y_x_batched_onedir(OpKernelContext* context, int y_A, int x_A,\n        const Ndarray * A1,\n        const Ndarray * B1,\n        Ndarray * C1,\n        const std::vector<int>& ys, const std::vector<int>& xs, Ndarray * ptr_storage, int height, int width,\n        cudaStream_t stream = 0, bool transpose_A=false, bool transpose_B=false)\n      {\n        const int batch_size = ys.size();\n        if(batch_size == 0)\n        {\n          return;\n        }\n        std::vector<const float*> ABC_ptrs(3 * 1 * batch_size); //content layout: 3x1xbatch_size (3: A,B,C, 1: dirs)\n\n        for(int i = 0; i < batch_size; ++i)\n        {\n          //A\n          //y not flipped, x not flipped\n          ABC_ptrs[0 * 1 * batch_size + 0 * batch_size + i] = data_ptr(A1, y_A + ys[i], x_A + xs[i]);\n\n          //B\n          //index doesent matter here, as B is only 2dimensional\n          ABC_ptrs[1 * 1 * batch_size + 0 * batch_size + i] = data_ptr(B1, 0, 0);\n\n          //we write the result (C) in the same destination (y,x) as the source (A), so we don't need to flip later\n          //C\n          //y not flipped, x not flipped\n          ABC_ptrs[2 * 1 * batch_size + 0 * batch_size + i] = data_ptr(C1, ys[i], xs[i]);\n        }\n        const float ** ptr_storage_data = reinterpret_cast<const float**>(&(ABC_ptrs[0]));\n        const float ** A_ptrs_data = (const float**) ptr_storage_data + 0 * 1 * batch_size;\n        const float ** B_ptrs_data = (const float**) ptr_storage_data + 1 * 1 * batch_size;\n        const float ** C_ptrs_data = ptr_storage_data + 2 * 1 * batch_size;\n\n        int A_dim[2], B_dim[2];\n        lastTwoDims(A1, A_dim);\n        lastTwoDims(B1, B_dim);\n        int ldB = B_dim[1];\n        int ldA = A_dim[1];\n        char transA = transpose_A ? 'T' : 'N';\n        char transB = transpose_B ? 'T' : 'N';\n        if (transpose_A)\n        {\n          std::swap(A_dim[0], A_dim[1]);\n        }\n        if (transpose_B)\n        {\n          std::swap(B_dim[0], B_dim[1]);\n        }\n\n        const float alpha = 1;\n        const float beta = 1;\n\n        Ndarray_sgemm_batched(\n          transB, transA, B_dim[1], A_dim[0], A_dim[1], &alpha,\n          B_ptrs_data, ldB, A_ptrs_data, ldA, &beta,\n          C_ptrs_data, B_dim[1], 1 * batch_size, batch_size == 1);\n      }\n    ", '05_lstm_stable_cell_kernel_batched': '\n      DEF_KERNEL\n      void lstm_stable_cell_kernel_batched(float ** datas, const float ** old_state_ys, const float ** old_state_xs,\n       float ** outputs, const float ** valids, int n_outer_batch, int n_cells, int n_minibatch)\n      {\n        //layout (for every outer batch):\n        //data[0*n_cells..1*n_cells-1] : input gate\n        //data[1*n_cells..2*n_cells-1] : forget gate\n        //data[2*n_cells..3*n_cells-1] : lambda gate\n        //data[3*n_cells..4*n_cells-1] : output gate\n        //data[5*n_cells..6*n_cells-1] : cell state\n        //output[0*n_cells..1*n_cells-1]: cell output\n        //valids: either 1.0 or 0.0, indicating if the current (y,x) position\n        //  is still inside the image in this minibatch\n        //repeated for every mini-batch\n\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_outer_batch * n_cells * n_minibatch)\n        {\n          int size_per_outer_batch = n_cells * n_minibatch;\n          int outer_batch_idx = idx / size_per_outer_batch;\n          float * data = datas[outer_batch_idx];\n          const float * old_state_y = old_state_ys[outer_batch_idx];\n          const float * old_state_x = old_state_xs[outer_batch_idx];\n          float * output = outputs[outer_batch_idx];\n          const float * valid = valids[outer_batch_idx];\n\n          int inner_idx = idx % size_per_outer_batch;\n          int minibatch_idx = inner_idx / n_cells;\n          int batch_offset = minibatch_idx * 5 * n_cells;\n          int cell_offset = inner_idx % n_cells;\n          int start = batch_offset + cell_offset;\n\n          float valid_batch = valid[minibatch_idx];\n\n          //input, forget and output gates\n          float inpGate = 1.f / (1.f + expf(-data[start]));\n          float fgtGate = 1.f / (1.f + expf(-data[start + n_cells]));\n          float lambdaGate = 1.f / (1.f + expf(-data[start + 2 * n_cells]));\n          float outGate = 1.f / (1.f + expf(-data[start + 3 * n_cells]));\n          float state = inpGate * tanhf(data[start + 4 * n_cells]);\n          if (old_state_y)\n          {\n            state += fgtGate * lambdaGate * old_state_y[start];\n          }\n          if (old_state_x)\n          {\n            state += fgtGate * (1.0f - lambdaGate) * old_state_x[start];\n          }\n          state *= valid_batch;\n\n          //cell output\n          output[inner_idx] = outGate * tanhf(state) * valid_batch;\n\n          data[start] = inpGate;\n          data[start + n_cells] = fgtGate;\n          data[start + 2 * n_cells] = lambdaGate;\n          data[start + 3 * n_cells] = outGate;\n          data[start + 4 * n_cells] = state;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '06_do_lstm_batched_onedir': '\n      // H, CompleteY, ys, xs, ptr_storage\n      void do_lstm_batched_onedir(\n       OpKernelContext* context, Ndarray* H, Ndarray* initialState, float iteration, Ndarray* completeOut,\n       const std::vector<int>& ys, const std::vector<int>& xs,\n       Ndarray* ptr_storage, Ndarray* valid_storage, Ndarray* sizes)\n      {\n        int n_outer_batch = ys.size();\n        Ndarray_DIMS_Type H_dims = Ndarray_HOST_DIMS(H);\n        int height = H_dims[0];\n        int width = H_dims[1];\n        int n_minibatch = H_dims[2];\n        int n_cells = H_dims[3] / 5;\n        assert(H_dims[3] % 5 == 0); //4 gates + cell\n\n        std::vector<float*> ptrs(1 * 5 * n_outer_batch); //1 dirs * 5 arrays\n        std::vector<float> valid(1 * n_minibatch * n_outer_batch, 1.0f);\n\n        float* h_sizes; // the sizes array is stored on the GPU, we have to copy it to the CPU\n        int dsize =\n          (n_outer_batch) * (n_minibatch) * sizeof(float) * 2; // (*2), because we have 2 (height, width) numbers\n        h_sizes = (float*)malloc(dsize);\n        HANDLE_ERROR(cudaMemcpy(h_sizes, Ndarray_DEV_DATA(sizes), dsize, cudaMemcpyDeviceToHost));\n\n        for(int i = 0; i < n_outer_batch; ++i)\n        {\n          int y = ys[i];\n          int x = xs[i];\n\n          //fill valid\n          for(int n = 0; n < n_minibatch; ++n) // iterates through all examples in the current batch\n          {\n            float img_height = *(h_sizes + 2*n);\n            float img_width = *(h_sizes + 2*n +1);\n\n            valid[i * 1 * n_minibatch + 0 * n_minibatch + n] = float(y < img_height && x < img_width);\n          }\n\n          //y not flipped, x not flipped\n          float * data_H = data_ptr(H, y, x);\n\n          //y not flipped, x not flipped\n          float * data_old_state_y;\n          data_old_state_y = y > 0 ? data_ptr(H, y - 1, x) + 4 * n_cells : data_ptr(initialState, 0, x) + 4 * n_cells;\n\n          //y not flipped, x not flipped\n          float * data_old_state_x = x > 0 ? data_ptr(H, y, x - 1) + 4 * n_cells : 0;\n\n          //y not flipped, x not flipped\n          float * data_out = data_ptr(completeOut, y, x);\n\n          float * valid = Ndarray_DEV_DATA(valid_storage) + i * 1 * n_minibatch + 0 * n_minibatch;\n\n          ptrs[0 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_H;\n          ptrs[1 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_old_state_y;\n          ptrs[2 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_old_state_x;\n          ptrs[3 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_out;\n          ptrs[4 * 1 * n_outer_batch + 0 * n_outer_batch + i] = valid;\n        }\n\n        free(h_sizes);\n\n        HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(valid_storage), valid.data(),\n          valid.size() * sizeof(float), cudaMemcpyHostToDevice));\n        HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(ptr_storage), ptrs.data(),\n          ptrs.size() * sizeof(float*), cudaMemcpyHostToDevice));\n        float ** ptr_storage_data = reinterpret_cast<float**>(Ndarray_DEV_DATA(ptr_storage));\n        float ** data_Hs = ptr_storage_data + 0 * 1 * n_outer_batch;\n        const float ** data_old_state_ys = (const float**) ptr_storage_data + 1 * 1 * n_outer_batch;\n        const float ** data_old_state_xs = (const float**) ptr_storage_data + 2 * 1 * n_outer_batch;\n        float ** data_outs = ptr_storage_data + 3 * 1 * n_outer_batch;\n        const float ** data_valids = (const float**) (ptr_storage_data + 4 * 1 * n_outer_batch);\n\n        start_dev_kernel(lstm_stable_cell_kernel_batched, (\n          data_Hs,\n          data_old_state_ys,\n          data_old_state_xs,\n          data_outs,\n          data_valids,\n          1 * n_outer_batch,\n          n_cells,\n          n_minibatch\n        ));\n      }\n    ', '07_lstm_bwd_stable_cell_kernel_batched': '\n      DEF_KERNEL\n      void lstm_bwd_stable_cell_kernel_batched(float ** deltas, const float ** epsilons,\n        const float ** next_epsilon_ys, const float ** next_epsilon_xs, float ** epsilon_ys, float ** epsilon_xs,\n        const float ** last_state_ys, const float ** last_state_xs, const float ** Ys, const float ** valids,\n        int n_outer_batch, int n_cells, int n_minibatch)\n      {\n        //layout (for every outer batch):\n        //delta[0*n_cells..1*n_cells-1] : input gate\n        //delta[1*n_cells..2*n_cells-1] : forget gate\n        //delta[2*n_cells..3*n_cells-1] : lambda gate\n        //delta[3*n_cells..4*n_cells-1] : output gate\n        //delta[4*n_cells..5*n_cells-1] : cell state\n        //epsilon[0*n_cells..1*n_cells-1]: cell output derivative\n        //next_epsilon_y[0*n_cells..1*n_cells-1]: cell state derivative * forget_gate * lambda_gate (of next timestep)\n        //next_epsilon_x[0*n_cells..1*n_cells-1]:\n        //  cell state derivative * forget_gate * (-1*)lambda_gate (of next timestep)\n        //epsilon_y[0*n_cells..1*n_cells-1]:\n        //  cell state derivative * forget_gate * lambda_gate (of current timestep, as output)\n        //epsilon_x[0*n_cells..1*n_cells-1]:\n        //  cell state derivative * forget_gate * (1-lambda_gate) (of current timestep, as output)\n        //valids: either 1.0 or 0.0, indicating if the current (y,x) position\n        //  is still inside the image in this minibatch\n        //repeated for every mini-batch\n\n        float near_zero = 0.00000000001f;\n\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while (idx < n_outer_batch * n_cells * n_minibatch)\n        {\n          int size_per_outer_batch = n_cells * n_minibatch;\n          int outer_batch_idx = idx / size_per_outer_batch;\n          const float * valid = valids[outer_batch_idx];\n\n          float * delta = deltas[outer_batch_idx];\n          const float * epsilon = epsilons[outer_batch_idx];\n          const float * next_epsilon_y = next_epsilon_ys[outer_batch_idx];\n          const float * next_epsilon_x = next_epsilon_xs[outer_batch_idx];\n          float * epsilon_y = epsilon_ys[outer_batch_idx];\n          float * epsilon_x = epsilon_xs[outer_batch_idx];\n          const float * last_state_y = last_state_ys[outer_batch_idx];\n          const float * last_state_x = last_state_xs[outer_batch_idx];\n          const float * Y = Ys[outer_batch_idx];\n\n          int inner_idx = idx % size_per_outer_batch;\n          int minibatch_idx = inner_idx / n_cells;\n          int batch_offset = minibatch_idx * 5 * n_cells;\n          int cell_offset = inner_idx % n_cells;\n          int start = batch_offset + cell_offset;\n          float valid_batch = valid[minibatch_idx];\n\n          float inpGate = delta[start];\n          float fgtGate = delta[start + n_cells];\n          float lambdaGate = delta[start + 2 * n_cells];\n          float outGate = delta[start + 3 * n_cells];\n          float state = delta[start + 4 * n_cells];\n          float lastState_y = last_state_y ? last_state_y[start] : 0.f;\n          float lastState_x = last_state_x ? last_state_x[start] : 0.f;\n          float eps = epsilon[inner_idx];\n\n          //avoid division by 0\n          float gc = 0.f; //g(c(t))\n          float gzc = 0.f; //g(z_c(t))\n          if (outGate < -near_zero || outGate > near_zero)\n          {\n            gc = Y[inner_idx] / outGate;\n          }\n\n          if (inpGate < -near_zero || inpGate > near_zero)\n          {\n            gzc = (state - fgtGate * lambdaGate * lastState_y - fgtGate * (1.0f - lambdaGate) * lastState_x) / inpGate;\n          }\n\n          //delta_output\n          delta[start + 3 * n_cells] = outGate * (1.f - outGate) * gc * eps * valid_batch;\n\n          //epsilon_c\n          float epsilon_c = (1.f - (gc * gc)) * outGate * eps;\n          if (next_epsilon_y)\n          {\n            epsilon_c += next_epsilon_y[inner_idx];\n          }\n          if (next_epsilon_x)\n          {\n            epsilon_c += next_epsilon_x[inner_idx];\n          }\n\n          //TODO: clip epsilon_c?\n          //epsilon_c = max(epsilon_c, -10.f);\n          //epsilon_c = min(epsilon_c, 10.f);\n\n          epsilon_y[inner_idx] = epsilon_c * fgtGate * lambdaGate * valid_batch;\n          epsilon_x[inner_idx] = epsilon_c * fgtGate * (1.0f - lambdaGate) * valid_batch;\n\n          //delta_cell\n          delta[start + 4 * n_cells] = inpGate * (1.f - (gzc * gzc)) * epsilon_c * valid_batch;\n\n          //delta_forget\n          delta[start + n_cells] = fgtGate * (1.f - fgtGate) * epsilon_c *\n                                   (lastState_y * lambdaGate + lastState_x * (1.0f - lambdaGate)) * valid_batch;\n\n          //delta_lambda\n          delta[start + 2 * n_cells] = fgtGate * lambdaGate * (1.f - lambdaGate) * epsilon_c\n                                       * (lastState_y - lastState_x) * valid_batch;\n\n          //delta_input\n          delta[start] = inpGate * (1.f - inpGate) * gzc * epsilon_c * valid_batch;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '08_do_lstm_bwd_batched_onedir': '\n      //epsilon are the derivates w.r.t. Z, delta stores the gate and cell activations\n      //  and will store the derivatives later\n      void do_lstm_bwd_batched_onedir(OpKernelContext* context, Ndarray * delta1, Ndarray * epsilon1,\n       const Ndarray* CompleteY, Ndarray * workmem1,\n       int height, int width, const std::vector<int>& ys, const std::vector<int>& xs,\n       Ndarray * ptr_storage, Ndarray * valid_storage, Ndarray*  sizes, int iteration, cudaStream_t stream=0)\n      {\n        int n_outer_batch = ys.size();\n        int dims[2];\n        lastTwoDims(delta1, dims);\n        assert(dims[1] % 5 == 0); //4 gates + cell\n        int n_cells = dims[1] / 5;\n        int n_minibatch = dims[0];\n\n        std::vector<const float*> ptrs(1 * 10 * n_outer_batch); //1 dirs * 10 arrays\n        std::vector<float> valid(1 * n_minibatch * n_outer_batch, 1.0f);\n\n        float* h_sizes; // the sizes array is stored on the GPU, we have to copy it to the CPU\n        int dsize =\n          (n_outer_batch) * (n_minibatch) * sizeof(float) * 2; // (*2), because we have 2 (height, width) numbers\n        h_sizes = (float*)malloc(dsize);\n        HANDLE_ERROR(cudaMemcpy(h_sizes, Ndarray_DEV_DATA(sizes), dsize, cudaMemcpyDeviceToHost));\n\n        for(int i = 0; i < n_outer_batch; ++i)\n        {\n          int y = ys[i];\n          int x = xs[i];\n          //fill valid\n          for(int n = 0; n < n_minibatch; ++n)\n          {\n            //these are the sizes of a single image in the batch, while height/width are the maximum sizes in the batch\n            float img_height = *(h_sizes + 2*n);\n            float img_width = *(h_sizes + 2*n +1);\n            valid[i * 1 * n_minibatch + 0 * n_minibatch + n] = float(y < img_height && x < img_width);\n          }\n\n          bool botBorder = (y == height-1);\n          bool rightBorder = (x == width-1);\n          int yp1 = y + 1;\n          int xp1 = x + 1;\n          int ym1 = y - 1;\n          int xm1 = x - 1;\n\n          float * data_delta1 = data_ptr(delta1, y, x);\n          const float * data_epsilon1 = data_ptr(epsilon1, y, x);\n          const float * data_next_epsilon_y1 = botBorder ? 0 : data_ptr(workmem1, (iteration-1)%2, x, 0);\n          const float * data_next_epsilon_x1 = rightBorder ? 0 : data_ptr(workmem1, (iteration-1)%2, xp1, 1);\n          float * data_epsilon_y1 = data_ptr(workmem1, iteration%2, x, 0);\n          float * data_epsilon_x1 = data_ptr(workmem1, iteration%2, x, 1);\n          const float * data_last_state_y1 = y > 0 ? data_ptr(delta1, ym1, x) + 4 * n_cells : 0;\n          const float * data_last_state_x1 = x > 0 ? data_ptr(delta1, y, xm1) + 4 * n_cells : 0;\n          const float * data_Y1 = data_ptr(CompleteY, y, x);\n          float * valid1 = Ndarray_DEV_DATA(valid_storage) + i * 1 * n_minibatch + 0 * n_minibatch;\n\n          ptrs[0 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_delta1;\n          ptrs[1 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon1;\n          ptrs[2 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_next_epsilon_y1;\n          ptrs[3 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_next_epsilon_x1;\n          ptrs[4 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon_y1;\n          ptrs[5 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon_x1;\n          ptrs[6 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_last_state_y1;\n          ptrs[7 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_last_state_x1;\n          ptrs[8 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_Y1;\n          ptrs[9 * 1 * n_outer_batch + 0 * n_outer_batch + i] = valid1;\n        }\n\n        free(h_sizes);\n\n        HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(valid_storage), valid.data(),\n          valid.size() * sizeof(float), cudaMemcpyHostToDevice));\n        HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(ptr_storage), ptrs.data(),\n          ptrs.size() * sizeof(float*), cudaMemcpyHostToDevice));\n        float ** ptr_storage_data = reinterpret_cast<float**>(Ndarray_DEV_DATA(ptr_storage));\n        float ** data_deltas = ptr_storage_data + 0 * 1 * n_outer_batch;\n        const float ** data_epsilons = (const float**) ptr_storage_data + 1 * 1 * n_outer_batch;\n        const float ** data_next_epsilon_ys = (const float**) ptr_storage_data + 2 * 1 * n_outer_batch;\n        const float ** data_next_epsilon_xs = (const float**) ptr_storage_data + 3 * 1 * n_outer_batch;\n        float ** data_epsilon_ys = ptr_storage_data + 4 * 1 * n_outer_batch;\n        float ** data_epsilon_xs = ptr_storage_data + 5 * 1 * n_outer_batch;\n        const float ** data_last_state_ys = (const float**) ptr_storage_data + 6 * 1 * n_outer_batch;\n        const float ** data_last_state_xs = (const float**) ptr_storage_data + 7 * 1 * n_outer_batch;\n        const float ** data_Ys = (const float**) ptr_storage_data + 8 * 1 * n_outer_batch;\n        const float ** data_valids = (const float**) (ptr_storage_data + 9 * 1 * n_outer_batch);\n\n        start_dev_kernel(lstm_bwd_stable_cell_kernel_batched, (\n          data_deltas,\n          data_epsilons,\n          data_next_epsilon_ys,\n          data_next_epsilon_xs,\n          data_epsilon_ys,\n          data_epsilon_xs,\n          data_last_state_ys,\n          data_last_state_xs,\n          data_Ys,\n          data_valids,\n          1 * n_outer_batch,\n          n_cells,\n          n_minibatch\n        ));\n      }\n    '}[source]#
c_fw_code: str = '\n    // X*, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, sizes, DYDummy,\n    //   initialState, initialOutput, iteration = input_names (*: inplace)\n    // CompleteY, H = output_names\n\n    assert(n_inputs == 15);\n    assert(n_outputs == 2);\n\n    Ndarray* X = inputs[0];\n    Ndarray* V_h = inputs[1];\n    Ndarray* V_v = inputs[2];\n    Ndarray* W = inputs[3];\n    Ndarray* b = inputs[4];\n    Ndarray* ptr_storage_fwd = inputs[5];\n    Ndarray* ptr_storage_bwd = inputs[6]; // not used in fwd\n    Ndarray* valid = inputs[7];\n    Ndarray* workmem = inputs[8]; // not used in fwd\n    Ndarray* workmem2 = inputs[9]; // not used in fwd\n    Ndarray* sizes = inputs[10];\n    Ndarray* DYDummy = inputs[11]; // not used in fwd\n    Ndarray* initialState = inputs[12];\n    Ndarray* initialOutput = inputs[13];\n    Ndarray* iteration = inputs[14];\n\n    assert(sizeof(float) == 4 && "ptr_storage has wrong size if sizeof(float) != 4");\n    assert(sizeof(float*) == 8 && "ptr_storage has wrong size if sizeof(float*) != 8");\n\n    Ndarray* CompleteY = *outputs[0];\n    Ndarray* H = *outputs[1];\n\n    Ndarray_DIMS_Type X_dim = Ndarray_DIMS(X);\n    Ndarray_DIMS_Type W_dim = Ndarray_DIMS(W);\n    Ndarray_DIMS_Type V_dim = Ndarray_DIMS(V_h);\n    assert(W_dim[1] %% 5 == 0 && "W has wrong shape");\n    assert(5 * V_dim[0] == V_dim[1] && "V has wrong shape");\n    assert(W_dim[1] == V_dim[1]);\n    assert(W_dim[0] == X_dim[3]);\n    const long long Y_dim[] = {X_dim[0], X_dim[1], X_dim[2], W_dim[1] / 5};\n    const long long H_dim[] = {X_dim[0], X_dim[1], X_dim[2], W_dim[1]};\n    const long long height = X_dim[0];\n    const long long width = X_dim[1];\n    const long long n_minibatch = X_dim[2];\n    const long long max_diag_size = std::min(height, width);\n    const long long n_diags = width + height - 1;\n\n    //H = XW (+ b, currently always 0)\n    fillmat(context, b, H);\n    affine_global(X, W, H);\n\n    // The iteration is stored on the GPU, but we need it on the CPU to controll the programm flow (use explicitly\n    // provided previous state/output on first iteration). Maybe this could be optimized by storing the tensor\n    // directly on the CPU?\n    // We only look at the first value of the tensor with shape (batch,), as every entry has the same value by design\n    float h_iteration;\n    HANDLE_ERROR(cudaMemcpy(&h_iteration, Ndarray_DEV_DATA(iteration), 1*sizeof(float), cudaMemcpyDeviceToHost));\n\n    for(long long diag = 0; diag < n_diags; ++diag)\n    {\n      int diag_size = min(diag+1, min((long long) abs(n_diags-diag), min(width, height)));\n      int y_high = min(diag, height-1);\n      int x_low = max(diag-height+1,(long long) 0);\n      std::vector<int> ys_h, xs_h, ys_v, xs_v, ys, xs;\n      for(int idx = 0; idx < diag_size; ++idx)\n      {\n        int y = y_high - idx;\n        int x = x_low + idx;\n        if(x > 0)\n        {\n          ys_h.push_back(y);\n          xs_h.push_back(x);\n        }\n        if(y > 0 || h_iteration >= 1) {\n          ys_v.push_back(y);\n          xs_v.push_back(x);\n        }\n        ys.push_back(y);\n        xs.push_back(x);\n      }\n\n      affine_y_x_batched_onedir(context, 0, -1,\n        CompleteY, V_h, H, ys_h, xs_h, ptr_storage_fwd, height, width);\n\n      // If it\'s not the first iteration, we need to use the explicitly provided initial output\n      if(h_iteration >= 1) {\n        assert(ys_v.size() == 1); // Otherwise, the target length would be != 1, we don\'t support that yet.\n        affine_y_x_batched_onedir(context, 0, 0,\n          initialOutput, V_v, H, ys_v, xs_v, ptr_storage_fwd, height, width);\n      }\n      else {\n        affine_y_x_batched_onedir(context, -1, 0,\n          CompleteY, V_v, H, ys_v, xs_v, ptr_storage_fwd, height, width);\n      }\n\n      do_lstm_batched_onedir(context, H, initialState, h_iteration, CompleteY, ys, xs, ptr_storage_fwd, valid, sizes);\n    }\n    '[source]#
c_bw_code: str = "\n    // X, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, workmem2, sizes, DYDummy, initialState,\n    //   initialOutput, iteration, CompleteY, H, DCompleteY, DH = inputs\n    // DX, DV_h, DV_v, DW, Db = outputs\n\n    assert(n_inputs == 19);\n    assert(n_outputs == 5);\n\n    Ndarray* X = inputs[0];\n    Ndarray* V_h = inputs[1];\n    Ndarray* V_v = inputs[2];\n    Ndarray* W = inputs[3];\n    Ndarray* b = inputs[4];\n    Ndarray* ptr_storage_fwd = inputs[5]; // not used in bwd\n    Ndarray* ptr_storage_bwd = inputs[6];\n    Ndarray* valid_storage = inputs[7];\n    Ndarray* workmem = inputs[8];\n    Ndarray* workmem2 = inputs[9];\n    Ndarray* sizes = inputs[10];\n    Ndarray* DYDummy = inputs[11];\n    Ndarray* initialState = inputs[12];\n    Ndarray* initialOutput = inputs[13];\n    Ndarray* iteration = inputs[14]; // not used in bwd (only for asserting it's == 0)\n    Ndarray* CompleteY = inputs[15];\n    Ndarray* H = inputs[16];\n    Ndarray* DCompleteY = inputs[17];\n    Ndarray* DH = inputs[18];\n\n    Ndarray* DX = *outputs[0];\n    Ndarray* DV_h = *outputs[1];\n    Ndarray* DV_v = *outputs[2];\n    Ndarray* DW = *outputs[3];\n    Ndarray* Db = *outputs[4];\n\n    Ndarray_DIMS_Type X_dim = Ndarray_HOST_DIMS(X);\n    Ndarray_DIMS_Type Y_dim = Ndarray_HOST_DIMS(CompleteY);\n    Ndarray_DIMS_Type Vh_dim = Ndarray_HOST_DIMS(V_h);\n    const int height = X_dim[0];\n    const int width = X_dim[1];\n    const int n_minibatch = X_dim[2];\n    const int n_diags = width + height - 1;\n    const int max_diag_size = std::min(Y_dim[0], Y_dim[1]);\n\n    Ndarray * delta1 = H;\n    Ndarray * epsilon = DYDummy;\n\n    int size = X_dim[0] * X_dim[1] * X_dim[2] * Vh_dim[0] * sizeof(float);\n    HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(epsilon), Ndarray_DEV_DATA(DCompleteY), size, cudaMemcpyDeviceToDevice));\n\n    for(int diag = n_diags-1; diag >= 0; --diag)\n    {\n      int diag_size = std::min(diag+1, std::min(std::abs(n_diags-diag), std::min(width, height)));\n      int y_high = std::min(diag, height-1);\n      int x_low = std::max(diag-height+1,0);\n      std::vector<int> ys_h, xs_h, ys_v, xs_v, ys, xs;\n      for(int idx = 0; idx < diag_size; ++idx)\n      {\n        int y = y_high - idx;\n        int x = x_low + idx;\n        bool rightBorder = (x == X_dim[1]-1);\n        if(!rightBorder)\n        {\n          ys_h.push_back(y);\n          xs_h.push_back(x);\n        }\n        bool botBorder = (y == X_dim[0]-1);\n        if(!botBorder)\n        {\n          ys_v.push_back(y);\n          xs_v.push_back(x);\n        }\n        ys.push_back(y);\n        xs.push_back(x);\n      }\n\n      affine_y_x_batched_onedir(context, 0, 1, delta1, V_h,\n        epsilon, ys_h, xs_h, ptr_storage_bwd, height, width, 0, false, true);\n      affine_y_x_batched_onedir(context, 1, 0, delta1, V_v,\n        epsilon, ys_v, xs_v, ptr_storage_bwd, height, width, 0, false, true);\n\n      do_lstm_bwd_batched_onedir(\n        context, delta1, epsilon, CompleteY, workmem,\n        X_dim[0], X_dim[2], ys, xs, ptr_storage_bwd, valid_storage, sizes, diag+1);\n    }\n\n    //DW = X^T * delta\n    affine_global(X, delta1, DW, true, false, 0, 0.0f);\n    //important! mind the order, first use X, then update DX, which might be aligned to X\n    //DX = delta * W^T\n    affine_global(delta1, W, DX, false, true, 0, 0.0f);\n\n    // Currently, the bias is not trained\n    //Db = (1 ... 1) * delta\n\n    //copy left/right part to workmem2 and set to 0\n    // (could be done more efficient, but profiling shows, it's not worth it)\n    Ndarray_DIMS_Type H_dim = Ndarray_HOST_DIMS(H);\n    const int block_size = H_dim[2] * H_dim[3];\n    for(int y = 0; y < Y_dim[0]; ++y)\n    {\n      float * workmem2_1_data_ptr = Ndarray_DEV_DATA(workmem2) + y * block_size;\n      float * delta1_data_ptr = data_ptr(delta1, y, 0);\n      HANDLE_ERROR(cudaMemcpy(\n        workmem2_1_data_ptr, delta1_data_ptr, block_size * sizeof(float), cudaMemcpyDeviceToDevice));\n      HANDLE_ERROR(cudaMemset(delta1_data_ptr, 0, sizeof(float) * H_dim[2] * H_dim[3]));\n    }\n\n    //DV_h = Y[0..end-1]^T * delta[1..end]\n    affine_global(CompleteY, delta1, DV_h, true, false, 1, 0.0f);\n\n    //copy left/right part back\n    for(int y = 0; y < Y_dim[0]; ++y)\n    {\n      float * workmem2_1_data_ptr = Ndarray_DEV_DATA(workmem2) + y * block_size;\n      float * delta1_data_ptr = data_ptr(delta1, y, 0);\n      HANDLE_ERROR(cudaMemcpy(\n        delta1_data_ptr, workmem2_1_data_ptr, block_size * sizeof(float), cudaMemcpyDeviceToDevice));\n    }\n\n    //DV_v = Y[0..end-1]^T * delta[1..end]\n    affine_global(CompleteY, delta1, DV_v, true, false, Y_dim[1], 0.0f);\n  "[source]#
cpu_support = False[source]#
code_version: Union[Tuple[int], int] = ()[source]#
class returnn.native_op.Chunking[source]#

Given an input in 3d (n_time,n_batch,n_dim), we chunk up the time dimension in chunks of size chunk_size, every chunk_step frames. This results in an 3d output (chunk_size, n_batch * n_chunks, n_dim) where n_chunks = floor( max(n_time - chunk_size + chunk_step - 1, 0) / chunk_step ) + 1. Examples:

n_time=1, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=49, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=50, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=51, chunk_size=50, chunk_step=10 -> n_chunks=2 n_time=60, chunk_size=50, chunk_step=10 -> n_chunks=2 n_time=61, chunk_size=50, chunk_step=10 -> n_chunks=3 n_time=99, chunk_size=50, chunk_step=10 -> n_chunks=6 n_time=100, chunk_size=50, chunk_step=10 -> n_chunks=6 n_time=101, chunk_size=50, chunk_step=10 -> n_chunks=7

in_info: Tuple[Dict[str]] = ({'name': 'input', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'output_buffer', 'ndim': 3, 'shape': (None, None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': 'oindex_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 1}, {'gradient': 'disconnected', 'name': 'chunk_params', 'ndim': 1, 'need_contiguous': True, 'shape': (2,)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'shape': ((2, 0), (2, 1), (2, 2))}, {'name': 'oindex', 'ndim': 2, 'shape': ((3, 0), (3, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'copy_kernel': '\n    DEF_KERNEL\n    void copy_kernel(\n      float* chunk_params,\n      float* input, long in_dim0, long in_dim1, long in_dim2, long in_stride0, long in_stride1, long in_stride2,\n      float* index, long idx_stride0, long idx_stride1,\n      float* output, long out_dim0, long out_dim1, long out_stride0, long out_stride1, long out_stride2,\n      float* oindex, long oidx_stride0, long oidx_stride1\n    ) {\n      assert_cmp(out_dim1 % in_dim1, ==, 0);\n      const long n_chunks = out_dim1 / in_dim1;\n      assert_cmp(n_chunks, >, 0);\n      const long chunk_size = out_dim0;\n      assert_cmp(long(chunk_params[0]), ==, chunk_size);\n      const long chunk_step = long(chunk_params[1]);\n      assert_cmp(chunk_step, >, 0);\n      assert_cmp(chunk_step * (n_chunks - 1) + chunk_size, >=, in_dim0);\n      assert_cmp(chunk_step * (n_chunks - 1), <, in_dim0);\n\n      // Iterate over output (chunked) x/y coordinates.\n      // In an inner loop, we will loop over z.\n      const long max_idx = out_dim0 * out_dim1;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        long out_x = idx % out_dim0;  // time\n        long out_y = idx / out_dim0;  // batch\n\n        long chunk_idx = out_y % n_chunks;\n        long in_y =      out_y / n_chunks;\n\n        long in_x = chunk_step * chunk_idx + out_x;\n\n        if(in_x < in_dim0 && index[in_x * idx_stride0 + in_y * idx_stride1] > 0.1) {\n          for(long z = 0; z < in_dim2; ++z)\n            output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] =\n              input[in_x * in_stride0 + in_y * in_stride1 + z * in_stride2];\n          oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 1;\n        }\n        else {\n          for(long z = 0; z < in_dim2; ++z)\n            output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] = 0;\n          oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 0;\n        }\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert_cmp(n_inputs, ==, 5);\n    assert_cmp(n_outputs, ==, 2);\n    Ndarray* input = inputs[0];\n    Ndarray* index = inputs[1];\n    Ndarray* chunk_params = inputs[4];\n    Ndarray* output = *outputs[0];\n    Ndarray* oindex = *outputs[1];\n\n    assert_cmp(Ndarray_NDIM(input), ==, 3);\n    assert_cmp(Ndarray_NDIM(index), ==, 2);\n    assert_cmp(Ndarray_DIMS(input)[0], ==, Ndarray_DIMS(index)[0]);\n    assert_cmp(Ndarray_DIMS(input)[1], ==, Ndarray_DIMS(index)[1]);\n    assert_cmp(Ndarray_NDIM(chunk_params), ==, 1);\n    assert_cmp(Ndarray_DIMS(chunk_params)[0], ==, 2);\n    assert_cmp(Ndarray_NDIM(output), ==, 3);\n    assert_cmp(Ndarray_NDIM(oindex), ==, 2);\n    assert_cmp(Ndarray_DIMS(output)[0], ==, Ndarray_DIMS(oindex)[0]);\n    assert_cmp(Ndarray_DIMS(output)[1], ==, Ndarray_DIMS(oindex)[1]);\n    assert_cmp(Ndarray_DIMS(output)[2], ==, Ndarray_DIMS(input)[2]);\n\n    start_dev_kernel(copy_kernel, (\n      Ndarray_DEV_DATA(chunk_params),\n      Ndarray_DEV_DATA(input),\n        Ndarray_DIMS(input)[0],\n        Ndarray_DIMS(input)[1],\n        Ndarray_DIMS(input)[2],\n        Ndarray_STRIDE(input, 0),\n        Ndarray_STRIDE(input, 1),\n        Ndarray_STRIDE(input, 2),\n      Ndarray_DEV_DATA(index),\n        Ndarray_STRIDE(index, 0),\n        Ndarray_STRIDE(index, 1),\n      Ndarray_DEV_DATA(output),\n        Ndarray_DIMS(output)[0],\n        Ndarray_DIMS(output)[1],\n        Ndarray_STRIDE(output, 0),\n        Ndarray_STRIDE(output, 1),\n        Ndarray_STRIDE(output, 2),\n      Ndarray_DEV_DATA(oindex),\n        Ndarray_STRIDE(oindex, 0),\n        Ndarray_STRIDE(oindex, 1)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
code_version: Union[Tuple[int], int] = ()[source]#
static naive_chunk_start_frames(n_time, chunk_size, chunk_step)[source]#

This is just for documentation / demonstration. Also used by testing code.

class returnn.native_op.UnChunking[source]#

This reverses the output from Chunking, i.e. chunking the time dimension. We get a 3d input (chunk_size, n_batch * n_chunks, n_dim) and return an 3d output (n_time, n_batch, n_dim) where the chunks are of size chunk_size, every chunk_step frames. Because of overlaps, we have to combine the overlapping chunks somehow. We will do that with a uniform distribution, i.e. take the mean of all overlaps per frame.

in_info: Tuple[Dict[str]] = ({'name': 'input', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'output_buffer', 'ndim': 3, 'shape': (None, None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': 'oindex_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 1}, {'gradient': 'disconnected', 'name': 'ofactors_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 2}, {'gradient': 'disconnected', 'name': 'chunk_params', 'ndim': 1, 'need_contiguous': True, 'shape': (2,)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'shape': ((2, 0), (2, 1), (2, 2))}, {'name': 'oindex', 'ndim': 2, 'shape': ((3, 0), (3, 1))}, {'name': 'ofactors', 'ndim': 2, 'shape': ((4, 0), (4, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'unchunk_kernel': '\n    DEF_KERNEL\n    void unchunk_kernel(\n      float* chunk_params,\n      float* input, long in_dim0, long in_dim1, long in_dim2, long in_stride0, long in_stride1, long in_stride2,\n      float* index, long idx_stride0, long idx_stride1,\n      float* output, long out_dim0, long out_dim1, long out_stride0, long out_stride1, long out_stride2,\n      float* oindex, long oidx_stride0, long oidx_stride1,\n      float* ofactors, long ofac_stride0, long ofac_stride1\n    ) {\n      assert_cmp(in_dim1 % out_dim1, ==, 0);\n      const long n_chunks = in_dim1 / out_dim1;\n      assert_cmp(n_chunks, >, 0);\n      const long chunk_size = in_dim0;\n      assert_cmp(long(chunk_params[0]), ==, chunk_size);\n      const long chunk_step = long(chunk_params[1]);\n      assert_cmp(chunk_step, >, 0);\n      assert_cmp(chunk_step * (n_chunks - 1) + chunk_size, >=, out_dim0);\n      assert_cmp(chunk_step * (n_chunks - 1), <, out_dim0);\n\n      // Iterate over output (unchunked) x/y coordinates.\n      // In an inner loop, we will loop over z.\n      const long max_idx = out_dim0 * out_dim1;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        long out_x = idx % out_dim0;  // time\n        long out_y = idx / out_dim0;  // batch\n\n        float c = 0;\n        for(long z = 0; z < in_dim2; ++z)\n          output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] = 0;\n\n        // in_x = out_x - chunk_step * chunk_idx,\n        // thus in_x < 0           when chunk_idx * chunk_step >  out_x,\n        // and  in_x >= chunk_size when chunk_idx * chunk_step <= out_x - chunk_size,\n        // thus we need chunk_idx <= out_x / chunk_step,\n        // and          chunk_idx > (out_x - chunk_size) / chunk_step.\n        // Examples:\n        //   out_x=0,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,1\n        //   out_x=3,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,1\n        //   out_x=4,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,2\n        //   out_x=7,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,2\n        //   out_x=8,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,3\n        //   out_x=9,  chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,3\n        //   out_x=10, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,3\n        //   out_x=11, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,3\n        //   out_x=12, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,4\n        //   out_x=13, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,4\n        //   out_x=14, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=2,4\n        long chunk_idx_start = (out_x - chunk_size + chunk_step) / chunk_step;\n        if(chunk_idx_start < 0) chunk_idx_start = 0;\n        long chunk_idx_end = out_x / chunk_step + 1;\n        if(chunk_idx_end > n_chunks) chunk_idx_end = n_chunks;\n        assert_cmp(chunk_idx_start, <, chunk_idx_end);\n        for(long chunk_idx = chunk_idx_start; chunk_idx < chunk_idx_end; ++chunk_idx) {\n          long in_y = out_y * n_chunks + chunk_idx;\n          long in_x = out_x - chunk_step * chunk_idx;\n          assert_cmp(in_x, >=, 0);\n          assert_cmp(in_x, <, chunk_size);\n          if(index[in_x * idx_stride0 + in_y * idx_stride1] > 0.1) {\n            c += 1;\n            for(long z = 0; z < in_dim2; ++z)\n              output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] +=\n                input[in_x * in_stride0 + in_y * in_stride1 + z * in_stride2];\n          }\n        }\n\n        if(c > 0.1) {\n          for(long z = 0; z < in_dim2; ++z)\n            output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] /= c;\n          oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 1;\n          ofactors[out_x * ofac_stride0 + out_y * ofac_stride1] = 1.0 / c;\n        } else {\n          oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 0;\n          ofactors[out_x * ofac_stride0 + out_y * ofac_stride1] = 1.0;\n        }\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert_cmp(n_inputs, ==, 6);\n    assert_cmp(n_outputs, ==, 3);\n    Ndarray* input = inputs[0];\n    Ndarray* index = inputs[1];\n    Ndarray* chunk_params = inputs[5];\n    Ndarray* output = *outputs[0];\n    Ndarray* oindex = *outputs[1];\n    Ndarray* ofactors = *outputs[2];\n\n    assert_cmp(Ndarray_NDIM(input), ==, 3);\n    assert_cmp(Ndarray_NDIM(index), ==, 2);\n    assert_cmp(Ndarray_DIMS(input)[0], ==, Ndarray_DIMS(index)[0]);\n    assert_cmp(Ndarray_DIMS(input)[1], ==, Ndarray_DIMS(index)[1]);\n    assert_cmp(Ndarray_NDIM(chunk_params), ==, 1);\n    assert_cmp(Ndarray_DIMS(chunk_params)[0], ==, 2);\n    assert_cmp(Ndarray_NDIM(output), ==, 3);\n    assert_cmp(Ndarray_NDIM(oindex), ==, 2);\n    assert_cmp(Ndarray_NDIM(ofactors), ==, 2);\n    assert_cmp(Ndarray_DIMS(output)[0], ==, Ndarray_DIMS(oindex)[0]);\n    assert_cmp(Ndarray_DIMS(output)[1], ==, Ndarray_DIMS(oindex)[1]);\n    assert_cmp(Ndarray_DIMS(output)[2], ==, Ndarray_DIMS(input)[2]);\n    assert_cmp(Ndarray_DIMS(oindex)[0], ==, Ndarray_DIMS(ofactors)[0]);\n    assert_cmp(Ndarray_DIMS(oindex)[1], ==, Ndarray_DIMS(ofactors)[1]);\n\n    start_dev_kernel(unchunk_kernel, (\n      Ndarray_DEV_DATA(chunk_params),\n      Ndarray_DEV_DATA(input),\n        Ndarray_DIMS(input)[0],\n        Ndarray_DIMS(input)[1],\n        Ndarray_DIMS(input)[2],\n        Ndarray_STRIDE(input, 0),\n        Ndarray_STRIDE(input, 1),\n        Ndarray_STRIDE(input, 2),\n      Ndarray_DEV_DATA(index),\n        Ndarray_STRIDE(index, 0),\n        Ndarray_STRIDE(index, 1),\n      Ndarray_DEV_DATA(output),\n        Ndarray_DIMS(output)[0],\n        Ndarray_DIMS(output)[1],\n        Ndarray_STRIDE(output, 0),\n        Ndarray_STRIDE(output, 1),\n        Ndarray_STRIDE(output, 2),\n      Ndarray_DEV_DATA(oindex),\n        Ndarray_STRIDE(oindex, 0),\n        Ndarray_STRIDE(oindex, 1),\n      Ndarray_DEV_DATA(ofactors),\n        Ndarray_STRIDE(ofactors, 0),\n        Ndarray_STRIDE(ofactors, 1)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
code_version: Union[Tuple[int], int] = ()[source]#
class returnn.native_op.SubtensorBatchedIndex[source]#
Consider you have:

idx: 2d (n_time, n_batch) -> idx (in [0..n_dim-1]) x: 3d (n_time, n_batch, n_dim)

Then, this op will calculate:

x[…, idx[…]]: 2d (n_time, n_batch)

in_info: Tuple[Dict[str]] = ({'bw_in_var': {'want_inplace': 0}, 'name': 'x', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'idx', 'ndim': 2, 'shape': (None, None)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'y', 'ndim': 2, 'shape': ((0, 0), (0, 1))},)[source]#
classmethod grad_input_map(x, idx, y, DY)[source]#

Map.

c_extra_support_code: Dict[str, str] = {'select_bw_kernel': '\n    DEF_KERNEL\n    void select_bw_kernel(\n      float* Dx, long Dx_dim0, long Dx_dim1, long Dx_dim2, long Dx_stride0, long Dx_stride1, long Dx_stride2,\n      float* index, long idx_stride0, long idx_stride1,\n      float* Dy, long Dy_stride0, long Dy_stride1\n    ) {\n      const long max_idx = Dx_dim0 * Dx_dim1;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        long d0 = idx % Dx_dim0;\n        long d1 = idx / Dx_dim0;\n        long d2 = long(index[d0 * idx_stride0 + d1 * idx_stride1]);\n        if(d2 < 0) d2 = 0;\n        if(d2 >= Dx_dim2) d2 = Dx_dim2 - 1;\n        Dx[d0 * Dx_stride0 + d1 * Dx_stride1 + d2 * Dx_stride2] = Dy[d0 * Dy_stride0 + d1 * Dy_stride1];\n      }\n    }\n    ', 'select_kernel': '\n    DEF_KERNEL\n    void select_kernel(\n      float* x, long x_dim0, long x_dim1, long x_dim2, long x_stride0, long x_stride1, long x_stride2,\n      float* index, long idx_stride0, long idx_stride1,\n      float* y, long y_stride0, long y_stride1\n    ) {\n      const long max_idx = x_dim0 * x_dim1;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        long d0 = idx % x_dim0;\n        long d1 = idx / x_dim0;\n        long d2 = long(index[d0 * idx_stride0 + d1 * idx_stride1]);\n        if(d2 < 0) d2 = 0;\n        if(d2 >= x_dim2) d2 = x_dim2 - 1;\n        y[d0 * y_stride0 + d1 * y_stride1] = x[d0 * x_stride0 + d1 * x_stride1 + d2 * x_stride2];\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert_cmp(n_inputs, ==, 2);\n    assert_cmp(n_outputs, ==, 1);\n    Ndarray* x = inputs[0];\n    Ndarray* idx = inputs[1];\n    Ndarray* y = *outputs[0];\n\n    assert_cmp(Ndarray_NDIM(x), ==, 3);\n    assert_cmp(Ndarray_NDIM(idx), ==, 2);\n    assert_cmp(Ndarray_DIMS(x)[0], ==, Ndarray_DIMS(idx)[0]);\n    assert_cmp(Ndarray_DIMS(x)[1], ==, Ndarray_DIMS(idx)[1]);\n    assert_cmp(Ndarray_NDIM(y), ==, 2);\n    assert_cmp(Ndarray_DIMS(y)[0], ==, Ndarray_DIMS(idx)[0]);\n    assert_cmp(Ndarray_DIMS(y)[1], ==, Ndarray_DIMS(idx)[1]);\n\n    start_dev_kernel(select_kernel, (\n      Ndarray_DEV_DATA(x),\n        Ndarray_DIMS(x)[0],\n        Ndarray_DIMS(x)[1],\n        Ndarray_DIMS(x)[2],\n        Ndarray_STRIDE(x, 0),\n        Ndarray_STRIDE(x, 1),\n        Ndarray_STRIDE(x, 2),\n      Ndarray_DEV_DATA(idx),\n        Ndarray_STRIDE(idx, 0),\n        Ndarray_STRIDE(idx, 1),\n      Ndarray_DEV_DATA(y),\n        Ndarray_STRIDE(y, 0),\n        Ndarray_STRIDE(y, 1)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
c_bw_code: str = '\n    assert_cmp(n_inputs, ==, 3);\n    assert_cmp(n_outputs, ==, 1);\n    Ndarray* x = inputs[0];\n    Ndarray* idx = inputs[1];\n    Ndarray* Dy = inputs[2];\n    Ndarray* Dx = *outputs[0];  // inplace on x\n\n    assert_cmp(Ndarray_NDIM(x), ==, 3);\n    assert_cmp(Ndarray_NDIM(idx), ==, 2);\n    assert_cmp(Ndarray_DIMS(x)[0], ==, Ndarray_DIMS(idx)[0]);\n    assert_cmp(Ndarray_DIMS(x)[1], ==, Ndarray_DIMS(idx)[1]);\n    assert_cmp(Ndarray_NDIM(Dy), ==, 2);\n    assert_cmp(Ndarray_DIMS(Dy)[0], ==, Ndarray_DIMS(idx)[0]);\n    assert_cmp(Ndarray_DIMS(Dy)[1], ==, Ndarray_DIMS(idx)[1]);\n    assert_cmp(Ndarray_NDIM(Dx), ==, 3);\n    assert_cmp(Ndarray_DIMS(Dx)[0], ==, Ndarray_DIMS(x)[0]);\n    assert_cmp(Ndarray_DIMS(Dx)[1], ==, Ndarray_DIMS(x)[1]);\n    assert_cmp(Ndarray_DIMS(Dx)[2], ==, Ndarray_DIMS(x)[2]);\n\n    Ndarray_set_zero(Dx);\n    start_dev_kernel(select_bw_kernel, (\n      Ndarray_DEV_DATA(Dx),\n        Ndarray_DIMS(Dx)[0],\n        Ndarray_DIMS(Dx)[1],\n        Ndarray_DIMS(Dx)[2],\n        Ndarray_STRIDE(Dx, 0),\n        Ndarray_STRIDE(Dx, 1),\n        Ndarray_STRIDE(Dx, 2),\n      Ndarray_DEV_DATA(idx),\n        Ndarray_STRIDE(idx, 0),\n        Ndarray_STRIDE(idx, 1),\n      Ndarray_DEV_DATA(Dy),\n        Ndarray_STRIDE(Dy, 0),\n        Ndarray_STRIDE(Dy, 1)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
class returnn.native_op.SparseToDense[source]#

Expects a sparse matrix in COOrdinate format, where W[s0[i,b],b,s1[i]] = weight[i,b] for all i, and all batches b. Will return W (time,batch,dim).

in_info: Tuple[Dict[str]] = ({'name': '_initial_W', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None), 'want_inplace': 0}, {'name': 's0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 's1', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'weight', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'W', 'ndim': 3, 'shape': ((0, 0), (0, 1), (0, 2))},)[source]#
c_extra_support_code: Dict[str, str] = {'assign_kernel': '\n    DEF_KERNEL\n    void assign_kernel(\n      float* out, float* s0, float* s1, float* w, float* mask,\n      long n_sparse_idx, long n_time, long n_batch, long n_dim)\n    {\n      long max_idx = n_batch * n_sparse_idx;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        if(mask[idx] < 0.1) continue;\n        long batch = idx % n_batch;\n        long t = (long) s0[idx];\n        long j = (long) s1[idx];\n        float y = w[idx];\n        if(t < 0 || t >= n_time) continue;  // error somehow?\n        if(j < 0 || j >= n_dim) continue;  // error somehow?\n        long out_idx = t * n_batch * n_dim + batch * n_dim + j;\n        out[out_idx] += y;\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 5);\n    assert(n_outputs == 1);\n    Ndarray* s0 = inputs[1];\n    Ndarray* s1 = inputs[2];\n    Ndarray* weight = inputs[3];\n    Ndarray* mask = inputs[4];\n    Ndarray* out_W = *outputs[0];\n\n    assert(Ndarray_NDIM(s0) == 2);\n    assert(Ndarray_NDIM(s1) == 2);\n    assert(Ndarray_NDIM(weight) == 2);\n    assert(Ndarray_NDIM(mask) == 2);\n    assert(Ndarray_NDIM(out_W) == 3);\n    int n_sparse_idx = Ndarray_DIMS(s0)[0];\n    assert(n_sparse_idx == Ndarray_DIMS(s1)[0]);\n    assert(n_sparse_idx == Ndarray_DIMS(weight)[0]);\n    assert(n_sparse_idx == Ndarray_DIMS(mask)[0]);\n    int n_batch = Ndarray_DIMS(s0)[1];\n    assert(n_batch == Ndarray_DIMS(s1)[1]);\n    assert(n_batch == Ndarray_DIMS(weight)[1]);\n    assert(n_batch == Ndarray_DIMS(mask)[1]);\n    assert(n_batch == Ndarray_DIMS(out_W)[1]);\n    int n_time = Ndarray_DIMS(out_W)[0];\n    int n_dim = Ndarray_DIMS(out_W)[2];\n\n    start_dev_kernel(assign_kernel, (\n      Ndarray_DEV_DATA(out_W),\n      Ndarray_DEV_DATA(s0),\n      Ndarray_DEV_DATA(s1),\n      Ndarray_DEV_DATA(weight),\n      Ndarray_DEV_DATA(mask),\n      n_sparse_idx, n_time, n_batch, n_dim\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
class returnn.native_op.MaxAndArgmaxSparse[source]#

Expects a sparse matrix in COOrdinate format, where W[s0[i,b],s1[i],b] = weight[i,b] for all i, and all batches b. It will return the max and argmax for all W[:,:,b] over the second axis.

in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 's0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 's1', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'weight', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': '_out_max', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': '_out_arg', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None), 'want_inplace': 1})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'out_max', 'ndim': 2, 'shape': ((4, 0), (4, 1))}, {'name': 'out_arg', 'ndim': 2, 'shape': ((5, 0), (5, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'doit_kernel': '\n    DEF_KERNEL\n    void doit_kernel(\n        long n_batch, long n_in_time, long n_out_time,\n        float* s0, float* s1, float* weight, float* mask,\n        float* out_max, float* out_arg) {\n      long batch_idx = threadIdx.x + blockDim.x * blockIdx.x;\n      while(batch_idx < n_batch) {\n        for(long i = 0; i < n_in_time; ++i) {\n          long idx = i * n_batch + batch_idx;\n          if(mask[idx] < 0.1) continue;\n          long t = (long) s0[idx];\n          long j = (long) s1[idx];\n          float w = weight[idx];\n          if(t < 0 || t >= n_out_time) continue;  // error somehow?\n          long out_idx = t * n_batch + batch_idx;\n          if(w > out_max[out_idx]) {\n            out_max[out_idx] = w;\n            out_arg[out_idx] = (float) j;\n          }\n        }\n        batch_idx += gridDim.x * blockDim.x;\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 6);\n    assert(n_outputs == 2);\n    Ndarray* s0 = inputs[0];\n    Ndarray* s1 = inputs[1];\n    Ndarray* weight = inputs[2];\n    Ndarray* mask = inputs[3];\n    Ndarray* out_max = *outputs[0];\n    Ndarray* out_arg = *outputs[1];\n\n    assert(Ndarray_NDIM(s0) == 2);\n    assert(Ndarray_NDIM(s1) == 2);\n    assert(Ndarray_NDIM(weight) == 2);\n    assert(Ndarray_NDIM(mask) == 2);\n    assert(Ndarray_NDIM(out_max) == 2);\n    assert(Ndarray_NDIM(out_arg) == 2);\n    int n_in_time = Ndarray_DIMS(s0)[0];\n    assert(n_in_time == Ndarray_DIMS(s1)[0]);\n    assert(n_in_time == Ndarray_DIMS(weight)[0]);\n    assert(n_in_time == Ndarray_DIMS(mask)[0]);\n    int n_batch = Ndarray_DIMS(s0)[1];\n    assert(n_batch == Ndarray_DIMS(s1)[1]);\n    assert(n_batch == Ndarray_DIMS(weight)[1]);\n    assert(n_batch == Ndarray_DIMS(mask)[1]);\n    assert(n_batch == Ndarray_DIMS(out_arg)[1]);\n    assert(n_batch == Ndarray_DIMS(out_max)[1]);\n    int n_out_time = Ndarray_DIMS(out_arg)[0];\n    assert(n_out_time == Ndarray_DIMS(out_max)[0]);\n    assert(out_max != out_arg);  // earlier bug in NativeOp\n\n    start_dev_kernel(doit_kernel, (\n      n_batch, n_in_time, n_out_time,\n      Ndarray_DEV_DATA(s0),\n      Ndarray_DEV_DATA(s1),\n      Ndarray_DEV_DATA(weight),\n      Ndarray_DEV_DATA(mask),\n      Ndarray_DEV_DATA(out_max),\n      Ndarray_DEV_DATA(out_arg)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
code_version: Union[Tuple[int], int] = ()[source]#
class returnn.native_op.CrossEntropySoftmaxAndGradientZSparse[source]#

y_target is given in sparse COOrdinate format. We will calculate CE[t,b] = sum_i y_target[t,b,i] * log(softmax(z[t,b])[i]), for any timeframe t and batch b, and grad(CE[t,b], z[t,b]) = softmax(z[t,b]) - y_target[t,b]. We also support an index-mask for z, i.e. for the possible [t,b].

in_info: Tuple[Dict[str]] = ({'name': 'z', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'z_mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_t', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_w', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'out_ce', 'ndim': 2, 'shape': ((0, 0), (0, 1))}, {'name': 'out_grad_z', 'ndim': 3, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': '_out_max_z', 'ndim': 2, 'shape': ((0, 0), (0, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'ce_sm_grad_kernel': '\n    DEF_KERNEL\n    void ce_sm_grad_kernel(\n      float* out_ce, float* out_grad_z,\n      float* z, float* max_z, float* z_mask,\n      float* s0, float* s1, float* w, float* s_mask,\n      long n_time, long n_batch, long n_dim, long n_sparse_index)\n    {\n      long max_idx = n_batch * n_sparse_index;\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        if(s_mask[idx] < 0.1) continue;\n        long batch = idx % n_batch;\n        long t = (long) s0[idx];\n        long j = (long) s1[idx];\n        float y_target = w[idx];\n        if(t < 0 || t >= n_time) continue;  // error somehow?\n        if(j < 0 || j >= n_dim) continue;  // error somehow?\n        long out_ce_idx = t * n_batch + batch;\n        long out_y_idx = t * n_batch * n_dim + batch * n_dim + j;\n        // This assumes that out_grad_z is still softmax(z).\n        // This also assumes that every [t,j] is only represented once in the sparse data.\n        out_ce[out_ce_idx] -= y_target * log(fmax(out_grad_z[out_y_idx], 1e-30f));\n        out_grad_z[out_y_idx] -= y_target;\n      }\n    }\n    ', 'max_kernel': '\n    DEF_KERNEL\n    void max_kernel(float* out, float* v, float* mask, long stride, long max_idx) {\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        if(mask[idx] < 0.1)\n          continue;\n        long start = idx * stride;\n        float last_max = v[start];\n        out[idx] = last_max;\n        for(long i = 1; i < stride; ++i) {\n          float cur = v[start + i];\n          if(cur > last_max) {\n            last_max = cur;\n            out[idx] = cur;\n          }\n        }\n      }\n    }\n    ', 'softmax_kernel': '\n    DEF_KERNEL\n    void softmax_kernel(\n      float* out_softmax,\n      float* z, float* max_z, float* mask,\n      long stride, long max_idx)\n    {\n      for(\n        long idx = threadIdx.x + blockDim.x * blockIdx.x;\n        idx < max_idx;\n        idx += gridDim.x * blockDim.x)\n      {\n        long start = idx * stride;\n        float s = 0;\n        for(long i = 0; i < stride; ++i) {\n          s += exp(z[start + i] - max_z[idx]);\n        }\n        if(s < 1e-16) s = 1e-16;\n        for(long i = 0; i < stride; ++i) {\n          float y = exp(z[start + i] - max_z[idx]) / s;\n          out_softmax[start + i] = (mask[idx] > 0.5) ? y : 0;\n        }\n      }\n    }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 6);\n    assert(n_outputs == 3);\n    Ndarray* z = inputs[0];\n    Ndarray* z_mask = inputs[1];\n    Ndarray* s0 = inputs[2];\n    Ndarray* s1 = inputs[3];\n    Ndarray* w = inputs[4];\n    Ndarray* s_mask = inputs[5];\n    Ndarray* out_ce = *outputs[0];\n    Ndarray* out_grad_z = *outputs[1];\n    Ndarray* out_max_z = *outputs[2];\n\n    assert(Ndarray_NDIM(z) == 3);\n    assert(Ndarray_NDIM(z_mask) == 2);\n    assert(Ndarray_NDIM(out_ce) == 2);\n    assert(Ndarray_NDIM(out_grad_z) == 3);\n    assert(Ndarray_NDIM(out_max_z) == 2);\n    assert(Ndarray_NDIM(s0) == 2);\n    assert(Ndarray_NDIM(s1) == 2);\n    assert(Ndarray_NDIM(w) == 2);\n    assert(Ndarray_NDIM(out_ce) == 2);\n    int n_time = Ndarray_DIMS(z)[0];\n    int n_batch = Ndarray_DIMS(z)[1];\n    int n_dim = Ndarray_DIMS(z)[2];\n    assert(n_time == Ndarray_DIMS(z_mask)[0]);\n    assert(n_time == Ndarray_DIMS(out_ce)[0]);\n    assert(n_time == Ndarray_DIMS(out_grad_z)[0]);\n    assert(n_time == Ndarray_DIMS(out_max_z)[0]);\n    assert(n_batch == Ndarray_DIMS(z_mask)[1]);\n    assert(n_batch == Ndarray_DIMS(out_ce)[1]);\n    assert(n_batch == Ndarray_DIMS(out_grad_z)[1]);\n    assert(n_batch == Ndarray_DIMS(out_max_z)[1]);\n    assert(n_batch == Ndarray_DIMS(s0)[1]);\n    assert(n_batch == Ndarray_DIMS(s1)[1]);\n    assert(n_batch == Ndarray_DIMS(w)[1]);\n    assert(n_batch == Ndarray_DIMS(s_mask)[1]);\n    assert(n_dim == Ndarray_DIMS(out_grad_z)[2]);\n    int n_sparse_index = Ndarray_DIMS(s0)[0];\n    assert(n_sparse_index == Ndarray_DIMS(s1)[0]);\n    assert(n_sparse_index == Ndarray_DIMS(w)[0]);\n    assert(n_sparse_index == Ndarray_DIMS(s_mask)[0]);\n\n    start_dev_kernel(max_kernel, (\n      Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(z_mask),\n      n_dim, n_time * n_batch\n    ));\n    HANDLE_LAST_ERROR();\n    Ndarray_set_zero(out_ce);\n    start_dev_kernel(softmax_kernel, (\n      Ndarray_DEV_DATA(out_grad_z),\n      Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z_mask),\n      n_dim, n_time * n_batch\n    ));\n    HANDLE_LAST_ERROR();\n    start_dev_kernel(ce_sm_grad_kernel, (\n      Ndarray_DEV_DATA(out_ce), Ndarray_DEV_DATA(out_grad_z),\n      Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z_mask),\n      Ndarray_DEV_DATA(s0), Ndarray_DEV_DATA(s1), Ndarray_DEV_DATA(w), Ndarray_DEV_DATA(s_mask),\n      n_time, n_batch, n_dim, n_sparse_index\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
class returnn.native_op.FastBaumWelchOp[source]#
inputs:
param am_scores:

scores in -log space. 3d (time,batch,dim)

param edges:

edges of the graph (from,to,emission_idx,sequence_idx)

param weights:

weights of the edges

outputs:
param output:

Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores

in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'state_buffer', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'sums', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n    DEF_KERNEL\n    void set_start_states(float* states, unsigned* start_states) {\n      unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n      states[state_idx] = 0.0;\n    }\n  ', '010_fill_array': '\n    DEF_KERNEL\n    void fill_array(float* array, float value, unsigned size) {\n      unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n      if (idx < size) {\n        array[idx] = value;\n      }\n    }\n  ', '011_remove_inf': '\n  DEF_KERNEL\n  void remove_inf(float* array, unsigned size) {\n    unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < size) {\n      array[idx] = fminf(array[idx], 1e32);\n    }\n  }\n  ', '012_prob_add': '\n    DEV_FUNC\n    float prob_add(float a, float b) {\n      float diff = a - b;\n      if (isnan(diff)) {\n        return INF_F;\n      }\n      else {\n        return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n      }\n    }\n  ', '013_atomic_prob_add': '\n    DEV_FUNC\n    void atomic_prob_add(float* a, float b) {\n      int* addr = (int*)a;\n      int old   = float_as_int(*a);\n      int assumed;\n      do {\n        assumed = old;\n        old     = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n      } while (old != assumed);\n    }\n  ', '020_dump_to_file': "\n    template<typename T>\n    void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n      std::vector<T> buffer(n_d1);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        T val = buffer[i1];\n        if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n          output << i1 << ' ' << val << '\\n';\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          T val = buffer[i1 * n_d2 + i2];\n          if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n            output << i1 << ' ' << i2 << ' ' << val << '\\n';\n          }\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n            T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n            if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n              output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n            }\n          }\n        }\n      }\n    }\n  ", '100_init_bwd_state_buffer': '\n      DEF_KERNEL\n      void init_bwd_state_buffer(\n          float* states, unsigned* end_states, unsigned t, unsigned max_t, float* index, unsigned index_stride) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (index[t * index_stride + idx] == 1.0 && (t == max_t || index[(t + 1) * index_stride + idx] == 0.0)) {\n          unsigned state_idx = end_states[idx];\n          states[state_idx] = 0.0;\n        }\n      }\n    ', '101_next_frame': '\n      DEF_KERNEL\n      void next_frame(bool fwd, unsigned num_edges, unsigned  num_emissions,\n                      unsigned* sequence_idxs, unsigned* from_buffer, unsigned* to_buffer, float* weight_buffer,\n                      unsigned* emission_idxs,\n                      float* prev_frame, float* next_frame, float* am_scores, float* edge_buffer) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_edges) {\n          return;\n        }\n\n        unsigned from     = from_buffer  [idx];\n        float    prev_val = prev_frame[from];\n        if (isinf(prev_val)) {\n          edge_buffer[idx] = INF_F;\n          return;\n        }\n\n        unsigned to           = to_buffer    [idx];\n        unsigned emission_idx = emission_idxs[idx];\n        float    edge_weight  = weight_buffer[idx];\n        unsigned sequence_idx = sequence_idxs[idx];\n\n        float val = prev_val + edge_weight + am_scores[sequence_idx * num_emissions + emission_idx];\n\n        if (fwd) {\n          edge_buffer[idx] += val;\n        }\n        else {\n          edge_buffer[idx] += prev_val;\n        }\n        atomic_prob_add(next_frame + to, val);\n      }\n    ', '102_normalize': '\n      DEF_KERNEL\n      void normalize(float* buffer, unsigned* sequence_idxs, unsigned num_edges, unsigned num_seqs, float* sum_output) {\n        DEF_SHARED(float, sum);\n\n        buffer += blockIdx.x * num_edges;\n\n        for (unsigned s = 0u; s < num_seqs; s++) {\n          sum[s] = INF_F;\n        }\n\n        for (unsigned e = 0u; e < num_edges; e++) {\n          unsigned s = sequence_idxs[e];\n          sum[s] = prob_add(sum[s], buffer[e]);\n        }\n\n        for (unsigned s = 0ul; s < num_seqs; s++) {\n          if (isinf(sum[s])) {\n            // if the frame is empty (happens due to batching of seqs with unequal length), set it to 0\n            sum_output[blockIdx.x * num_seqs + s] = 0.0;\n          }\n          else {\n            sum_output[blockIdx.x * num_seqs + s] = sum[s];\n          }\n        }\n\n        for (unsigned e = 0u; e < num_edges; e++) {\n          unsigned s = sequence_idxs[e];\n          buffer[e] -= sum[s];\n        }\n      }\n    ', '103_compute_result': '\n      DEF_KERNEL\n      void compute_result(float* edge_buffer, float* out, unsigned* emission_idxs, unsigned* sequence_idxs,\n                          unsigned frame_stride, unsigned seq_stride,\n                          unsigned num_frames, unsigned num_seqs, unsigned num_edges) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_frames * num_edges) {\n          return;\n        }\n\n        unsigned e_idx        = idx % num_edges;\n        unsigned frame        = idx / num_edges;\n        unsigned emission_idx = emission_idxs[e_idx];\n        unsigned seq_idx      = sequence_idxs[e_idx];\n        float    score        = edge_buffer[idx];\n\n        atomic_prob_add(out + frame * frame_stride + seq_idx * seq_stride + emission_idx, score);\n      }\n    ', '110_write_alignment_to_file': '\n      void write_alignment_to_file(float* d_state_buffer, float* d_index, unsigned index_stride,\n                                   unsigned* d_start_states, unsigned* d_end_states,\n                                   float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_states,\n                                   unsigned batch_idx) {\n        std::vector<float>    state_buffer((n_frames + 1u) * n_states);\n        std::vector<float>    index       (n_frames * index_stride);\n        std::vector<unsigned> start_states(n_seqs);\n        std::vector<unsigned> end_states  (n_seqs);\n\n        //HANDLE_ERROR(cudaMemcpy(\n        //  state_buffer.data(), d_state_buffer, state_buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  index.data(),        d_index,        index.size()        * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  start_states.data(), d_start_states, start_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  end_states.data(),   d_end_states,   end_states.size()   * sizeof(float), cudaMemcpyDeviceToHost));\n\n        for (unsigned seq = 0u; seq < n_seqs; seq++) {\n          std::stringstream filename;\n          filename << "alignment.dump." << batch_idx << \'.\' << seq;\n          std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n          for (unsigned t = 0u; t < n_frames; t++) {\n            if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n              break;\n            }\n            float sum = std::numeric_limits<float>::infinity();\n            for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n              const float val = state_buffer[t * n_states + s];\n              float diff = val - sum;\n              if (!isnan(diff)) {\n                sum = -log1p(exp(-abs(diff))) + fminf(sum, val);\n              }\n            }\n            for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n              const float val = state_buffer[t * n_states + s] - sum;\n              if (val <= pruning) {\n                out << t << \' \' << (s - start_states[seq]) << \' \' << val << \'\\n\';\n              }\n            }\n          }\n        }\n      }\n    ', '111_write_output_to_file': '\n      void write_output_to_file(float* d_out, float* d_index, unsigned index_stride,\n                                float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_emissions,\n                                unsigned batch_idx) {\n        std::vector<float> buffer(n_frames * n_seqs * n_emissions);\n        std::vector<float> index (n_frames * index_stride);\n\n        //HANDLE_ERROR(cudaMemcpy(buffer.data(), d_out,   buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(index.data(),  d_index, index.size()  * sizeof(float), cudaMemcpyDeviceToHost));\n\n        for (unsigned seq = 0u; seq < n_seqs; seq++) {\n          std::stringstream filename;\n          filename << "target.dump." << batch_idx << \'.\' << seq;\n          std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n          for (unsigned t = 0u; t < n_frames; t++) {\n            if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n              break;\n            }\n            for (unsigned e = 0u; e < n_emissions; e++) {\n              const float val = buffer[t * n_seqs * n_emissions + seq * n_emissions + e];\n              if (val <= pruning) {\n                out << t << \' \' << e << \' \' << val << \'\\n\';\n              }\n            }\n          }\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    // am_scores, edges, weights, start_end_states, index, state_buffer* = input_names (*: inplace)\n    // output = output_names\n    assert(n_inputs  == 6);\n    assert(n_outputs == 2);\n    Ndarray* am_scores        = inputs[0];\n    Ndarray* edges            = inputs[1];\n    Ndarray* weights          = inputs[2];\n    Ndarray* start_end_states = inputs[3];\n    Ndarray* index            = inputs[4];\n    Ndarray* state_buffer     = inputs[5];\n    Ndarray* out              = *outputs[0];\n    Ndarray* sum_output       = *outputs[1];\n\n    /*\n    debug_print(context, am_scores, "am_scores");\n    debug_print(context, edges, "edges");\n    debug_print(context, weights, "weights");\n    debug_print(context, start_end_states, "start_end_states");\n    debug_print(context, index, "index");\n    debug_print(context, state_buffer, "state_buffer");\n    */\n\n    assert_cmp(Ndarray_DIMS(am_scores)[0], ==, Ndarray_DIMS(out)[0]);\n    assert_cmp(Ndarray_DIMS(am_scores)[1], ==, Ndarray_DIMS(out)[1]);\n    assert_cmp(Ndarray_DIMS(am_scores)[2], ==, Ndarray_DIMS(out)[2]);\n    assert_cmp(Ndarray_DIMS(am_scores)[1], ==, Ndarray_DIMS(start_end_states)[1]);\n\n    assert_cmp(Ndarray_DIMS(sum_output)[0], ==, Ndarray_DIMS(am_scores)[0]);\n    assert_cmp(Ndarray_DIMS(sum_output)[1], ==, Ndarray_DIMS(am_scores)[1]);\n\n    bool            dump_alignment = false;\n    bool            dump_output    = false;\n    unsigned        dump_every = 40u;\n    static unsigned batch_idx  = 0u;\n    float           pruning    = 10.f;\n\n    unsigned* d_from = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 0 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_to = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 1 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_emission_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 2 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_sequence_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 3 * Ndarray_STRIDE(edges, 0));\n    float*    d_weights = Ndarray_DEV_DATA(weights);\n    float*    d_am_scores = Ndarray_DEV_DATA(am_scores);\n    unsigned* d_start_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_end_states)\n      + 0 * Ndarray_STRIDE(start_end_states, 0));\n    unsigned* d_end_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_end_states)\n      + 1 * Ndarray_STRIDE(start_end_states, 0));\n    float*    d_index             = Ndarray_DEV_DATA(index);\n    float*    d_state_buffer_prev = Ndarray_DEV_DATA(state_buffer) + 0 * Ndarray_STRIDE(state_buffer, 0);\n    float*    d_state_buffer_next = Ndarray_DEV_DATA(state_buffer) + 1 * Ndarray_STRIDE(state_buffer, 0);\n    float*    d_out               = Ndarray_DEV_DATA(out);\n    float*    d_sum_output        = Ndarray_DEV_DATA(sum_output);\n\n    unsigned n_frames    = Ndarray_DIMS(am_scores)[0];\n    unsigned n_seqs      = Ndarray_DIMS(am_scores)[1];\n    unsigned n_emissions = Ndarray_DIMS(am_scores)[2];\n    unsigned n_states    = Ndarray_DIMS(state_buffer)[1];\n    unsigned n_edges     = Ndarray_DIMS(edges)[1];\n    unsigned n_threads   = 1024u;\n    unsigned n_blocks    = (n_edges + n_threads - 1) / n_threads;\n\n    unsigned frame_stride    = Ndarray_STRIDE(am_scores, 0);\n    unsigned sequence_stride = Ndarray_STRIDE(am_scores, 1);\n    unsigned index_stride    = Ndarray_STRIDE(index, 0);\n\n    assert_cmp(n_frames, >, 0);\n    assert_cmp(n_states, >, 0);\n    //std::cerr << "n_frames: "    << n_frames    << std::endl;\n    //std::cerr << "n_seqs: "      << n_seqs      << std::endl;\n    //std::cerr << "n_emissions: " << n_emissions << std::endl;\n    //std::cerr << "n_states: "    << n_states    << std::endl;\n    //std::cerr << "n_edges: "     << n_edges     << std::endl;\n    //std::cerr << "n_threads: "   << n_threads   << std::endl;\n    //std::cerr << "n_blocks: "    << n_blocks    << std::endl;\n\n    //std::cerr << "frame_stride: "     << frame_stride    << std::endl;\n    //std::cerr << "sequnence_stride: " << sequence_stride << std::endl;\n    //std::cerr << "index_stride: "     << index_stride    << std::endl;\n\n    // initialize edge buffer\n    float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));\n    if(!d_edge_buffer) { HANDLE_LAST_ERROR(); abort(); }  // error should have been set in device_malloc\n    unsigned n_fill_blocks = (n_edges * n_frames + n_threads - 1u) / n_threads;\n    start_dev_kernel2(fill_array, n_fill_blocks, n_threads, 0, (d_edge_buffer, 0.0, n_edges * n_frames));\n    HANDLE_LAST_ERROR();\n\n    // initialize the state buffer\n    n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n    start_dev_kernel2(\n      fill_array, n_fill_blocks, n_threads, 0,\n      (d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states));\n    HANDLE_LAST_ERROR();\n    start_dev_kernel2(set_start_states, 1, n_seqs, 0, (d_state_buffer_prev, d_start_states));\n    HANDLE_LAST_ERROR();\n\n    // initialize full state buffer (only used to dump the alignment)\n    float* d_state_buffer_all = NULL;\n    if (dump_alignment && batch_idx %% dump_every == 0) {\n      d_state_buffer_all = reinterpret_cast<float*>(device_malloc(n_states * (n_frames + 1u) * sizeof(float)));\n      if(!d_state_buffer_all) { HANDLE_LAST_ERROR(); abort(); }  // error should have been set in device_malloc\n      Ndarray_memcpy(d_state_buffer_all, d_state_buffer_prev, n_states * sizeof(float));\n      HANDLE_LAST_ERROR();\n    }\n\n    // fwd pass\n    for (unsigned t = 0u; t < n_frames; t++) {\n      start_dev_kernel2(\n        fill_array, n_fill_blocks, n_threads, 0,\n        (d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states));\n      HANDLE_LAST_ERROR();\n      start_dev_kernel2(next_frame, n_blocks, n_threads, 0,\n        (true, n_edges, sequence_stride,\n         d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs,\n         d_state_buffer_prev, d_state_buffer_next, d_am_scores + t * frame_stride, d_edge_buffer + t * n_edges));\n      HANDLE_LAST_ERROR();\n      if (dump_alignment && batch_idx %% dump_every == 0) {\n        Ndarray_memcpy(d_state_buffer_all + (t + 1u) * n_states, d_state_buffer_next, n_states * sizeof(float));\n        HANDLE_LAST_ERROR();\n      }\n      std::swap(d_state_buffer_prev, d_state_buffer_next);\n    }\n\n    // bwd pass\n    start_dev_kernel2(\n      fill_array, n_fill_blocks, n_threads, 0,\n      (d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states));\n    HANDLE_LAST_ERROR();\n    for (unsigned t = n_frames; t > 0; t--) {\n      start_dev_kernel2(init_bwd_state_buffer, 1, n_seqs, 0,\n        (d_state_buffer_prev, d_end_states, t - 1, n_frames - 1, d_index, index_stride));\n      HANDLE_LAST_ERROR();\n      if (dump_alignment && batch_idx %% dump_every == 0) {\n        float alpha = 1.0f;\n        //HANDLE_ERROR(cublasSaxpy(\n        //  handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all + t * n_states, 1));\n      }\n      start_dev_kernel2(\n        fill_array, n_fill_blocks, n_threads, 0,\n        (d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states));\n      HANDLE_LAST_ERROR();\n      start_dev_kernel2(next_frame, n_blocks, n_threads, 0,\n        (false, n_edges, sequence_stride,\n         d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs,\n         d_state_buffer_prev, d_state_buffer_next, d_am_scores + (t - 1) * frame_stride,\n         d_edge_buffer + (t - 1) * n_edges));\n      HANDLE_LAST_ERROR();\n      std::swap(d_state_buffer_prev, d_state_buffer_next);\n    }\n    if (dump_alignment && batch_idx %% dump_every == 0) {\n      float alpha = 1.0f;\n      //HANDLE_ERROR(cublasSaxpy(handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all, 1));\n    }\n\n    // normalize at each time frame\n    start_dev_kernel2(normalize, n_frames, 1, n_seqs * sizeof(float),\n      (d_edge_buffer, d_sequence_idxs, n_edges, n_seqs, d_sum_output));\n    HANDLE_LAST_ERROR();\n\n    // dump alignment\n    if (dump_alignment && batch_idx %% dump_every == 0) {\n      write_alignment_to_file(d_state_buffer_all, d_index, index_stride, d_start_states, d_end_states,\n                              pruning, n_frames, n_seqs, n_states, batch_idx);\n    }\n\n    n_fill_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n    start_dev_kernel2(\n      fill_array, n_fill_blocks, n_threads, 0,\n      (d_out, std::numeric_limits<float>::infinity(), n_frames * n_seqs * n_emissions));\n    HANDLE_LAST_ERROR();\n\n    frame_stride    = Ndarray_STRIDE(out, 0);\n    sequence_stride = Ndarray_STRIDE(out, 1);\n    n_blocks        = (n_frames * n_edges + n_threads - 1u) / n_threads;\n    start_dev_kernel2(compute_result, n_blocks, n_threads, 0,\n      (d_edge_buffer, d_out, d_emission_idxs, d_sequence_idxs,\n       frame_stride, sequence_stride, n_frames, n_seqs, n_edges));\n    HANDLE_LAST_ERROR();\n\n    #if TENSORFLOW\n    // Certain TensorFlow code doesn\'t like inf, even if it is just the CheckNumerics,\n    // which is helpful for debugging.\n    // We replace it by a very high number, so that tf.exp(-out) will still result in 0.0.\n    n_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n    start_dev_kernel2(remove_inf, n_blocks, n_threads, 0, (d_out, n_frames * n_seqs * n_emissions));\n    //debug_print(context, out, "out");\n    #endif\n    if (dump_output && batch_idx %% dump_every == 0) {\n      write_output_to_file(d_out, d_index, index_stride, pruning, n_frames, n_seqs, n_emissions, batch_idx);\n    }\n\n    device_free(d_edge_buffer);\n    if (d_state_buffer_all != NULL) {\n      device_free(d_state_buffer_all);\n    }\n    batch_idx++;\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.MultiEndFastBaumWelchOp[source]#
inputs:
param am_scores:

scores in -log space. 3d (time,batch,dim)

param edges:

edges of the graph (from,to,emission_idx,sequence_idx)

param weights:

weights of the edges

outputs:
param output:

Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores

in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_states', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (None, 2)}, {'gradient': 'disconnected', 'name': 'end_state_weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((4, 0),)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'state_buffer', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)})[source]#
out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'sums', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n    DEF_KERNEL\n    void set_start_states(float* states, unsigned* start_states) {\n      unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n      states[state_idx] = 0.0;\n    }\n  ', '010_fill_array': '\n    DEF_KERNEL\n    void fill_array(float* array, float value, unsigned size) {\n      unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n      if (idx < size) {\n        array[idx] = value;\n      }\n    }\n  ', '011_remove_inf': '\n  DEF_KERNEL\n  void remove_inf(float* array, unsigned size) {\n    unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < size) {\n      array[idx] = fminf(array[idx], 1e32);\n    }\n  }\n  ', '012_prob_add': '\n    DEV_FUNC\n    float prob_add(float a, float b) {\n      float diff = a - b;\n      if (isnan(diff)) {\n        return INF_F;\n      }\n      else {\n        return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n      }\n    }\n  ', '013_atomic_prob_add': '\n    DEV_FUNC\n    void atomic_prob_add(float* a, float b) {\n      int* addr = (int*)a;\n      int old   = float_as_int(*a);\n      int assumed;\n      do {\n        assumed = old;\n        old     = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n      } while (old != assumed);\n    }\n  ', '020_dump_to_file': "\n    template<typename T>\n    void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n      std::vector<T> buffer(n_d1);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        T val = buffer[i1];\n        if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n          output << i1 << ' ' << val << '\\n';\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          T val = buffer[i1 * n_d2 + i2];\n          if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n            output << i1 << ' ' << i2 << ' ' << val << '\\n';\n          }\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n            T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n            if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n              output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n            }\n          }\n        }\n      }\n    }\n  ", '100_init_bwd_state_buffer': '\n      __global__\n      void init_bwd_state_buffer(unsigned t, unsigned max_t, unsigned num_endstates, unsigned index_stride,\n                                 float* states, unsigned const* end_states, float const* end_state_weights,\n                                 float const* index) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_endstates) {\n          return;\n        }\n\n        unsigned seq_idx = end_states[idx * 2u + 0u];\n        if (index[t * index_stride + seq_idx] == 1.0\n            && (t == max_t || index[(t + 1) * index_stride + seq_idx] == 0.0)) {\n          unsigned state_idx = end_states[idx * 2u + 1u];\n          float    weight    = end_state_weights[idx];\n          states[state_idx] = weight;\n        }\n      }\n    ', '101_next_frame': '\n      DEF_KERNEL\n      void next_frame(bool fwd, unsigned num_edges, unsigned  num_emissions,\n                      unsigned* sequence_idxs, unsigned* from_buffer, unsigned* to_buffer, float* weight_buffer,\n                      unsigned* emission_idxs,\n                      float* prev_frame, float* next_frame, float* am_scores, float* edge_buffer) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_edges) {\n          return;\n        }\n\n        unsigned from     = from_buffer  [idx];\n        float    prev_val = prev_frame[from];\n        if (isinf(prev_val)) {\n          edge_buffer[idx] = INF_F;\n          return;\n        }\n\n        unsigned to           = to_buffer    [idx];\n        unsigned emission_idx = emission_idxs[idx];\n        float    edge_weight  = weight_buffer[idx];\n        unsigned sequence_idx = sequence_idxs[idx];\n\n        float val = prev_val + edge_weight + am_scores[sequence_idx * num_emissions + emission_idx];\n\n        if (fwd) {\n          edge_buffer[idx] += val;\n        }\n        else {\n          edge_buffer[idx] += prev_val;\n        }\n        atomic_prob_add(next_frame + to, val);\n      }\n    ', '102_normalize': '\n      DEF_KERNEL\n      void normalize(float* buffer, unsigned* sequence_idxs, unsigned num_edges, unsigned num_seqs, float* sum_output) {\n        DEF_SHARED(float, sum);\n\n        buffer += blockIdx.x * num_edges;\n\n        for (unsigned s = 0u; s < num_seqs; s++) {\n          sum[s] = INF_F;\n        }\n\n        for (unsigned e = 0u; e < num_edges; e++) {\n          unsigned s = sequence_idxs[e];\n          sum[s] = prob_add(sum[s], buffer[e]);\n        }\n\n        for (unsigned s = 0ul; s < num_seqs; s++) {\n          if (isinf(sum[s])) {\n            // if the frame is empty (happens due to batching of seqs with unequal length), set it to 0\n            sum_output[blockIdx.x * num_seqs + s] = 0.0;\n          }\n          else {\n            sum_output[blockIdx.x * num_seqs + s] = sum[s];\n          }\n        }\n\n        for (unsigned e = 0u; e < num_edges; e++) {\n          unsigned s = sequence_idxs[e];\n          buffer[e] -= sum[s];\n        }\n      }\n    ', '103_compute_result': '\n      DEF_KERNEL\n      void compute_result(float* edge_buffer, float* out, unsigned* emission_idxs, unsigned* sequence_idxs,\n                          unsigned frame_stride, unsigned seq_stride,\n                          unsigned num_frames, unsigned num_seqs, unsigned num_edges) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_frames * num_edges) {\n          return;\n        }\n\n        unsigned e_idx        = idx % num_edges;\n        unsigned frame        = idx / num_edges;\n        unsigned emission_idx = emission_idxs[e_idx];\n        unsigned seq_idx      = sequence_idxs[e_idx];\n        float    score        = edge_buffer[idx];\n\n        atomic_prob_add(out + frame * frame_stride + seq_idx * seq_stride + emission_idx, score);\n      }\n    ', '110_write_alignment_to_file': '\n      void write_alignment_to_file(float* d_state_buffer, float* d_index, unsigned index_stride,\n                                   unsigned* d_start_states, unsigned* d_end_states,\n                                   float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_states,\n                                   unsigned batch_idx) {\n        std::vector<float>    state_buffer((n_frames + 1u) * n_states);\n        std::vector<float>    index       (n_frames * index_stride);\n        std::vector<unsigned> start_states(n_seqs);\n        std::vector<unsigned> end_states  (n_seqs);\n\n        //HANDLE_ERROR(cudaMemcpy(\n        //  state_buffer.data(), d_state_buffer, state_buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  index.data(),        d_index,        index.size()        * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  start_states.data(), d_start_states, start_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(\n        //  end_states.data(),   d_end_states,   end_states.size()   * sizeof(float), cudaMemcpyDeviceToHost));\n\n        for (unsigned seq = 0u; seq < n_seqs; seq++) {\n          std::stringstream filename;\n          filename << "alignment.dump." << batch_idx << \'.\' << seq;\n          std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n          for (unsigned t = 0u; t < n_frames; t++) {\n            if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n              break;\n            }\n            float sum = std::numeric_limits<float>::infinity();\n            for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n              const float val = state_buffer[t * n_states + s];\n              float diff = val - sum;\n              if (!isnan(diff)) {\n                sum = -log1p(exp(-abs(diff))) + fminf(sum, val);\n              }\n            }\n            for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n              const float val = state_buffer[t * n_states + s] - sum;\n              if (val <= pruning) {\n                out << t << \' \' << (s - start_states[seq]) << \' \' << val << \'\\n\';\n              }\n            }\n          }\n        }\n      }\n    ', '111_write_output_to_file': '\n      void write_output_to_file(float* d_out, float* d_index, unsigned index_stride,\n                                float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_emissions,\n                                unsigned batch_idx) {\n        std::vector<float> buffer(n_frames * n_seqs * n_emissions);\n        std::vector<float> index (n_frames * index_stride);\n\n        //HANDLE_ERROR(cudaMemcpy(buffer.data(), d_out,   buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n        //HANDLE_ERROR(cudaMemcpy(index.data(),  d_index, index.size()  * sizeof(float), cudaMemcpyDeviceToHost));\n\n        for (unsigned seq = 0u; seq < n_seqs; seq++) {\n          std::stringstream filename;\n          filename << "target.dump." << batch_idx << \'.\' << seq;\n          std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n          for (unsigned t = 0u; t < n_frames; t++) {\n            if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n              break;\n            }\n            for (unsigned e = 0u; e < n_emissions; e++) {\n              const float val = buffer[t * n_seqs * n_emissions + seq * n_emissions + e];\n              if (val <= pruning) {\n                out << t << \' \' << e << \' \' << val << \'\\n\';\n              }\n            }\n          }\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    // am_scores, edges, weights, start_states, end_states, end_state_weights,\n    //   index, state_buffer* = input_names (*: inplace)\n    // output = output_names\n    assert(n_inputs  == 8);\n    assert(n_outputs == 2);\n    Ndarray* am_scores         = inputs[0];\n    Ndarray* edges             = inputs[1];\n    Ndarray* weights           = inputs[2];\n    Ndarray* start_states      = inputs[3];\n    Ndarray* end_states        = inputs[4];\n    Ndarray* end_state_weights = inputs[5];\n    Ndarray* index             = inputs[6];\n    Ndarray* state_buffer      = inputs[7];\n    Ndarray* out               = *outputs[0];\n    Ndarray* sum_output        = *outputs[1];\n\n    assert(Ndarray_DIMS(am_scores)[0] == Ndarray_DIMS(out)[0]);\n    assert(Ndarray_DIMS(am_scores)[1] == Ndarray_DIMS(out)[1]);\n    assert(Ndarray_DIMS(am_scores)[2] == Ndarray_DIMS(out)[2]);\n//    assert(Ndarray_DIMS(am_scores)[1] == Ndarray_DIMS(end_states)[0]);\n\n    assert(Ndarray_DIMS(sum_output)[0] == Ndarray_DIMS(am_scores)[0]);\n    assert(Ndarray_DIMS(sum_output)[1] == Ndarray_DIMS(am_scores)[1]);\n\n    bool            dump_alignment = false;\n    bool            dump_output    = false;\n    unsigned        dump_every = 40u;\n    static unsigned batch_idx  = 0u;\n    float           pruning    = 10.f;\n\n    unsigned* d_from = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 0 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_to = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 1 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_emission_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 2 * Ndarray_STRIDE(edges, 0));\n    unsigned* d_sequence_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n      + 3 * Ndarray_STRIDE(edges, 0));\n    float*    d_weights           = Ndarray_DEV_DATA(weights);\n    float*    d_am_scores         = Ndarray_DEV_DATA(am_scores);\n    unsigned* d_start_states      = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_states));\n    unsigned* d_end_states        = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(end_states));\n    float*    d_end_state_weights = Ndarray_DEV_DATA(end_state_weights);\n    float*    d_index             = Ndarray_DEV_DATA(index);\n    float*    d_state_buffer_prev = Ndarray_DEV_DATA(state_buffer) + 0 * Ndarray_STRIDE(state_buffer, 0);\n    float*    d_state_buffer_next = Ndarray_DEV_DATA(state_buffer) + 1 * Ndarray_STRIDE(state_buffer, 0);\n    float*    d_out               = Ndarray_DEV_DATA(out);\n    float*    d_sum_output        = Ndarray_DEV_DATA(sum_output);\n\n    unsigned n_frames       = Ndarray_DIMS(am_scores)[0];\n    unsigned n_seqs         = Ndarray_DIMS(am_scores)[1];\n    unsigned n_emissions    = Ndarray_DIMS(am_scores)[2];\n    unsigned n_states       = Ndarray_DIMS(state_buffer)[1];\n    unsigned n_edges        = Ndarray_DIMS(edges)[1];\n    unsigned n_start_states = Ndarray_DIMS(start_states)[0];\n    unsigned n_end_states   = Ndarray_DIMS(end_states)[0];\n    unsigned n_threads      = 1024u;\n    unsigned n_blocks       = (n_edges + n_threads - 1) / n_threads;\n\n    unsigned frame_stride    = Ndarray_STRIDE(am_scores, 0);\n    unsigned sequence_stride = Ndarray_STRIDE(am_scores, 1);\n    unsigned index_stride    = Ndarray_STRIDE(index, 0);\n\n    assert(n_frames > 0);\n\n//    std::cerr << "n_frames: "       << n_frames       << std::endl;\n//    std::cerr << "n_seqs: "         << n_seqs         << std::endl;\n//    std::cerr << "n_emissions: "    << n_emissions    << std::endl;\n//    std::cerr << "n_states: "       << n_states       << std::endl;\n//    std::cerr << "n_edges: "        << n_edges        << std::endl;\n//    std::cerr << "n_start_states: " << n_start_states << std::endl;\n//    std::cerr << "n_end_states: "   << n_end_states   << std::endl;\n//    std::cerr << "n_threads: "      << n_threads      << std::endl;\n//    std::cerr << "n_blocks: "       << n_blocks       << std::endl;\n\n//    std::cerr << "frame_stride: "     << frame_stride    << std::endl;\n//    std::cerr << "sequence_stride: "  << sequence_stride << std::endl;\n//    std::cerr << "index_stride: "     << index_stride    << std::endl;\n\n    // initialize edge buffer\n    float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n    unsigned n_fill_blocks = (n_edges * n_frames + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(d_edge_buffer, 0.0, n_edges * n_frames);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n\n    // initialize the state buffer\n    n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n    set_start_states<<<1, n_start_states>>>(d_state_buffer_prev, d_start_states);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n\n    // initialize full state buffer (only used to dump the alignment)\n    float* d_state_buffer_all = NULL;\n    if (dump_alignment and batch_idx %% dump_every == 0) {\n      d_state_buffer_all = reinterpret_cast<float*>(device_malloc(n_states * (n_frames + 1u) * sizeof(float)));\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n      cudaMemcpy(d_state_buffer_all, d_state_buffer_prev, n_states * sizeof(float), cudaMemcpyDeviceToDevice);\n//      HANDLE_LAST_ERROR();\n    }\n\n    // fwd pass\n    for (unsigned t = 0u; t < n_frames; t++) {\n      fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n//      std::cerr << "frame " << t << std::endl;\n      next_frame<<<n_blocks, n_threads>>>(true, n_edges, sequence_stride,\n                                          d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs,\n                                          d_state_buffer_prev, d_state_buffer_next, d_am_scores + t * frame_stride,\n                                          d_edge_buffer + t * n_edges);\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n      if (dump_alignment and batch_idx %% dump_every == 0) {\n        cudaMemcpy(\n          d_state_buffer_all + (t + 1u) * n_states, d_state_buffer_next, n_states * sizeof(float),\n          cudaMemcpyDeviceToDevice);\n      }\n      std::swap(d_state_buffer_prev, d_state_buffer_next);\n    }\n\n    // bwd pass\n    const unsigned n_end_state_blocks = (n_end_states + n_threads - 1u) / n_threads;\n    const unsigned n_end_state_threads = min(n_threads, n_end_states);\n    fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n    for (unsigned t = n_frames; t > 0; t--) {\n      init_bwd_state_buffer<<<n_end_state_blocks, n_end_state_threads>>>(\n        t - 1, n_frames - 1, n_end_states, index_stride,\n        d_state_buffer_prev, d_end_states, d_end_state_weights,  d_index);\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n      if (dump_alignment and batch_idx %% dump_every == 0) {\n        float alpha = 1.0f;\n//        HANDLE_ERROR(cublasSaxpy(\n//          handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all + t * n_states, 1));\n      }\n      fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n      next_frame<<<n_blocks, n_threads>>>(false, n_edges, sequence_stride,\n                                          d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs,\n                                          d_state_buffer_prev, d_state_buffer_next,\n                                          d_am_scores + (t - 1) * frame_stride,\n                                          d_edge_buffer + (t - 1) * n_edges);\n//      cudaDeviceSynchronize();\n//      HANDLE_LAST_ERROR();\n      std::swap(d_state_buffer_prev, d_state_buffer_next);\n    }\n    if (dump_alignment and batch_idx %% dump_every == 0) {\n      float alpha = 1.0f;\n//      HANDLE_ERROR(cublasSaxpy(handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all, 1));\n    }\n\n    // normalize at each time frame\n    normalize<<<n_frames, 1, n_seqs * sizeof(float)>>>(d_edge_buffer, d_sequence_idxs, n_edges, n_seqs, d_sum_output);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n\n    // dump alignment\n    if (dump_alignment and batch_idx %% dump_every == 0) {\n      write_alignment_to_file(d_state_buffer_all, d_index, index_stride, d_start_states, d_end_states,\n                              pruning, n_frames, n_seqs, n_states, batch_idx);\n    }\n\n    n_fill_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(\n      d_out, std::numeric_limits<float>::infinity(), n_frames * n_seqs * n_emissions);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n\n    frame_stride    = Ndarray_STRIDE(out, 0);\n    sequence_stride = Ndarray_STRIDE(out, 1);\n    n_blocks        = (n_frames * n_edges + n_threads - 1u) / n_threads;\n    compute_result<<<n_blocks, n_threads>>>(d_edge_buffer, d_out, d_emission_idxs, d_sequence_idxs,\n                                            frame_stride, sequence_stride, n_frames, n_seqs, n_edges);\n//    cudaDeviceSynchronize();\n//    HANDLE_LAST_ERROR();\n\n    #if TENSORFLOW\n    // Certain TensorFlow code doesn\'t like inf, even if it is just the CheckNumerics,\n    // which is helpful for debugging.\n    // We replace it by a very high number, so that tf.exp(-out) will still result in 0.0.\n    n_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n    remove_inf<<<n_blocks, n_threads>>>(d_out, n_frames * n_seqs * n_emissions);\n    //debug_print(context, out, "out");\n    #endif\n    if (dump_output and batch_idx %% dump_every == 0) {\n      write_output_to_file(d_out, d_index, index_stride, pruning, n_frames, n_seqs, n_emissions, batch_idx);\n    }\n\n    device_free(d_edge_buffer);\n    if (d_state_buffer_all != NULL) {\n      device_free(d_state_buffer_all);\n    }\n    batch_idx++;\n  '[source]#
c_bw_code: str = None[source]#
cpu_support = False[source]#
class returnn.native_op.SegmentFastBaumWelchOp(segmentwise_normalization=False, dump_targets_interval=None, new_batch_idxs_format=False)[source]#

Segmental Baum-Welch…

out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'normalization_factors', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'name': 'posterior_weigths', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
cpu_support = False[source]#
c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n    DEF_KERNEL\n    void set_start_states(float* states, unsigned* start_states) {\n      unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n      states[state_idx] = 0.0;\n    }\n  ', '010_fill_array': '\n    DEF_KERNEL\n    void fill_array(float* array, float value, unsigned size) {\n      unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n      if (idx < size) {\n        array[idx] = value;\n      }\n    }\n  ', '011_remove_inf': '\n  DEF_KERNEL\n  void remove_inf(float* array, unsigned size) {\n    unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < size) {\n      array[idx] = fminf(array[idx], 1e32);\n    }\n  }\n  ', '012_prob_add': '\n    DEV_FUNC\n    float prob_add(float a, float b) {\n      float diff = a - b;\n      if (isnan(diff)) {\n        return INF_F;\n      }\n      else {\n        return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n      }\n    }\n  ', '013_atomic_prob_add': '\n    DEV_FUNC\n    void atomic_prob_add(float* a, float b) {\n      int* addr = (int*)a;\n      int old   = float_as_int(*a);\n      int assumed;\n      do {\n        assumed = old;\n        old     = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n      } while (old != assumed);\n    }\n  ', '020_dump_to_file': "\n    template<typename T>\n    void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n      std::vector<T> buffer(n_d1);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        T val = buffer[i1];\n        if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n          output << i1 << ' ' << val << '\\n';\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          T val = buffer[i1 * n_d2 + i2];\n          if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n            output << i1 << ' ' << i2 << ' ' << val << '\\n';\n          }\n        }\n      }\n    }\n\n    template<typename T>\n    void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n      std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n      //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n      std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n      for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n        for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n          for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n            T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n            if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n              output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n            }\n          }\n        }\n      }\n    }\n  ", '100_get_batch_idx': '\n      __device__\n      int get_batch_idx(int const* batch_idxs, unsigned num_seqs, unsigned t, unsigned seq_idx) {\n        if (NEW_BATCH_IDX_FORMAT) {\n          int res = batch_idxs[seq_idx] + t;\n          if (res >= batch_idxs[seq_idx + 1]) {\n            return -1;\n          }\n          return res;\n        }\n        else {\n          return batch_idxs[t * num_seqs + seq_idx];\n        }\n      }\n    ', '101_init_bwd_state_buffer': '\n      __global__\n      void init_bwd_state_buffer(unsigned t, unsigned num_batches, unsigned num_seqs,\n                                 int* batch_idxs, float* index, float* states, unsigned* end_states) {\n        unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        int batch_idx = get_batch_idx(batch_idxs, num_seqs, t, idx);\n        if (batch_idx < 0) {\n          return;\n        }\n        float* batch_first_frame = index + batch_idx;\n        //if (*batch_first_frame != 0.0 && (t == max_t || *(batch_first_frame + 1) == 0.0)) {\n        if (batch_first_frame[0] != 0.0 && batch_first_frame[num_batches] == 0.0) {\n          unsigned state_idx = end_states[idx];\n          states[state_idx] = 0.0;\n        }\n      }\n    ', '102_next_frame_fwd': '\n      __global__\n      void next_frame_fwd(unsigned time, unsigned num_states, unsigned num_edges, unsigned num_emissions,\n                          unsigned num_seg_frames,\n                          unsigned num_tot_frames, unsigned num_seqs, unsigned num_am_score_scales,\n                          unsigned const* sequence_idxs, unsigned const* from_buffer, unsigned const* to_buffer,\n                          float const* weight_buffer,\n                          unsigned const* emission_idxs, unsigned const* lenmod_idxs, int const* batch_idxs,\n                          float const* am_scores, float const* length_models, float const* am_score_scales,\n                          float const* epoch,\n                          float* state_buffer, float* edge_buffer) {\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_edges) {\n          return;\n        }\n\n        const unsigned num_ringbuffer_frames = num_seg_frames + 1;\n        const unsigned max_seg_frames        = min(num_seg_frames, num_tot_frames - time);\n\n        const unsigned prev_frame_idx   = time % num_ringbuffer_frames;\n        const unsigned prev_frame_start = prev_frame_idx * num_states;\n\n        const unsigned from     = from_buffer [idx];\n        const float    prev_val = state_buffer[prev_frame_start + from];\n        if (isinf(prev_val)) {\n          return;\n        }\n\n        const unsigned sequence_idx = sequence_idxs[idx];\n        const int      batch_idx    = get_batch_idx(batch_idxs, num_seqs, time, sequence_idx);\n        if (batch_idx == -1) {\n          return;\n        }\n\n        const unsigned amss_idx       = min(static_cast<unsigned>(*epoch), num_am_score_scales - 1);\n        const float    am_score_scale = am_score_scales[amss_idx];\n\n        const unsigned to             = to_buffer    [idx];\n        const unsigned emission_idx   = emission_idxs[idx];\n        const unsigned lenmod_idx     = lenmod_idxs  [idx];\n        const float    edge_weight    = weight_buffer[idx];\n        const float    prev_plus_edge = prev_val + edge_weight;\n\n        float const* am_buffer_in    = am_scores     + batch_idx  * num_seg_frames * num_emissions + emission_idx;\n        float const* length_scores   = length_models + lenmod_idx * num_seg_frames;\n        float*       edge_buffer_out = edge_buffer   + idx;\n\n        for (unsigned i = 0u; i < max_seg_frames; i++) {\n          const float val = prev_plus_edge + am_score_scale * am_buffer_in[i * num_emissions] + length_scores[i];\n          edge_buffer_out[i * num_edges] = val;\n          const unsigned next_frame = (prev_frame_idx + 1 + i) % num_ringbuffer_frames;\n          atomic_prob_add(state_buffer + (next_frame * num_states + to), val);\n        }\n      }\n    ', '103_next_frame_bwd': '\n      __global__\n      void next_frame_bwd(unsigned time, unsigned num_states, unsigned num_edges, unsigned num_emissions,\n                          unsigned num_seg_frames,\n                          unsigned num_tot_frames, unsigned num_seqs, unsigned num_am_score_scales,\n                          unsigned const* sequence_idxs, unsigned const* from_buffer, unsigned const* to_buffer,\n                          float const* weight_buffer,\n                          unsigned const* emission_idxs, unsigned const* lenmod_idxs, int const* batch_idxs,\n                          float const* am_scores, float const* length_models, float const* am_score_scales,\n                          float const* epoch,\n                          float* state_buffer, float* edge_buffer) {\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_edges) {\n          return;\n        }\n\n        const unsigned num_ringbuffer_frames = num_seg_frames + 1;\n        const unsigned max_seg_frames        = min(num_seg_frames, num_tot_frames - time);\n\n        const unsigned sequence_idx = sequence_idxs[idx];\n        const int      batch_idx    = get_batch_idx(batch_idxs, num_seqs, time, sequence_idx);\n        if (batch_idx == -1) {\n          return;\n        }\n\n        const unsigned amss_idx       = min(static_cast<unsigned>(*epoch), num_am_score_scales - 1);\n        const float    am_score_scale = am_score_scales[amss_idx];\n\n        const unsigned from           = from_buffer  [idx];\n        const unsigned to             = to_buffer    [idx];\n        const unsigned emission_idx   = emission_idxs[idx];\n        const unsigned lenmod_idx     = lenmod_idxs  [idx];\n        const float    edge_weight    = weight_buffer[idx];\n        const unsigned next_frame_idx = time % num_ringbuffer_frames;\n\n        float const*   am_buffer_in    = am_scores     + batch_idx  * num_seg_frames * num_emissions + emission_idx;\n        float const*   length_scores   = length_models + lenmod_idx * num_seg_frames;\n        float*         edge_buffer_out = edge_buffer   + idx;\n\n        float acc_val = CUDART_INF_F;\n\n        for (unsigned i = 0u; i < max_seg_frames; i++) {\n          const unsigned prev_frame_idx = (next_frame_idx + i + 1) % num_ringbuffer_frames;\n          const float    prev_val       = state_buffer[prev_frame_idx * num_states + from];\n          if (isinf(prev_val)) {\n            edge_buffer_out[i * num_edges] = CUDART_INF_F;\n          }\n          else {\n            const float val =\n              prev_val + edge_weight + am_score_scale * am_buffer_in[i * num_emissions] + length_scores[i];\n            edge_buffer_out[i * num_edges] += prev_val;\n            acc_val = prob_add(acc_val, val);\n          }\n        }\n\n        atomic_prob_add(state_buffer + next_frame_idx * num_states + to, acc_val);\n      }\n    ', '104_compute_framewise_sum': '\n      __global__\n      void compute_framewise_sum(unsigned num_tot_frames, unsigned num_seqs, unsigned num_seg_frames,\n                                 unsigned num_batches, unsigned num_edges,\n                                 unsigned const* sequence_idxs, int const* batch_idxs, float const* index,\n                                 float const* edge_buffer,\n                                 float* output_buffer) {\n        extern __shared__ float sum[];\n\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_tot_frames * num_seg_frames) {\n          return;\n        }\n\n        float* sum_buffer = sum + threadIdx.x * num_seqs;\n        edge_buffer += idx * num_edges;\n\n        for (unsigned s = 0u; s < num_seqs; s++) {\n          sum_buffer[s] = CUDART_INF_F;\n        }\n\n        for (unsigned i = 0; i < num_edges; i++) {\n          const unsigned seq_idx = sequence_idxs[i];\n          sum_buffer[seq_idx] = prob_add(sum_buffer[seq_idx], edge_buffer[i]);\n        }\n\n        const unsigned time     = idx / num_seg_frames;\n        const unsigned seg_size = idx % num_seg_frames;\n        for (unsigned s = 0u; s < num_seqs; s++) {\n          const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, s);\n          if (batch_idx >= 0) {\n            const unsigned output_idx = seg_size * num_batches + batch_idx;\n            if (isinf(sum_buffer[s]) or index[output_idx] == 0.0) {\n              output_buffer[output_idx] = 0.0;\n            }\n            else {\n              output_buffer[output_idx] = sum_buffer[s];\n            }\n          }\n        }\n      }\n    ', '105_merge_framewise_sums': '\n      __global__\n      void merge_framewise_sum(unsigned num_seg_frames, unsigned num_batches, float const* index, float* sum_buffer) {\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_batches) {\n          return;\n        }\n\n        sum_buffer += idx;\n        index += idx;\n\n        float sum = sum_buffer[0];\n        for (unsigned s = 1; s < num_seg_frames; s++) {\n          if (index[s * num_batches] != 0.0f) {\n            sum = prob_add(sum, sum_buffer[s * num_batches]);\n          }\n        }\n\n        for (unsigned s = 0; s < num_seg_frames; s++) {\n          if (index[s * num_batches] != 0.0f) {\n            sum_buffer[s * num_batches] = sum;\n          }\n        }\n      }\n    ', '106_compute_targets': '\n      __global__\n      void compute_targets(unsigned num_tot_frames, unsigned num_seg_frames, unsigned num_edges, unsigned num_batches,\n                           unsigned num_seqs, unsigned num_emissions,\n                           unsigned const* sequence_idxs, unsigned const* emission_idxs, int const* batch_idxs,\n                           float const* index,\n                           float const* edge_buffer, float const* normalization_buffer, float* output_buffer) {\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_tot_frames * num_seg_frames * num_edges) {\n          return;\n        }\n\n        const unsigned edge_idx  = idx % num_edges;\n        const unsigned time      = idx / (num_edges * num_seg_frames);\n        const unsigned seq_idx   = sequence_idxs[edge_idx];\n        const int      batch_idx = get_batch_idx(batch_idxs, num_seqs, time, seq_idx);\n\n        if (batch_idx < 0) {\n          return;\n        }\n\n        const unsigned seg_length = (idx / num_edges) % num_seg_frames;\n\n        if (index[seg_length * num_batches + batch_idx] == 0.0) {\n          return;\n        }\n\n        const unsigned emission_idx  = emission_idxs[edge_idx];\n        const float    normalization = normalization_buffer[seg_length * num_batches + batch_idx];\n\n        atomic_prob_add(\n          output_buffer + seg_length * num_batches * num_emissions + batch_idx * num_emissions + emission_idx,\n          edge_buffer[idx] - normalization);\n      }\n    ', '107_compute_posterior_weights': '\n    __global__\n    void compute_posterior_weights(unsigned num_tot_frames, unsigned num_seg_frames, unsigned num_seqs,\n                                   unsigned num_batches,\n                                   float const* state_buffer, unsigned const* start_states, int const* batch_idxs,\n                                   float const* index, float const* normalization_factors, float* posterior_weigths) {\n        const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n        if (idx >= num_tot_frames * num_seqs) {\n          return;\n        }\n\n        const unsigned time    = idx / num_seqs;\n        const unsigned seq_idx = idx % num_seqs;\n\n        const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, seq_idx);\n        if (batch_idx < 0) {\n          return;\n        }\n\n        const float seq_sum = state_buffer[start_states[seq_idx]];\n        for (unsigned s = 0u; s < num_seg_frames; s++) {\n          const unsigned i = s * num_batches + batch_idx;\n          if (index[i] == 0.0) {\n            return;\n          }\n          posterior_weigths[i] = exp(-(normalization_factors[i] - seq_sum));\n        }\n    }\n    '}[source]#
in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'batch_idxs', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((2, 1),)}, {'gradient': 'disconnected', 'name': 'length_models', 'ndim': 2, 'need_contiguous': True, 'shape': (None, (0, 0))}, {'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'am_score_scales', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'epoch', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
c_fw_code: str = '\n    // inputs:  am_scores, batch_idxs, edges, weights, length_models, start_end_states, index, am_score_scales, epoch\n    // outputs: output, normalization_factors, posterior_weigths\n    assert(n_inputs  == 9);\n    assert(n_outputs == 3);\n    Ndarray* ary_am_scores         = inputs[0];\n    Ndarray* ary_batch_idxs        = inputs[1];\n    Ndarray* ary_edges             = inputs[2];\n    Ndarray* ary_weights           = inputs[3];\n    Ndarray* ary_start_end_states  = inputs[4];\n    Ndarray* ary_length_models     = inputs[5];\n    Ndarray* ary_index             = inputs[6];\n    Ndarray* ary_am_score_scales   = inputs[7];\n    Ndarray* ary_epoch             = inputs[8];\n    Ndarray* ary_out               = *outputs[0];\n    Ndarray* ary_norm_factors      = *outputs[1];\n    Ndarray* ary_posterior_weights = *outputs[2];\n\n    assert(Ndarray_DIMS(ary_edges)[1] == Ndarray_DIMS(ary_weights)[0]);\n\n    static unsigned iter = 0u; // used for debug output\n\n    float*    d_am_scores         = Ndarray_DEV_DATA(ary_am_scores);\n    int*      d_batch_idxs        = reinterpret_cast<int*>(Ndarray_DEV_DATA(ary_batch_idxs));\n    unsigned* d_from              =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 0 * Ndarray_STRIDE(ary_edges, 0));\n    unsigned* d_to                =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 1 * Ndarray_STRIDE(ary_edges, 0));\n    unsigned* d_emission_idxs     =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 2 * Ndarray_STRIDE(ary_edges, 0));\n    unsigned* d_lenmod_idxs       =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 3 * Ndarray_STRIDE(ary_edges, 0));\n    unsigned* d_sequence_idxs     =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 4 * Ndarray_STRIDE(ary_edges, 0));\n    float*    d_weights           = Ndarray_DEV_DATA(ary_weights);\n    float*    d_length_models     = Ndarray_DEV_DATA(ary_length_models);\n    unsigned* d_start_states      =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_start_end_states) + 0 * Ndarray_STRIDE(ary_start_end_states, 0));\n    unsigned* d_end_states        =\n      reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_start_end_states) + 1 * Ndarray_STRIDE(ary_start_end_states, 0));\n    float*    d_index             = Ndarray_DEV_DATA(ary_index);\n    float*    d_am_score_scales   = Ndarray_DEV_DATA(ary_am_score_scales);\n    float*    d_epoch             = Ndarray_DEV_DATA(ary_epoch);\n    float*    d_out               = Ndarray_DEV_DATA(ary_out);\n    float*    d_norm_factors      = Ndarray_DEV_DATA(ary_norm_factors);\n    float*    d_posterior_weights = Ndarray_DEV_DATA(ary_posterior_weights);\n\n    std::vector<int> seq_lengths;\n    if (NEW_BATCH_IDX_FORMAT) {\n      seq_lengths.resize(Ndarray_DIMS(ary_batch_idxs)[0]);\n      HANDLE_ERROR(cudaMemcpy(\n        seq_lengths.data(), d_batch_idxs, seq_lengths.size() * sizeof(int), cudaMemcpyDeviceToHost));\n    }\n\n    const unsigned n_seg_frames      = Ndarray_DIMS(ary_am_scores)[0];\n    const unsigned n_batches         = Ndarray_DIMS(ary_am_scores)[1];\n    const unsigned n_emissions       = Ndarray_DIMS(ary_am_scores)[2];\n    const unsigned n_seqs            =\n      NEW_BATCH_IDX_FORMAT ? (Ndarray_DIMS(ary_batch_idxs)[0] - 1) : Ndarray_DIMS(ary_batch_idxs)[1];\n    const unsigned n_tot_frames      =\n      NEW_BATCH_IDX_FORMAT ? seq_lengths.back()                     : Ndarray_DIMS(ary_batch_idxs)[0];\n    const unsigned n_edges           = Ndarray_DIMS(ary_edges)[1];\n    const unsigned n_length_models   = Ndarray_DIMS(ary_length_models)[1];\n    const unsigned n_am_score_scales = Ndarray_DIMS(ary_am_score_scales)[0];\n    const unsigned n_threads         = 1024u;\n    unsigned       n_blocks          = (n_edges + n_threads - 1) / n_threads;\n\n    unsigned tmp;\n    HANDLE_ERROR(cudaMemcpy(&tmp, d_end_states + n_seqs - 1, sizeof(float), cudaMemcpyDeviceToHost));\n\n    const unsigned n_states = tmp + 1;\n\n    /*std::cerr << "seg frames: "    << n_seg_frames    << std::endl;\n    std::cerr << "batches: "       << n_batches       << std::endl;\n    std::cerr << "emissions: "     << n_emissions     << std::endl;\n    std::cerr << "tot frames: "    << n_tot_frames    << std::endl;\n    std::cerr << "seqs: "          << n_seqs          << std::endl;\n    std::cerr << "edges: "         << n_edges         << std::endl;\n    std::cerr << "length models: " << n_length_models << std::endl;\n    std::cerr << "threads: "       << n_threads       << std::endl;\n    std::cerr << "blocks: "        << n_blocks        << std::endl;\n    std::cerr << "num states: "    << n_states        << std::endl;*/\n\n    // initialize edge buffer\n    const unsigned edge_buffer_size = n_tot_frames * n_seg_frames * n_edges;\n    float* d_edge_buffer  = reinterpret_cast<float*>(device_malloc(edge_buffer_size * sizeof(float)));\n    HANDLE_LAST_ERROR();\n    unsigned n_fill_blocks = (edge_buffer_size + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(d_edge_buffer, std::numeric_limits<float>::infinity(), edge_buffer_size);\n    HANDLE_LAST_ERROR();\n\n    // initialize the state buffer\n    const unsigned n_ringbuffer_frames = n_seg_frames + 1;\n    float* d_state_buffer = reinterpret_cast<float*>(device_malloc(n_states * n_ringbuffer_frames * sizeof(float)));\n    HANDLE_LAST_ERROR();\n    n_fill_blocks = (n_states * n_ringbuffer_frames + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(\n      d_state_buffer, std::numeric_limits<float>::infinity(), n_states * n_ringbuffer_frames);\n    HANDLE_LAST_ERROR();\n\n    // initialize sum buffer and posterior weigths\n    n_fill_blocks = (n_batches * n_seg_frames + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(d_norm_factors, 0.0f, n_batches * n_seg_frames);\n    HANDLE_LAST_ERROR();\n    fill_array<<<n_fill_blocks, n_threads>>>(d_posterior_weights, 0.0f, n_batches * n_seg_frames);\n    HANDLE_LAST_ERROR();\n\n    set_start_states<<<1, n_seqs>>>(d_state_buffer, d_start_states);\n    HANDLE_LAST_ERROR();\n\n    // fwd pass\n    for (unsigned t = 0u; t < n_tot_frames; t++) {\n      //std::cerr << "fwd t: " << t << " " << n_tot_frames << std::endl;\n      float* d_state_buffer_prev = d_state_buffer + ((t - 1) %% n_ringbuffer_frames) * n_states;\n      fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n      HANDLE_LAST_ERROR();\n      next_frame_fwd<<<n_blocks, n_threads>>>(t, n_states, n_edges, n_emissions, n_seg_frames, n_tot_frames, n_seqs,\n                                              n_am_score_scales,\n                                              d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs, d_lenmod_idxs,\n                                              d_batch_idxs,\n                                              d_am_scores, d_length_models, d_am_score_scales, d_epoch,\n                                              d_state_buffer, d_edge_buffer + t * n_seg_frames * n_edges);\n      HANDLE_LAST_ERROR();\n\n      //std::stringstream ss;\n      //ss << "dump/fwd_state_buffer." << t << ".dump";\n      //dump_to_file_2d(d_state_buffer, n_ringbuffer_frames, n_states, ss.str());\n    }\n\n    //dump_to_file_3d(d_edge_buffer, n_tot_frames, n_seg_frames, n_edges, "dump/fwd_edges.dump");\n\n    // bwd pass\n    n_fill_blocks = (n_states * n_ringbuffer_frames + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(\n      d_state_buffer, std::numeric_limits<float>::infinity(), n_states * n_ringbuffer_frames);\n    HANDLE_LAST_ERROR();\n    n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n    for (unsigned t = n_tot_frames; t > 0; t--) {\n      //std::cerr <<\n      //"bwd t: " << t << " " << n_tot_frames << " buffer next: " << ((t-1) %% n_ringbuffer_frames) << std::endl;\n      float* d_state_buffer_next = d_state_buffer + ((t - 1) %% n_ringbuffer_frames) * n_states;\n      float* d_state_buffer_prev = d_state_buffer + ( t      %% n_ringbuffer_frames) * n_states;\n      fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n      HANDLE_LAST_ERROR();\n      init_bwd_state_buffer<<<1, n_seqs>>>(\n        t - 1, n_batches, n_seqs, d_batch_idxs, d_index, d_state_buffer_prev, d_end_states);\n      HANDLE_LAST_ERROR();\n      next_frame_bwd<<<n_blocks, n_threads>>>(\n        t - 1, n_states, n_edges, n_emissions, n_seg_frames, n_tot_frames, n_seqs, n_am_score_scales,\n        d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs, d_lenmod_idxs, d_batch_idxs,\n        d_am_scores, d_length_models, d_am_score_scales, d_epoch,\n        d_state_buffer, d_edge_buffer + (t - 1) * n_seg_frames * n_edges);\n      HANDLE_LAST_ERROR();\n\n      //std::stringstream ss;\n      //ss << "dump/bwd_state_buffer." << t << ".dump";\n      //dump_to_file_2d(d_state_buffer, n_ringbuffer_frames, n_states, ss.str());\n    }\n\n    n_blocks = (n_tot_frames * n_seg_frames + n_threads - 1) / n_threads;\n    compute_framewise_sum<<<n_blocks, n_threads, n_threads * n_seqs * sizeof(float)>>>(\n      n_tot_frames, n_seqs, n_seg_frames, n_batches, n_edges,\n      d_sequence_idxs, d_batch_idxs,\n      d_index, d_edge_buffer, d_norm_factors);\n    HANDLE_LAST_ERROR();\n\n    //dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, "dump/norm_factors_1.dump");\n\n    if (segmentwise_normalization) {\n      n_blocks = (n_batches + n_threads - 1) / n_threads;\n      merge_framewise_sum<<<n_blocks, n_threads>>>(n_seg_frames, n_batches, d_index, d_norm_factors);\n      HANDLE_LAST_ERROR();\n    }\n\n    //dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, "dump/norm_factors_2.dump");\n\n    n_blocks = (n_tot_frames * n_seqs + n_threads - 1) / n_threads;\n    compute_posterior_weights<<<n_blocks, n_threads>>>(n_tot_frames, n_seg_frames, n_seqs, n_batches, d_state_buffer,\n                                                       d_start_states, d_batch_idxs, d_index, d_norm_factors,\n                                                       d_posterior_weights);\n    HANDLE_LAST_ERROR();\n\n    n_fill_blocks = (n_batches * n_seg_frames * n_emissions + n_threads - 1u) / n_threads;\n    fill_array<<<n_fill_blocks, n_threads>>>(\n      d_out, std::numeric_limits<float>::infinity(), n_batches * n_seg_frames * n_emissions);\n    HANDLE_LAST_ERROR();\n\n    n_blocks = (n_tot_frames * n_seg_frames * n_edges + n_threads - 1) / n_threads;\n    compute_targets<<<n_blocks, n_threads>>>(n_tot_frames, n_seg_frames, n_edges, n_batches, n_seqs, n_emissions,\n                                             d_sequence_idxs, d_emission_idxs, d_batch_idxs, d_index, d_edge_buffer,\n                                             d_norm_factors, d_out);\n    HANDLE_LAST_ERROR();\n\n    //dump_to_file_1d(d_weights,       n_edges, "dump/edge_weights.dump");\n    //dump_to_file_1d(d_sequence_idxs, n_edges, "dump/sequence_idxs.dump");\n    //dump_to_file_2d(d_state_buffer,  n_ringbuffer_frames, n_states,  "dump/state_buffer.dump");\n    //dump_to_file_2d(d_batch_idxs,    n_tot_frames,        n_seqs,    "dump/batch_idxs.dump");\n    //dump_to_file_2d(d_index,         n_seg_frames,        n_batches, "dump/index.dump");\n    //dump_to_file_3d(d_edge_buffer,   n_tot_frames,        n_seg_frames, n_edges,     "dump/edges.dump");\n    //dump_to_file_3d(d_am_scores,     n_seg_frames,        n_batches,    n_emissions, "dump/am_scores.dump");\n    //dump_to_file_3d(d_out,           n_seg_frames,        n_batches,    n_emissions, "dump/targets.dump");\n\n    if (dump_targets and iter %% dump_targets_interval == 0) {\n      std::stringstream ss;\n      ss << "dump/targets_" << iter << ".dump";\n      dump_to_file_3d(d_out, n_seg_frames, n_batches, n_emissions, ss.str());\n      ss.str("");\n      ss.clear();\n      ss << "dump/norm_factors_" << iter << ".dump";\n      dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, ss.str());\n      ss.str("");\n      ss.clear();\n      ss << "dump/posterior_weights_" << iter << ".dump";\n      dump_to_file_2d(d_posterior_weights, n_seg_frames, n_batches, ss.str());\n    }\n\n    iter += 1;\n\n    device_free(d_state_buffer);\n    device_free(d_edge_buffer);\n  '[source]#
class returnn.native_op.FastViterbiOp[source]#
inputs:
param am_scores:

scores in +log space. 3d (time,batch,dim)

param am_seq_len:

(batch,)

param edges:

edges of the graph (from,to,emission_idx,sequence_idx), i.e. (4, n_edges)

param weights:

weights of the edges (n_edges,)

param start_end_states:

(2, batch)

param n_states:

scalar, int32

outputs:
param output:

Viterbi (hard) alignment, scores in +log space. 2d (time,batch)

param scores:

(batch,)

in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'am_seq_len', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 0),)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (4, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((3, 1),)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, (0, 0))}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'n_states', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'name': 'scores', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 1),)})[source]#
c_extra_support_code: Dict[str, str] = {'01_IdxAndVal': '\n      struct __attribute__((__packed__)) IdxAndVal {\n        int idx;\n        float val;\n      };\n    ', '04_select_max': '\n      DEV_FUNC\n      void select_max(IdxAndVal* a, IdxAndVal b) {\n        // fast path\n        if(b.val < a->val)\n          return;\n        // Maybe we could use double compare-and-swap (https://stackoverflow.com/questions/55941382/).\n        // But not sure how.\n        // So instead, we use double-wide compare-and-swap.\n        union U {\n          IdxAndVal s;\n          unsigned long long int v64;\n        };\n        while(true) {\n          U prev;\n          prev.s = *a;\n          if(b.val < prev.s.val)\n            return;\n          if(b.val == prev.s.val && b.idx >= prev.s.idx)\n            return;\n          U updated;\n          updated.s = b;\n\n          U old;\n          old.v64 = elem_atomic_cas((unsigned long long int*) a, prev.v64, updated.v64);\n          if(old.v64 == prev.v64)\n            return;\n          // Not the same, so repeat.\n        }\n      }\n    ', '05_init_buffer': '\n      DEF_KERNEL\n      void init_buffer\n      (\n        int n_time,\n        int n_states, // for the whole batch\n        IdxAndVal* buffer // (time+1,n_states), states for the whole batch\n      )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < (n_time + 1) * n_states) {\n          buffer[idx].val = -INF_F;\n          buffer[idx].idx = -1;\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '06_init_first_frame': '\n      DEF_KERNEL\n      void init_first_frame\n      (\n        int n_batch,\n        int n_states, // for the whole batch\n        IdxAndVal* frame, // (n_states,), states for the whole batch\n        const int32_t* d_start_states // (n_batch,)\n      )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch) {\n          int state_idx = d_start_states[idx];\n          frame[state_idx].val = 0;\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '08_next_frame': '\n      DEF_KERNEL\n      void next_frame\n      (\n        int n_time,\n        int n_states,\n        int n_edges,\n        int n_classes,\n        int t,\n        const float* d_am_scores,\n        const int32_t* d_am_seq_len,\n        const IdxAndVal* prev_frame,\n        IdxAndVal* frame,\n        const int32_t* d_edge_from,\n        const int32_t* d_edge_to,\n        const int32_t* d_edge_emission_idx,\n        const int32_t* d_edge_seq_idx,\n        const float* d_edge_weights,\n        const int32_t* d_end_states // (n_batch,)\n      )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_edges) {\n          int from_idx = d_edge_from[idx];\n          //assert_cmp(0, <=, from_idx); assert_cmp(from_idx, <, n_states);\n\n          int seq_idx = d_edge_seq_idx[idx];\n          if(t < d_am_seq_len[seq_idx]) {\n            float prev_val = prev_frame[from_idx].val;\n            int emission_idx = d_edge_emission_idx[idx];\n            //assert_cmp(0, <=, emission_idx); assert_cmp(emission_idx, <, n_classes);\n            int to_idx = d_edge_to[idx];\n            //assert_cmp(0, <=, to_idx); assert_cmp(to_idx, <, n_states);\n            IdxAndVal candidate;\n            candidate.val = prev_val + d_edge_weights[idx] + d_am_scores[seq_idx * n_classes + emission_idx];\n            candidate.idx = idx;\n            select_max(&frame[to_idx], candidate);\n          }\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '11_select_scores': '\n      DEF_KERNEL\n      void select_scores\n      (\n        int n_batch,\n        int n_states,\n        int buffer_stride,\n        const IdxAndVal* buffer,\n        const int32_t* d_am_seq_len, // (n_batch,)\n        const int32_t* d_end_states, // (n_batch,)\n        float* d_score // (n_batch,)\n      )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch) {\n          const IdxAndVal* last_frame = buffer + d_am_seq_len[idx] * buffer_stride;\n          int end_state_idx = d_end_states[idx];\n          d_score[idx] = last_frame[end_state_idx].val;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '13_select_best_path': '\n      DEF_KERNEL\n      void select_best_path\n      (\n        int n_batch,\n        int n_states,\n        int n_edges,\n        int t,\n        int32* cur_state, // (n_batch,)\n        const IdxAndVal* frame,\n        const int32_t* d_am_seq_len,\n        const int32_t* d_edge_from,\n        const int32_t* d_edge_to,\n        const int32_t* d_edge_emission_idx,\n        int32_t* output\n      )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch) {\n          if(t < d_am_seq_len[idx]) {\n            int state_idx = cur_state[idx];\n            //assert_cmp(0, <=, state_idx); assert_cmp(state_idx, <, n_states);\n            int edge_idx = frame[state_idx].idx;\n            if(edge_idx >= 0) {\n              //assert_cmp(0, <=, edge_idx); assert_cmp(edge_idx, <, n_edges);\n              //assert_cmp(state_idx, ==, d_edge_to[edge_idx]);\n              cur_state[idx] = d_edge_from[edge_idx];\n              output[idx] = d_edge_emission_idx[edge_idx];\n            }\n            else  // no path found\n              output[idx] = 0;\n          }\n          else {\n            output[idx] = 0;\n          }\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    using namespace std;\n    // am_scores, am_seq_len, edges, weights, start_end_states, n_states = input_names\n    // output, scores = output_names\n    assert(n_inputs == 6);\n    assert(n_outputs == 2);\n    Ndarray* am_scores = inputs[0];\n    Ndarray* am_seq_len = inputs[1];\n    Ndarray* edges = inputs[2];\n    Ndarray* weights = inputs[3];\n    Ndarray* start_end_states = inputs[4];\n    Ndarray* n_states_ref = inputs[5];\n    Ndarray* output = *outputs[0];\n    Ndarray* score = *outputs[1];\n\n    assert_cmp(Ndarray_NDIM(am_scores), ==, 3);\n    assert_cmp(Ndarray_NDIM(am_seq_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(edges), ==, 2);\n    assert_cmp(Ndarray_NDIM(weights), ==, 1);\n    assert_cmp(Ndarray_NDIM(start_end_states), ==, 2);\n    assert_cmp(Ndarray_NDIM(n_states_ref), ==, 0);\n    assert_cmp(Ndarray_NDIM(output), ==, 2);\n    assert_cmp(Ndarray_NDIM(score), ==, 1);\n    int n_time = Ndarray_DIMS(am_scores)[0];\n    int n_batch = Ndarray_DIMS(am_scores)[1];\n    int n_classes = Ndarray_DIMS(am_scores)[2];\n    assert_cmp(Ndarray_DIMS(am_scores)[0], ==, n_time);\n    assert_cmp(Ndarray_DIMS(am_scores)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(am_scores)[2], ==, n_classes);\n    assert_cmp(Ndarray_DIMS(am_seq_len)[0], ==, n_batch);\n    int n_edges = Ndarray_DIMS(edges)[1];\n    assert_cmp(Ndarray_DIMS(edges)[0], ==, 4);\n    assert_cmp(Ndarray_DIMS(edges)[1], ==, n_edges);\n    assert_cmp(Ndarray_DIMS(weights)[0], ==, n_edges);\n    assert_cmp(Ndarray_DIMS(start_end_states)[0], ==, 2);\n    assert_cmp(Ndarray_DIMS(start_end_states)[1], ==, n_batch);\n    int n_states = Ndarray_DEV_DATA_int32_scalar(n_states_ref);\n    assert_cmp(Ndarray_DIMS(output)[0], ==, n_time);\n    assert_cmp(Ndarray_DIMS(output)[1], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(score)[0], ==, n_batch);\n\n    int32_t* d_edge_from = Ndarray_DEV_DATA_int32(edges) + 0 * Ndarray_STRIDE(edges, 0);\n    int32_t* d_edge_to = Ndarray_DEV_DATA_int32(edges) + 1 * Ndarray_STRIDE(edges, 0);\n    int32_t* d_edge_emission_idx = Ndarray_DEV_DATA_int32(edges) + 2 * Ndarray_STRIDE(edges, 0);\n    int32_t* d_edge_seq_idx = Ndarray_DEV_DATA_int32(edges) + 3 * Ndarray_STRIDE(edges, 0);\n    float* d_edge_weights = Ndarray_DEV_DATA(weights);\n    float* d_am_scores = Ndarray_DEV_DATA(am_scores);\n    int am_scores_stride = Ndarray_STRIDE(am_scores, 0);\n    int32_t* d_am_seq_len = Ndarray_DEV_DATA_int32(am_seq_len);\n    int32_t* d_start_states = Ndarray_DEV_DATA_int32(start_end_states) + 0 * Ndarray_STRIDE(start_end_states, 0);\n    int32_t* d_end_states = Ndarray_DEV_DATA_int32(start_end_states) + 1 * Ndarray_STRIDE(start_end_states, 0);\n    int32_t* d_output = Ndarray_DEV_DATA_int32(output);\n    int output_stride = Ndarray_STRIDE(output, 0);\n    float* d_score = Ndarray_DEV_DATA(score);\n\n    IdxAndVal* d_buffer = (IdxAndVal*) device_malloc((n_time + 1) * n_states * sizeof(IdxAndVal));\n    int buffer_stride = n_states;\n    start_dev_kernel(init_buffer, (n_time, n_states, d_buffer));\n    start_dev_kernel(init_first_frame, (n_batch, n_states, d_buffer, d_start_states));\n    HANDLE_LAST_ERROR();\n\n    for(int t = 0; t < n_time; ++t) {\n      start_dev_kernel(next_frame, (\n        n_time,\n        n_states,\n        n_edges,\n        n_classes,\n        t,\n        d_am_scores + t * am_scores_stride,\n        d_am_seq_len,\n        d_buffer + t * buffer_stride,\n        d_buffer + (t + 1) * buffer_stride,\n        d_edge_from,\n        d_edge_to,\n        d_edge_emission_idx,\n        d_edge_seq_idx,\n        d_edge_weights,\n        d_end_states\n      ));\n    }\n    HANDLE_LAST_ERROR();\n\n    start_dev_kernel(select_scores, (\n      n_batch,\n      n_states,\n      buffer_stride,\n      d_buffer,\n      d_am_seq_len,\n      d_end_states,\n      d_score // out\n    ));\n\n    int32_t* d_cur_state = (int32_t*) device_malloc(n_batch * sizeof(int32_t));\n    Ndarray_memcpy(d_cur_state, d_end_states, n_batch * sizeof(int32_t));\n\n    for(int t = n_time - 1; t >= 0; --t) {\n      start_dev_kernel(select_best_path, (\n        n_batch,\n        n_states,\n        n_edges,\n        t,\n        d_cur_state,\n        d_buffer + (t + 1) * buffer_stride,\n        d_am_seq_len,\n        d_edge_from,\n        d_edge_to,\n        d_edge_emission_idx,\n        d_output + t * output_stride // out\n      ));\n    }\n    HANDLE_LAST_ERROR();\n\n    device_free(d_cur_state);\n    device_free(d_buffer);\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.GetCtcFsaFastBwOp[source]#

This implements Fsa.get_ctc_fsa_fast_bw() as a native op. This is for constructing a FSA with a CTC topology. The output format is compatible to the FastBaumWelch native op.

inputs:
param targets:

shape (batch,time), int32

param seq_lens:

shape (batch), int32

param blank_idx:

scalar, int32

param weights:

shape (num_edges,), float32 (not used, except for target shape)

param label_loop:

scalar, int32 (casted from bool). True -> normal CTC; False -> RNA-like

outputs:
param edges:

(4,num_edges), int32, edges of the graph (from,to,emission_idx,sequence_idx)

param start_end_states:

(2,batch), int32, (start,end) state idx in FSA

To construct weights (for FastBaumWelch), weights should be just tf.zeros((num_edges,)). num_edges should be n_batch * (5 * (n_time - 1) + 10)

(see construction in kernel why that number).

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'targets', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'seq_lens', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'blank_idx', 'ndim': 0, 'need_contiguous': True, 'shape': ()}, {'dtype': 'float32', 'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'label_loop', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (4, (3, 0))}, {'dtype': 'int32', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, (1, 0))})[source]#
c_extra_support_code: Dict[str, str] = {'01_kernel': '\n      template<bool label_loop>\n      DEF_KERNEL\n      void construct_kernel\n        (\n        int n_batch, int n_time, int n_edges,\n        const int32_t* targets, const int32_t* seq_lens,\n        int32_t blank_idx,\n        int32_t* edges, int32_t* start_end_states\n        )\n      {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        // n_edges should be n_batch * (5 * (n_time - 1) + 10).\n        assert(n_edges % n_batch == 0);\n        while(idx < n_edges) {\n          int batch_idx = idx / (n_edges / n_batch);\n          int rel_edge_idx = idx % (n_edges / n_batch);\n          int32_t seq_len = seq_lens[batch_idx];\n          // state_idx: 0 b, 1 l, 2 b, 3 l, ..., (T-1)*2 b, T*2-1 l, T*2 b, T*2+1 dummy, T*2+2 end\n          // i.e. T*2+3 states per seq.\n          int state_idx_offset = (n_time * 2 + 3) * batch_idx;\n          int t = -1; // pos in targets\n          int srel_edge_idx = -1; // state relative edge\n          // (seq_len * 2) - 1 is last label state idx. seq_len * 2 is last blank state idx.\n          int32_t dummy_state_idx = seq_len * 2 + 1;\n          int32_t end_state_idx = seq_len * 2 + 2;\n          int32_t state_idx = dummy_state_idx;\n          int32_t to_state_idx = dummy_state_idx;\n          if(rel_edge_idx == 0) {\n            start_end_states[0 * n_batch + batch_idx] = state_idx_offset; // start\n            start_end_states[1 * n_batch + batch_idx] = state_idx_offset + end_state_idx; // end\n          }\n          int32_t emission_idx = blank_idx;\n          int32_t label_idx = -1, next_label_idx = -1;\n          if(seq_len == 0) {\n            t = -1;\n            emission_idx = blank_idx;\n            // 1 single blank loop\n            if(rel_edge_idx == 0) {\n              state_idx = 0;\n              to_state_idx = 0;\n              srel_edge_idx = 0;\n            }\n            else if(rel_edge_idx == 1) {\n              state_idx = 0;\n              to_state_idx = end_state_idx;\n              srel_edge_idx = 1;\n            }\n            else {\n              state_idx = dummy_state_idx;\n              srel_edge_idx = -1;\n            }\n          }\n          else if(seq_len == 1) {\n            label_idx = targets[batch_idx * n_time + 0];\n            // 3 edges for first / prev last blank\n            if(rel_edge_idx < 3) {\n              t = 0;\n              state_idx = 0;\n              srel_edge_idx = rel_edge_idx;\n              if(srel_edge_idx == 0) {\n                to_state_idx = state_idx;\n                emission_idx = blank_idx;\n              }\n              else if(srel_edge_idx == 1) {\n                to_state_idx = state_idx + 1;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 2) {\n                to_state_idx = end_state_idx;\n                emission_idx = label_idx;\n              }\n            }\n            // 4 edges for first / last label\n            else if(rel_edge_idx < 7) {\n              t = 0;\n              state_idx = 1;\n              srel_edge_idx = rel_edge_idx - 3;\n              if(srel_edge_idx == 0) {\n                to_state_idx = label_loop ? state_idx : dummy_state_idx;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 1) {\n                to_state_idx = state_idx + 1;\n                emission_idx = blank_idx;\n              }\n              else if(srel_edge_idx == 2) {\n                to_state_idx = label_loop ? end_state_idx : dummy_state_idx;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 3) {\n                to_state_idx = end_state_idx;\n                emission_idx = blank_idx;\n              }\n            }\n            // 2 edges for last blank\n            else if(rel_edge_idx < 9) {\n              t = -1;\n              emission_idx = blank_idx;\n              state_idx = 2;\n              srel_edge_idx = rel_edge_idx - 7;\n              if(srel_edge_idx == 0)\n                to_state_idx = state_idx;\n              else\n                to_state_idx = end_state_idx;\n            }\n            else {\n              t = -1;\n              state_idx = dummy_state_idx;\n              srel_edge_idx = -1;\n            }\n          }\n          else { // seq_len >= 2\n            // 2 edges for each blank, 3 for each label. up to prev last.\n            if(rel_edge_idx < 5 * (seq_len - 1)) {\n              t = rel_edge_idx / 5;\n              label_idx = targets[batch_idx * n_time + t];\n              next_label_idx = targets[batch_idx * n_time + t + 1];\n              state_idx = 2 * (rel_edge_idx / 5);\n              srel_edge_idx = rel_edge_idx % 5;\n              if(srel_edge_idx >= 2) {\n                srel_edge_idx -= 2;\n                state_idx += 1;\n              }\n              if(state_idx % 2 == 0) { // blank loop state\n                if(srel_edge_idx == 0) {\n                  to_state_idx = state_idx;\n                  emission_idx = blank_idx;\n                }\n                else if(srel_edge_idx == 1) {\n                  to_state_idx = state_idx + 1;\n                  emission_idx = label_idx;\n                }\n              }\n              else { // label loop state\n                if(srel_edge_idx == 0) {\n                  to_state_idx = label_loop ? state_idx : dummy_state_idx;\n                  emission_idx = label_idx;\n                }\n                else if(srel_edge_idx == 1) {\n                  to_state_idx = state_idx + 1;\n                  emission_idx = blank_idx;\n                }\n                else if(srel_edge_idx == 2) {\n                  // skip over blank to next label (if allowed <=> next label is different)\n                  if(label_idx != next_label_idx || !label_loop) {\n                    to_state_idx = state_idx + 2;\n                    emission_idx = next_label_idx;\n                  }\n                }\n              }\n            }\n            // 1 more edge for prev last label\n            else if(rel_edge_idx == 5 * (seq_len - 1)) {\n              t = seq_len - 2;\n              label_idx = targets[batch_idx * n_time + t];\n              next_label_idx = targets[batch_idx * n_time + t + 1];\n              state_idx = (seq_len - 2) * 2 + 1;\n              srel_edge_idx = 3;\n              // skip over blank to next label / end state (if allowed <=> next label is different)\n              if(label_idx != next_label_idx || !label_loop) {\n                to_state_idx = end_state_idx;\n                emission_idx = next_label_idx;\n              }\n            }\n            // 3 edges for prev last blank\n            else if(rel_edge_idx <= 5 * (seq_len - 1) + 3) {\n              t = seq_len - 1;\n              label_idx = targets[batch_idx * n_time + t];\n              state_idx = (seq_len - 1) * 2;\n              srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 1);\n              if(srel_edge_idx == 0) {\n                to_state_idx = state_idx;\n                emission_idx = blank_idx;\n              }\n              else if(srel_edge_idx == 1) {\n                to_state_idx = state_idx + 1;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 2) {\n                to_state_idx = end_state_idx;\n                emission_idx = label_idx;\n              }\n            }\n            // 4 edges for last label\n            else if(rel_edge_idx <= 5 * (seq_len - 1) + 7) {\n              t = seq_len - 1;\n              label_idx = targets[batch_idx * n_time + t];\n              state_idx = (seq_len - 1) * 2 + 1;\n              srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 4);\n              if(srel_edge_idx == 0) {\n                to_state_idx = label_loop ? state_idx : dummy_state_idx;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 1) {\n                to_state_idx = state_idx + 1;\n                emission_idx = blank_idx;\n              }\n              else if(srel_edge_idx == 2) {\n                to_state_idx = label_loop ? end_state_idx : dummy_state_idx;\n                emission_idx = label_idx;\n              }\n              else if(srel_edge_idx == 3) {\n                to_state_idx = end_state_idx;\n                emission_idx = blank_idx;\n              }\n            }\n            // 2 edges for last blank\n            else if(rel_edge_idx <= 5 * (seq_len - 1) + 9) {\n              t = -1;\n              emission_idx = blank_idx;\n              state_idx = (seq_len - 1) * 2 + 2;\n              srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 8);\n              if(srel_edge_idx == 0)\n                to_state_idx = state_idx;\n              else\n                to_state_idx = end_state_idx;\n            }\n            else {\n              t = -1;\n              state_idx = dummy_state_idx;\n              srel_edge_idx = -1;\n            }\n          }\n\n          edges[0 * n_edges + idx] = state_idx_offset + state_idx; // from\n          edges[1 * n_edges + idx] = state_idx_offset + to_state_idx; // to\n          edges[2 * n_edges + idx] = emission_idx; // emission\n          edges[3 * n_edges + idx] = batch_idx; // batch\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 5);\n    assert(n_outputs == 2);\n    Ndarray* targets = inputs[0];\n    Ndarray* seq_lens = inputs[1];\n    Ndarray* blank_idx_ref = inputs[2];\n    Ndarray* weights = inputs[3];\n    bool label_loop = (bool) Ndarray_DEV_DATA_int32_scalar(inputs[4]);\n    Ndarray* edges = *outputs[0];\n    Ndarray* start_end_states = *outputs[1];\n    assert_cmp(Ndarray_NDIM(targets), ==, 2);\n    assert_cmp(Ndarray_NDIM(seq_lens), ==, 1);\n    assert_cmp(Ndarray_NDIM(blank_idx_ref), ==, 0);\n    assert_cmp(Ndarray_NDIM(weights), ==, 1);\n    assert_cmp(Ndarray_NDIM(edges), ==, 2);\n    assert_cmp(Ndarray_NDIM(start_end_states), ==, 2);\n    int n_batch = Ndarray_DIMS(seq_lens)[0];\n    assert_cmp(Ndarray_DIMS(targets)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(seq_lens)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(start_end_states)[1], ==, n_batch);\n    int n_time = Ndarray_DIMS(targets)[1];\n    int n_edges = Ndarray_DIMS(weights)[0];\n    assert_cmp(Ndarray_DIMS(start_end_states)[0], ==, 2);\n    assert_cmp(Ndarray_DIMS(edges)[0], ==, 4);\n    assert_cmp(Ndarray_DIMS(edges)[1], ==, n_edges);\n\n    assert_cmp(n_edges, ==, n_batch * (5 * (n_time - 1) + 10));\n\n    Ndarray_memset(Ndarray_DEV_DATA_int32(edges), 255, 4 * n_edges * sizeof(int32_t));\n    Ndarray_memset(Ndarray_DEV_DATA_int32(start_end_states), 255, 2 * n_batch * sizeof(int32_t));\n    int32_t blank_idx = Ndarray_DEV_DATA_int32_scalar(blank_idx_ref);\n\n    if(label_loop) {\n      start_dev_kernel(construct_kernel<true>, (\n        n_batch, n_time, n_edges,\n        Ndarray_DEV_DATA_int32(targets), Ndarray_DEV_DATA_int32(seq_lens),\n        blank_idx,\n        Ndarray_DEV_DATA_int32(edges), Ndarray_DEV_DATA_int32(start_end_states)\n      ));\n    } else {\n      start_dev_kernel(construct_kernel<false>, (\n        n_batch, n_time, n_edges,\n        Ndarray_DEV_DATA_int32(targets), Ndarray_DEV_DATA_int32(seq_lens),\n        blank_idx,\n        Ndarray_DEV_DATA_int32(edges), Ndarray_DEV_DATA_int32(start_end_states)\n      ));\n    }\n    HANDLE_LAST_ERROR();\n  '[source]#
class returnn.native_op.EditDistanceOp[source]#

Similar to tf.edit_distance(). Calculates the edit distance / Levenshtein distance.

The naive implementation either goes over a and then b, thus results in O(|a|*|b|) time complexity. To calculate a new entry in the table (over then length of a and b), it depends on the prev symbol in a (left) (deletion error), the prev symbol in b (up) (insertion error), and the left-up diagonal (substitution error, or no error).

To take advantage of the parallelism of the GPU, we follow a diagonal iteration scheme, such that in every iteration, all entries on the diagonal can be computed in parallel, as they do not depend on each other. After implementing this, we found that this algorithm is described here:

Using GPUs to Speed-Up Levenshtein Edit Distance Computation, Balhaf et al, 2016,
https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7476090&tag=1
inputs:
param a:

symbols. 2d (batch,time), int32

param a_len:

1d (batch,), int32

param b:

symbols. 2d (batch,time), int32

param b_len:

1d (batch,), int32

outputs:
param output:

1d (batch,), int32, unnormalized edit distance

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 0),)},)[source]#
c_extra_support_code: Dict[str, str] = {'001_next_step': '\n      DEF_KERNEL\n      void next_step_kernel(\n            int n_batch, int n_a_max_len, int n_b_max_len,\n            int diag_idx,\n            const int32_t* a, const int32_t* b,\n            const int32_t* a_len, const int32_t* b_len,\n            const int32_t* last1_dist, const int32_t* last2_dist, int32_t* cur_dist,\n            int32_t* result) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        // We are going diagonal!\n        int num_entries;\n        if(diag_idx <= n_a_max_len) {\n          num_entries = diag_idx + 1;\n          if(num_entries > n_b_max_len + 1)\n            num_entries = n_b_max_len + 1;\n        } else {\n          num_entries = n_b_max_len + 1 - (diag_idx - n_a_max_len);\n          if(num_entries > n_a_max_len + 1)\n            num_entries = n_a_max_len + 1;\n        }\n        int max_num_entries = n_a_max_len + 1;\n        if(max_num_entries > n_b_max_len + 1)\n          max_num_entries = n_b_max_len + 1;\n        while(idx < n_batch * num_entries) {\n          int batch_idx = idx / num_entries;\n          int entry_idx = idx % num_entries;\n          int dist_idx = batch_idx * max_num_entries + entry_idx;\n\n          int t_a, t_b;\n          if(diag_idx <= n_a_max_len) {\n            t_a = diag_idx - entry_idx;\n            t_b = entry_idx;\n          } else {\n            t_a = n_a_max_len - entry_idx;\n            t_b = diag_idx - n_a_max_len + entry_idx;\n          }\n\n          if(t_a == 0)\n            cur_dist[dist_idx] = t_b;  // distance == how much to delete from b\n          else if(t_b == 0)\n            cur_dist[dist_idx] = t_a;  // distance == how much to delete from a\n          else {\n            // last1 is with diag_idx - 2. Needed for substitution cost.\n            // last2 is with diag_idx - 1. Needed for insertion or deletion cost.\n            // last2 refers to the first, for deletion. last2_idx + 1 is for insertion.\n            int last1_idx, last2_idx;\n            if(diag_idx - 1 < n_a_max_len)\n              last1_idx = dist_idx - 1;\n            else if(diag_idx - 1 == n_a_max_len)\n              last1_idx = dist_idx;\n            else\n              last1_idx = dist_idx + 1;\n            if(diag_idx <= n_a_max_len)\n              last2_idx = dist_idx - 1;\n            else\n              last2_idx = dist_idx;\n\n            int del_cost, ins_cost, sub_cost;\n            del_cost = last2_dist[last2_idx] + 1;\n            ins_cost = last2_dist[last2_idx + 1] + 1;\n            sub_cost = last1_dist[last1_idx];\n            if(a[batch_idx * n_a_max_len + t_a - 1] != b[batch_idx * n_b_max_len + t_b - 1])\n              ++sub_cost;\n            //printf("t_a %i, t_b %i, del %i, ins %i, sub %i\\n", t_a, t_b, del_cost, ins_cost, sub_cost);\n            int min_cost = del_cost;\n            if(min_cost > ins_cost) min_cost = ins_cost;\n            if(min_cost > sub_cost) min_cost = sub_cost;\n            cur_dist[dist_idx] = min_cost;\n          }\n          //printf("t_a %i, t_b %i, dist %i\\n", t_a, t_b, cur_dist[dist_idx]);\n\n          if(t_a == a_len[batch_idx] && t_b == b_len[batch_idx])\n            result[batch_idx] = cur_dist[dist_idx];\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 4);\n    assert(n_outputs == 1);\n    Ndarray* a = inputs[0];\n    Ndarray* a_len = inputs[1];\n    Ndarray* b = inputs[2];\n    Ndarray* b_len = inputs[3];\n    Ndarray* out = *outputs[0];\n    assert_cmp(Ndarray_NDIM(a), ==, 2);\n    assert_cmp(Ndarray_NDIM(a_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(b), ==, 2);\n    assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(out), ==, 1);\n    int n_batch = Ndarray_DIMS(out)[0];\n    assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_len)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n    int n_a_max_len = Ndarray_DIMS(a)[1];\n    int n_b_max_len = Ndarray_DIMS(b)[1];\n    Ndarray_memset(Ndarray_DEV_DATA_int32(out), 255, n_batch * sizeof(int32_t));\n\n    // Working buffer.\n    int max_num_entries = std::min(n_a_max_len + 1, n_b_max_len + 1);\n    int32_t* buffer = (int32_t*) device_malloc(3 * n_batch * max_num_entries * sizeof(int32_t));\n    int32_t* last1_dist = buffer;\n    int32_t* last2_dist = buffer + n_batch * max_num_entries;\n    int32_t* cur_dist = buffer + 2 * n_batch * max_num_entries;\n\n    int num_diag = n_a_max_len + n_b_max_len + 1;\n    for(int diag_idx = 0; diag_idx < num_diag; ++diag_idx) {\n      start_dev_kernel(next_step_kernel, (\n        n_batch, n_a_max_len, n_b_max_len,\n        diag_idx,\n        Ndarray_DEV_DATA_int32(a), Ndarray_DEV_DATA_int32(b),\n        Ndarray_DEV_DATA_int32(a_len), Ndarray_DEV_DATA_int32(b_len),\n        last1_dist, last2_dist, cur_dist,\n        Ndarray_DEV_DATA_int32(out)));\n      // Rotate. last1_dist not needed anymore.\n      int32_t* tmp = last1_dist;\n      last1_dist = last2_dist;\n      last2_dist = cur_dist;\n      cur_dist = tmp;\n    }\n    HANDLE_LAST_ERROR();\n\n    device_free(buffer);\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.OptimalCompletionEditDistanceOp[source]#

Given some prefix a, what is the minimum possible edit distance to b with any possible suffix on a ? This is described in Optimal Completion Distillation (OCD). The implementation is derived from EditDistanceOp.

inputs:
param a:

symbols. 2d (batch,time), int32. prefix.

param a_len:

1d (batch,), int32

param b:

symbols. 2d (batch,time), int32

param b_len:

1d (batch,), int32

outputs:
param output:

1d (batch,), int32, unnormalized edit distance

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 0),)},)[source]#
c_extra_support_code: Dict[str, str] = {'001_init_result': '\n      DEF_KERNEL\n      void init_result_kernel(int n_batch, int32_t* result) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch) {\n          result[idx] = 2147483647;  // biggest int32\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '002_next_step': '\n      DEF_KERNEL\n      void next_step_kernel(\n            int n_batch, int n_a_max_len, int n_b_max_len,\n            int diag_idx,\n            const int32_t* a, const int32_t* b,\n            const int32_t* a_len, const int32_t* b_len,\n            const int32_t* last1_dist, const int32_t* last2_dist, int32_t* cur_dist,\n            int32_t* result) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        // We are going diagonal!\n        int num_entries;\n        if(diag_idx <= n_a_max_len) {\n          num_entries = diag_idx + 1;\n          if(num_entries > n_b_max_len + 1)\n            num_entries = n_b_max_len + 1;\n        } else {\n          num_entries = n_b_max_len + 1 - (diag_idx - n_a_max_len);\n          if(num_entries > n_a_max_len + 1)\n            num_entries = n_a_max_len + 1;\n        }\n        int max_num_entries = n_a_max_len + 1;\n        if(max_num_entries > n_b_max_len + 1)\n          max_num_entries = n_b_max_len + 1;\n        while(idx < n_batch * num_entries) {\n          int batch_idx = idx / num_entries;\n          int entry_idx = idx % num_entries;\n          int dist_idx = batch_idx * max_num_entries + entry_idx;\n\n          int t_a, t_b;\n          if(diag_idx <= n_a_max_len) {\n            t_a = diag_idx - entry_idx;\n            t_b = entry_idx;\n          } else {\n            t_a = n_a_max_len - entry_idx;\n            t_b = diag_idx - n_a_max_len + entry_idx;\n          }\n\n          if(t_a == 0)\n            cur_dist[dist_idx] = t_b;  // distance == how much to delete from b\n          else if(t_b == 0)\n            cur_dist[dist_idx] = t_a;  // distance == how much to delete from a\n          else {\n            // last1 is with diag_idx - 2. Needed for substitution cost.\n            // last2 is with diag_idx - 1. Needed for insertion or deletion cost.\n            // last2 refers to the first, for deletion. last2_idx + 1 is for insertion.\n            int last1_idx, last2_idx;\n            if(diag_idx - 1 < n_a_max_len)\n              last1_idx = dist_idx - 1;\n            else if(diag_idx - 1 == n_a_max_len)\n              last1_idx = dist_idx;\n            else\n              last1_idx = dist_idx + 1;\n            if(diag_idx <= n_a_max_len)\n              last2_idx = dist_idx - 1;\n            else\n              last2_idx = dist_idx;\n\n            int del_cost, ins_cost, sub_cost;\n            del_cost = last2_dist[last2_idx] + 1;\n            ins_cost = last2_dist[last2_idx + 1] + 1;\n            sub_cost = last1_dist[last1_idx];\n            if(a[batch_idx * n_a_max_len + t_a - 1] != b[batch_idx * n_b_max_len + t_b - 1])\n              ++sub_cost;\n            int min_cost = del_cost;\n            if(min_cost > ins_cost) min_cost = ins_cost;\n            if(min_cost > sub_cost) min_cost = sub_cost;\n            cur_dist[dist_idx] = min_cost;\n          }\n\n          if(t_a == a_len[batch_idx] && t_b <= b_len[batch_idx])\n            elem_atomic_min(&result[batch_idx], cur_dist[dist_idx]);\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 4);\n    assert(n_outputs == 1);\n    Ndarray* a = inputs[0];\n    Ndarray* a_len = inputs[1];\n    Ndarray* b = inputs[2];\n    Ndarray* b_len = inputs[3];\n    Ndarray* out = *outputs[0];\n    assert_cmp(Ndarray_NDIM(a), ==, 2);\n    assert_cmp(Ndarray_NDIM(a_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(b), ==, 2);\n    assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(out), ==, 1);\n    int n_batch = Ndarray_DIMS(out)[0];\n    assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_len)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n    int n_a_max_len = Ndarray_DIMS(a)[1];\n    int n_b_max_len = Ndarray_DIMS(b)[1];\n    start_dev_kernel(init_result_kernel, (n_batch, Ndarray_DEV_DATA_int32(out)));\n\n    // Working buffer.\n    int max_num_entries = std::min(n_a_max_len + 1, n_b_max_len + 1);\n    int32_t* buffer = (int32_t*) device_malloc(3 * n_batch * max_num_entries * sizeof(int32_t));\n    int32_t* last1_dist = buffer;\n    int32_t* last2_dist = buffer + n_batch * max_num_entries;\n    int32_t* cur_dist = buffer + 2 * n_batch * max_num_entries;\n\n    int num_diag = n_a_max_len + n_b_max_len + 1;\n    for(int diag_idx = 0; diag_idx < num_diag; ++diag_idx) {\n      start_dev_kernel(next_step_kernel, (\n        n_batch, n_a_max_len, n_b_max_len,\n        diag_idx,\n        Ndarray_DEV_DATA_int32(a), Ndarray_DEV_DATA_int32(b),\n        Ndarray_DEV_DATA_int32(a_len), Ndarray_DEV_DATA_int32(b_len),\n        last1_dist, last2_dist, cur_dist,\n        Ndarray_DEV_DATA_int32(out)));\n      // Rotate. last1_dist not needed anymore.\n      int32_t* tmp = last1_dist;\n      last1_dist = last2_dist;\n      last2_dist = cur_dist;\n      cur_dist = tmp;\n    }\n    HANDLE_LAST_ERROR();\n\n    device_free(buffer);\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.OptimalCompletionEditDistancePerSuccessorOp[source]#

Given some prefix a + successor, what is the minimum possible edit distance to b with any possible suffix on a + successor, for successor in successors. This is described in Optimal Completion Distillation (OCD). The implementation is derived from OptimalCompletionEditDistanceOp.

inputs:
param a:

symbols. 2d (batch,time), int32. prefix.

param a_len:

1d (batch,), int32

param b:

symbols. 2d (batch,time), int32

param b_len:

1d (batch,), int32

param successors:

1d (num_labels,), int32

outputs:
param output:

2d (batch,num_labels), int32, unnormalized edit distance

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'successors', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (4, 0))},)[source]#
c_extra_support_code: Dict[str, str] = {'001_next_step': '\n      DEF_KERNEL\n      void next_step_kernel(\n            int n_batch, int n_a_max_len, int n_b_max_len,\n            int diag_idx,\n            const int32_t* a, const int32_t* b,\n            const int32_t* a_len, const int32_t* b_len,\n            const int32_t* last1_dist, const int32_t* last2_dist, int32_t* cur_dist, int32_t* a_last_row) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        // We are going diagonal!\n        int num_entries;\n        if(diag_idx <= n_a_max_len) {\n          num_entries = diag_idx + 1;\n          if(num_entries > n_b_max_len + 1)\n            num_entries = n_b_max_len + 1;\n        } else {\n          num_entries = n_b_max_len + 1 - (diag_idx - n_a_max_len);\n          if(num_entries > n_a_max_len + 1)\n            num_entries = n_a_max_len + 1;\n        }\n        int max_num_entries = n_a_max_len + 1;\n        if(max_num_entries > n_b_max_len + 1)\n          max_num_entries = n_b_max_len + 1;\n        while(idx < n_batch * num_entries) {\n          int batch_idx = idx / num_entries;\n          int entry_idx = idx % num_entries;\n          int dist_idx = batch_idx * max_num_entries + entry_idx;\n\n          int t_a, t_b;\n          if(diag_idx <= n_a_max_len) {\n            t_a = diag_idx - entry_idx;\n            t_b = entry_idx;\n          } else {\n            t_a = n_a_max_len - entry_idx;\n            t_b = diag_idx - n_a_max_len + entry_idx;\n          }\n\n          if(t_a == 0)\n            cur_dist[dist_idx] = t_b;  // distance == how much to delete from b\n          else if(t_b == 0)\n            cur_dist[dist_idx] = t_a;  // distance == how much to delete from a\n          else {\n            // last1 is with diag_idx - 2. Needed for substitution cost.\n            // last2 is with diag_idx - 1. Needed for insertion or deletion cost.\n            // last2 refers to the first, for deletion. last2_idx + 1 is for insertion.\n            int last1_idx, last2_idx;\n            if(diag_idx - 1 < n_a_max_len)\n              last1_idx = dist_idx - 1;\n            else if(diag_idx - 1 == n_a_max_len)\n              last1_idx = dist_idx;\n            else\n              last1_idx = dist_idx + 1;\n            if(diag_idx <= n_a_max_len)\n              last2_idx = dist_idx - 1;\n            else\n              last2_idx = dist_idx;\n\n            int del_cost, ins_cost, sub_cost;\n            del_cost = last2_dist[last2_idx] + 1;\n            ins_cost = last2_dist[last2_idx + 1] + 1;\n            sub_cost = last1_dist[last1_idx];\n            if(a[batch_idx * n_a_max_len + t_a - 1] != b[batch_idx * n_b_max_len + t_b - 1])\n              ++sub_cost;\n            int min_cost = del_cost;\n            if(min_cost > ins_cost) min_cost = ins_cost;\n            if(min_cost > sub_cost) min_cost = sub_cost;\n            cur_dist[dist_idx] = min_cost;\n          }\n\n          if(t_a == a_len[batch_idx] && t_b <= b_len[batch_idx])\n            a_last_row[batch_idx * (n_b_max_len + 1) + t_b] = cur_dist[dist_idx];\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '002_init_result': '\n      DEF_KERNEL\n      void init_result_kernel(\n            int n_batch, int n_b_max_len, int n_labels,\n            const int32_t* a_len, const int32_t* b_len,\n            const int32_t* a_last_row,\n            int32_t* result\n      ) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch * n_labels) {\n          int batch_idx = idx / n_labels;\n          int successor_idx = idx % n_labels;\n\n          // Initial insertion, last deletion.\n          int t_a = a_len[batch_idx] + 1;\n          int min_cost = t_a;\n          int last_del_cost = a_last_row[batch_idx * (n_b_max_len + 1) + b_len[batch_idx]] + 1;\n          if(min_cost > last_del_cost) min_cost = last_del_cost;\n          result[batch_idx * n_labels + successor_idx] = min_cost;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    ', '003_expand': '\n      DEF_KERNEL\n      void expand_kernel(\n            int n_batch, int n_b_max_len, int n_labels,\n            const int32_t* b,\n            const int32_t* b_len,\n            const int32_t* a_last_row,\n            const int32_t* successors,\n            int32_t* result\n      ) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch * n_labels * n_b_max_len) {\n          int batch_idx = idx / n_b_max_len / n_labels;\n          int successor_idx = (idx / n_b_max_len) % n_labels;\n          int t_b = idx % n_b_max_len;\n          int successor = successors[successor_idx];\n\n          if(t_b < b_len[batch_idx]) {\n            // We can ignore insertion/deletion\n            // (except initial insertion / last deletion, see init_result_kernel).\n            int sub_cost = a_last_row[batch_idx * (n_b_max_len + 1) + t_b];\n            if(successor != b[batch_idx * n_b_max_len + t_b])\n              ++sub_cost;\n            elem_atomic_min(&result[batch_idx * n_labels + successor_idx], sub_cost);\n          }\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 5);\n    assert(n_outputs == 1);\n    Ndarray* a = inputs[0];\n    Ndarray* a_len = inputs[1];\n    Ndarray* b = inputs[2];\n    Ndarray* b_len = inputs[3];\n    Ndarray* successors = inputs[4];\n    Ndarray* out = *outputs[0];\n    assert_cmp(Ndarray_NDIM(a), ==, 2);\n    assert_cmp(Ndarray_NDIM(a_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(b), ==, 2);\n    assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(successors), ==, 1);\n    assert_cmp(Ndarray_NDIM(out), ==, 2);\n    int n_batch = Ndarray_DIMS(out)[0];\n    int n_labels = Ndarray_DIMS(successors)[0];\n    assert_cmp(Ndarray_DIMS(out)[1], ==, n_labels);\n    assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_len)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n    int n_a_max_len = Ndarray_DIMS(a)[1];\n    int n_b_max_len = Ndarray_DIMS(b)[1];\n    Ndarray_memset(Ndarray_DEV_DATA_int32(out), 255, n_batch * n_labels * sizeof(int32_t));\n\n    // Working buffer.\n    int max_num_entries = std::min(n_a_max_len + 1, n_b_max_len + 1);\n    int32_t* buffer = (int32_t*) device_malloc(3 * n_batch * max_num_entries * sizeof(int32_t));\n    int32_t* last1_dist = buffer;\n    int32_t* last2_dist = buffer + n_batch * max_num_entries;\n    int32_t* cur_dist = buffer + 2 * n_batch * max_num_entries;\n    int32_t* a_last_row = (int32_t*) device_malloc(n_batch * (n_b_max_len + 1) * sizeof(int32_t));\n\n    int num_diag = n_a_max_len + n_b_max_len + 1;\n    for(int diag_idx = 0; diag_idx < num_diag; ++diag_idx) {\n      start_dev_kernel(next_step_kernel, (\n        n_batch, n_a_max_len, n_b_max_len,\n        diag_idx,\n        Ndarray_DEV_DATA_int32(a), Ndarray_DEV_DATA_int32(b),\n        Ndarray_DEV_DATA_int32(a_len), Ndarray_DEV_DATA_int32(b_len),\n        last1_dist, last2_dist, cur_dist, a_last_row\n      ));\n      // Rotate. last1_dist not needed anymore.\n      int32_t* tmp = last1_dist;\n      last1_dist = last2_dist;\n      last2_dist = cur_dist;\n      cur_dist = tmp;\n    }\n    HANDLE_LAST_ERROR();\n\n    start_dev_kernel(init_result_kernel, (\n      n_batch, n_b_max_len, n_labels,\n      Ndarray_DEV_DATA_int32(a_len), Ndarray_DEV_DATA_int32(b_len),\n      a_last_row,\n      Ndarray_DEV_DATA_int32(out)\n    ));\n    HANDLE_LAST_ERROR();\n\n    start_dev_kernel(expand_kernel, (\n      n_batch, n_b_max_len, n_labels,\n      Ndarray_DEV_DATA_int32(b),\n      Ndarray_DEV_DATA_int32(b_len),\n      a_last_row,\n      Ndarray_DEV_DATA_int32(successors),\n      Ndarray_DEV_DATA_int32(out)\n    ));\n    HANDLE_LAST_ERROR();\n\n    device_free(buffer);\n    device_free(a_last_row);\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.NextEditDistanceRowOp[source]#

This does a single step in calculating the edit distance table, going over the symbols in a. Note that when you have the full sequence a in advance, EditDistanceOp should be faster. However, this iterative op is useful when a is constructed step by step.

inputs:
param last_row:

2d (batch,b_time + 1), int32. last edit distances

param a:

symbols. 1d (batch,), int32. current.

param a_n:

(batch,), int32. current position

param a_ended:

1d (batch,), int32 (casted from bool, because int32 easier to handle)

param b:

symbols. 2d (batch,b_time), int32

param b_len:

1d (batch,), int32

outputs:
param output:

2d (batch,b_time + 1), int32, next (unnormalized) edit distance row

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'last_row', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_n', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_ended', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))},)[source]#
c_extra_support_code: Dict[str, str] = {'001_next_row': '\n      DEF_KERNEL\n      void next_row_kernel(\n            int n_batch, int n_b_max_len,\n            const int32_t* last_row,\n            const int32_t* a, const int32_t* a_n, const int32_t* a_ended,\n            const int32_t* b, const int32_t* b_len,\n            int32_t* next_row\n      ) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch) {\n          int batch_idx = idx;\n\n          int last_dist;\n          if(!a_ended[batch_idx]) {\n            last_dist = a_n[batch_idx] + 1;  // Initial deletion error.\n            next_row[batch_idx * (n_b_max_len + 1)] = last_dist;\n            for(int t_b = 1; t_b <= b_len[batch_idx]; ++t_b) {\n              int ins_error = last_row[batch_idx * (n_b_max_len + 1) + t_b] + 1;\n              int del_error = last_dist + 1;\n              int sub_error = last_row[batch_idx * (n_b_max_len + 1) + t_b - 1];\n              if(a[batch_idx] != b[batch_idx * n_b_max_len + t_b - 1])\n                ++sub_error;\n              last_dist = ins_error;\n              if(last_dist > del_error) last_dist = del_error;\n              if(last_dist > sub_error) last_dist = sub_error;\n              next_row[batch_idx * (n_b_max_len + 1) + t_b] = last_dist;\n            }\n          }\n          else {  // a ended\n            // Just copy over.\n            for(int t_b = 0; t_b <= b_len[batch_idx]; ++t_b) {\n              last_dist = last_row[batch_idx * (n_b_max_len + 1) + t_b];\n              next_row[batch_idx * (n_b_max_len + 1) + t_b] = last_dist;\n            }\n          }\n          // Repeat last entry.\n          for(int t_b = b_len[batch_idx] + 1; t_b < n_b_max_len + 1; ++t_b)\n            next_row[batch_idx * (n_b_max_len + 1) + t_b] = last_dist;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 6);\n    assert(n_outputs == 1);\n    Ndarray* last_row = inputs[0];\n    Ndarray* a = inputs[1];\n    Ndarray* a_n = inputs[2];\n    Ndarray* a_ended = inputs[3];\n    Ndarray* b = inputs[4];\n    Ndarray* b_len = inputs[5];\n    Ndarray* out = *outputs[0];\n    assert_cmp(Ndarray_NDIM(last_row), ==, 2);\n    assert_cmp(Ndarray_NDIM(a), ==, 1);\n    assert_cmp(Ndarray_NDIM(a_n), ==, 1);\n    assert_cmp(Ndarray_NDIM(a_ended), ==, 1);\n    assert_cmp(Ndarray_NDIM(b), ==, 2);\n    assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(out), ==, 2);\n    int n_batch = Ndarray_DIMS(out)[0];\n    int n_b_max_len = Ndarray_DIMS(b)[1];\n    assert_cmp(Ndarray_DIMS(out)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(out)[1], ==, n_b_max_len + 1);\n    assert_cmp(Ndarray_DIMS(last_row)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(last_row)[1], ==, n_b_max_len + 1);\n    assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_n)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_ended)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[1], ==, n_b_max_len);\n    assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n\n    start_dev_kernel(next_row_kernel, (\n      n_batch, n_b_max_len,\n      Ndarray_DEV_DATA_int32(last_row),\n      Ndarray_DEV_DATA_int32(a), Ndarray_DEV_DATA_int32(a_n), Ndarray_DEV_DATA_int32(a_ended),\n      Ndarray_DEV_DATA_int32(b), Ndarray_DEV_DATA_int32(b_len),\n      Ndarray_DEV_DATA_int32(out)\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
c_bw_code: str = None[source]#
class returnn.native_op.NextEditDistanceReduceOp[source]#

Code derived from NextEditDistanceRowOp.

inputs:
param last_row:

2d (batch,b_time + 1), int32. last edit distances

param a:

symbols. 2d (batch|1,n_labels), int32. current.

param a_n:

1d (batch,), int32. current position

param a_ended:

1d (batch,), int32 (casted from bool, because int32 easier to handle)

param b:

symbols. 2d (batch,b_time), int32

param b_len:

1d (batch,), int32

param optimal_completion:

scalar, int32 (casted from bool). True -> reduce_min over row; False -> last of row

param a_blank_idx:

scalar, int32. use -1 to not use

outputs:
param output:

2d (batch,n_labels), int32, next (unnormalized) (maybe optional) edit distance

in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'last_row', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_n', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_ended', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'optimal_completion', 'ndim': 0, 'shape': ()}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'a_blank_idx', 'ndim': 0, 'shape': ()})[source]#
out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (1, 1))},)[source]#
c_extra_support_code: Dict[str, str] = {'001_calc_result': '\n      DEF_KERNEL\n      void calc_result_kernel(\n            int n_batch, int n_b_max_len, int n_labels,\n            const int32_t* last_row,\n            const int32_t* a, const int32_t* a_n, const int32_t* a_ended,\n            const int32_t* b, const int32_t* b_len,\n            int32_t* result,\n            bool optimal_completion,\n            bool a_broadcast_batch,\n            int32_t a_blank_idx\n      ) {\n        int idx = threadIdx.x + blockDim.x * blockIdx.x;\n        while(idx < n_batch * n_labels) {\n          int batch_idx = idx / n_labels;\n          int label_idx = idx % n_labels;\n          int a_label = a[(a_broadcast_batch ? 0 : batch_idx) * n_labels + label_idx];\n\n          int total_min_error;\n          int last_dist;\n          if(!a_ended[batch_idx] && a_label != a_blank_idx) {\n            last_dist = a_n[batch_idx] + 1;  // Initial deletion error.\n            total_min_error = last_dist;\n            for(int t_b = 1; t_b <= b_len[batch_idx]; ++t_b) {\n              int ins_error = last_row[batch_idx * (n_b_max_len + 1) + t_b] + 1;\n              int del_error = last_dist + 1;\n              int sub_error = last_row[batch_idx * (n_b_max_len + 1) + t_b - 1];\n              if(a_label != b[batch_idx * n_b_max_len + t_b - 1])\n                ++sub_error;\n              int min_error = ins_error;\n              if(min_error > del_error) min_error = del_error;\n              if(min_error > sub_error) min_error = sub_error;\n              last_dist = min_error;\n              if(total_min_error > last_dist) total_min_error = last_dist;\n            }\n          }\n          else {  // a ended or blank\n            // Just copy over.\n            total_min_error = last_row[batch_idx * (n_b_max_len + 1)];\n            for(int t_b = 0; t_b <= b_len[batch_idx]; ++t_b) {\n              last_dist = last_row[batch_idx * (n_b_max_len + 1) + t_b];\n              if(total_min_error > last_dist) total_min_error = last_dist;\n            }\n          }\n\n          result[batch_idx * n_labels + label_idx] = optimal_completion ? total_min_error : last_dist;\n\n          idx += gridDim.x * blockDim.x;\n        }\n      }\n    '}[source]#
c_fw_code: str = '\n    assert(n_inputs == 8);\n    assert(n_outputs == 1);\n    Ndarray* last_row = inputs[0];\n    Ndarray* a = inputs[1];\n    Ndarray* a_n = inputs[2];\n    Ndarray* a_ended = inputs[3];\n    Ndarray* b = inputs[4];\n    Ndarray* b_len = inputs[5];\n    bool optimal_completion = (bool) Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n    int32_t a_blank_idx = Ndarray_DEV_DATA_int32_scalar(inputs[7]);\n    Ndarray* out = *outputs[0];\n    assert_cmp(Ndarray_NDIM(last_row), ==, 2);\n    assert_cmp(Ndarray_NDIM(a), ==, 2);\n    assert_cmp(Ndarray_NDIM(a_n), ==, 1);\n    assert_cmp(Ndarray_NDIM(a_ended), ==, 1);\n    assert_cmp(Ndarray_NDIM(b), ==, 2);\n    assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n    assert_cmp(Ndarray_NDIM(out), ==, 2);\n    int n_batch = Ndarray_DIMS(out)[0];\n    int n_labels = Ndarray_DIMS(out)[1];\n    int n_b_max_len = Ndarray_DIMS(b)[1];\n    assert_cmp(Ndarray_DIMS(out)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(out)[1], ==, n_labels);\n    assert_cmp(Ndarray_DIMS(last_row)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(last_row)[1], ==, n_b_max_len + 1);\n    bool a_broadcast_batch = Ndarray_DIMS(a)[0] == 1;\n    if(!a_broadcast_batch)\n      assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a)[1], ==, n_labels);\n    assert_cmp(Ndarray_DIMS(a_n)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(a_ended)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n    assert_cmp(Ndarray_DIMS(b)[1], ==, n_b_max_len);\n    assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n\n    start_dev_kernel(calc_result_kernel, (\n      n_batch, n_b_max_len, n_labels,\n      Ndarray_DEV_DATA_int32(last_row),\n      Ndarray_DEV_DATA_int32(a), Ndarray_DEV_DATA_int32(a_n), Ndarray_DEV_DATA_int32(a_ended),\n      Ndarray_DEV_DATA_int32(b), Ndarray_DEV_DATA_int32(b_len),\n      Ndarray_DEV_DATA_int32(out),\n      optimal_completion,\n      a_broadcast_batch, a_blank_idx\n    ));\n    HANDLE_LAST_ERROR();\n  '[source]#
c_bw_code: str = None[source]#
returnn.native_op.sparse_splice_offset_numpy(s0, idx)[source]#

Like sparse_slice_offset().