returnn.native_op
#
Generic interface which automatically creates:
CPU and GPU (CUDA) op
inplace and not inplace
grad variants
See returnn.tf.native_op
and returnn.theano.native_op
for usage in TensorFlow and Theano.
See Native operations for more background.
- class returnn.native_op.NativeOpBaseMixin(in_info, out_info, c_fw_code, c_bw_code=None, c_extra_support_code=None, code_version=None, cpu_support=True, grad_input_map=None, name=None)[source]#
The purpose of having this as a separate base class is to make this independent of any Theano specific functionality so that we can also use this base for example for TensorFlow.
- Parameters:
in_info (list[dict(str)]) –
each dict describes one input var. attribs in the dict:
int ndim: the ndim. tuple shape: tuple and can contain None for specific dimensions.
- optional attribs:
str dtype: “float32” by default. bool need_contiguous: false by default. int want_inplace: -1 by default. try to optimize to destroy input, on output-index.
”dummy_out” is a special value which will add another output.
bool is_inplace: false by default. whether the optimization was applied. str gradient: can be “disconnected”. see grad(). bool bw_input: True by default. add this param to the bw input.
other attribs are just ignored.
out_info (list[dict(str)]) –
like in_info. slightly different behavior for:
shape: we also allow refs to the in_info in the form (in-idx,dim). see infer_shape(). need_contiguous/want_inplace: used for bw, in case for bw_input == True.
c_fw_code (str) – C code for forward pass
c_extra_support_code (str|dict[str]) – C support code (for c_support_code)
c_bw_code (str|None) – C code for backward pass (for gradient)
code_version (tuple[int]) – will be returned by c_code_cache_version.
cpu_support (bool) –
grad_input_map (tuple[int]|callable) – selection of grad inputs. by default, we get all inputs + all outputs + all grad outputs.
name (str) – name
- infer_shape(node, input_shapes)[source]#
- Parameters:
node –
input_shapes –
- Return type:
list[tuple[int]]
- class returnn.native_op.NativeOpGenBase[source]#
Base interface for op generation. See NativeOp.__init__() for attribs.
- class returnn.native_op.LstmGenericBase[source]#
- inputs:
- param Z:
{input,output,forget} gate + cell state. 3d (time,batch,dim*4)
- param V_h:
recurrent matrix. 2d (dim,dim*4)
- param c:
initial cell state. 2d (batch,dim)
- param i:
index. 2d (time,batch) -> 0 or 1
- outputs:
- param Y:
output. 3d (time,batch,dim)
- param H:
gates and cell state. 3d (time,batch,dim*4)
- param d:
final cell state. 2d (batch,dim)
- in_info: Tuple[Dict[str]] = ({'bw_out_var': {'shape': ((2, 0), (2, 1), (0, 1))}, 'name': 'Z', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None), 'want_inplace': 1}, {'name': 'V_h', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
- out_info: Tuple[Dict[str]] = ({'bw_grad_var': {'want_inplace': 'dummy_out'}, 'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'bw_in_var': {'want_inplace': 0}, 'name': 'H', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((2, 0), (2, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'lstm_bwd_kernel': '\n DEF_KERNEL\n void lstm_bwd_kernel(\n float* delta, float* epsilon, const float* next_epsilon, const float* old_state,\n bool old_state_strided, const float* Y, int n_cells, int n_batch, const float* i) {\n //layout:\n //delta[0*n_cells..1*n_cells-1] : input gate\n //delta[1*n_cells..2*n_cells-1] : forget gate\n //delta[2*n_cells..3*n_cells-1] : output gate\n //delta[3*n_cells..4*n_cells-1] : cell state\n //epsilon[0*n_cells..1*n_cells-1]: cell output derivative (later overwritten, see below)\n //next_epsilon[0*n_cells..1*n_cells-1]: cell state derivative * forget_gate (of next timestep)\n //repeated for every mini-batch\n\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int batch_offset = batch_idx * 4 * n_cells;\n int cell_offset = idx % n_cells;\n int start = batch_offset + cell_offset;\n float i_batch = i[batch_idx];\n\n float inpGate = delta[start + n_cells];\n float fgtGate = delta[start + 2 * n_cells];\n float outGate = delta[start + 3 * n_cells];\n float oldState = old_state_strided ? old_state[start] : old_state[idx];\n float state = delta[start];\n float eps = epsilon[idx];\n\n //avoid division by 0\n float gc = tanhf(state); //g(c(t))\n float gzc = (state - fgtGate * oldState) / fmaxf(inpGate, float(1e-16)); //g(z_c(t))\n\n //delta_output\n delta[start + 3 * n_cells] = outGate * (1.f - outGate) * gc * eps * i_batch;\n\n //epsilon_c\n float epsilon_c = (1.f - (gc * gc)) * outGate * eps;\n epsilon_c += next_epsilon[idx];\n epsilon[idx] = epsilon_c * fgtGate * i_batch + next_epsilon[idx] * (1.f - i_batch);\n\n //delta_cell\n delta[start] = inpGate * (1.f - (gzc * gzc)) * epsilon_c * i_batch;\n\n //delta_forget\n delta[start + 2 * n_cells] = fgtGate * (1.f - fgtGate) * oldState * epsilon_c * i_batch;\n\n //delta_input\n delta[start + n_cells] = inpGate * (1.f - inpGate) * gzc * epsilon_c * i_batch;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'lstm_kernel': '\n DEF_KERNEL\n void lstm_kernel(float* data, const float* old_state, bool old_state_strided,\n float* output, float* state_out, int n_cells, int n_batch, const float* i) {\n //layout:\n //data[0*n_cells..1*n_cells-1] : cell state\n //data[1*n_cells..2*n_cells-1] : input gate\n //data[2*n_cells..3*n_cells-1] : forget gate\n //data[3*n_cells..4*n_cells-1] : output gate\n //output[0*n_cells..1*n_cells-1]: cell output\n //repeated for every mini-batch\n\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int start = batch_idx * 4 * n_cells + idx % n_cells;\n float i_batch = i[batch_idx];\n\n //input, forget and output gates\n float inpGate = 1.f / (1.f + expf(-data[start + n_cells]));\n float fgtGate = 1.f / (1.f + expf(-data[start + 2 * n_cells]));\n float outGate = 1.f / (1.f + expf(-data[start + 3 * n_cells]));\n float state = inpGate * tanhf(data[start]);\n float old_state_batch = old_state_strided ? old_state[start] : old_state[idx];\n\n state += fgtGate * old_state_batch;\n state = state * i_batch + old_state_batch * (1.f - i_batch);\n\n //cell output\n output[idx] = outGate * tanhf(state) * i_batch;\n\n data[start] = state;\n data[start + n_cells] = inpGate;\n data[start + 2 * n_cells] = fgtGate;\n data[start + 3 * n_cells] = outGate;\n if(state_out)\n state_out[idx] = state;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n // Z*, V_h, c, i = input_names (*: inplace)\n // Y, H, d = output_names\n assert(n_inputs == 4);\n assert(n_outputs == 3);\n Ndarray* V_h = inputs[1];\n Ndarray* c = inputs[2];\n Ndarray* i = inputs[3];\n Ndarray* Y = *outputs[0];\n Ndarray* H = *outputs[1]; // inplace on Z\n Ndarray* d = *outputs[2];\n\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n assert(Ndarray_DIMS(H)[2] %% 4 == 0); // 3 gates + cell\n int n_cells = Ndarray_DIMS(H)[2] / 4;\n\n assert(T > 0);\n for(int x = 0; x < T; ++x) {\n if(x > 0) {\n //H += Y[x-1]*V_h\n affine_y_x(x-1, Y, x, V_h, x, H);\n }\n\n start_dev_kernel(lstm_kernel, (\n data_ptr(H, x),\n x > 0 ? data_ptr(H, x - 1) : Ndarray_DEV_DATA(c),\n x > 0,\n data_ptr(Y, x),\n (x == T - 1) ? Ndarray_DEV_DATA(d) : 0,\n n_cells,\n n_batch,\n Ndarray_DEV_DATA(i) + x * n_batch\n ));\n }\n HANDLE_LAST_ERROR();\n '[source]#
- c_bw_code: str = '\n // V_h, c, i, Y, H*, DY*, Dd = input_names (*: inplace)\n // DZ, DV_h, Dc, tmpDc = output_names\n assert(n_inputs == 7);\n assert(n_outputs == 4);\n Ndarray* V_h = inputs[0];\n Ndarray* c = inputs[1];\n Ndarray* i = inputs[2];\n Ndarray* Y = inputs[3];\n Ndarray* Dd = inputs[6];\n Ndarray* DZ = *outputs[0]; // inplace on H\n Ndarray* DV_h = *outputs[1];\n Ndarray* Dc = *outputs[2];\n Ndarray* tmpDc = *outputs[3]; // (old DY), inplace buffer\n\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n assert(Ndarray_DIMS(DZ)[2] %% 4 == 0); // 3 gates + cell\n int n_cells = Ndarray_DIMS(DZ)[2] / 4;\n\n assert(T > 0);\n for(int x = T - 1; x >= 0; --x) {\n // add recurrent\n bool rightBorder = (x == T - 1);\n if(!rightBorder)\n affine_y_x(x+1, DZ, x, V_h, x, tmpDc, false, true);\n\n start_dev_kernel(lstm_bwd_kernel, (\n data_ptr(DZ, x),\n data_ptr(tmpDc, x),\n rightBorder ? Ndarray_DEV_DATA(Dd) : data_ptr(tmpDc, x + 1),\n x > 0 ? data_ptr(DZ, x - 1) : Ndarray_DEV_DATA(c),\n x > 0,\n data_ptr(Y, x),\n n_cells,\n n_batch,\n Ndarray_DEV_DATA(i) + x * n_batch\n ));\n }\n\n //DV_h = Y[0..end-1]^T * DZ[1..end]\n affine_global(Y, DZ, DV_h, true, false, 1, 0.0f);\n\n Ndarray_DIMS_Type Dc_dim = Ndarray_HOST_DIMS(Dc);\n Ndarray_memcpy(\n Ndarray_DEV_DATA(Dc), Ndarray_DEV_DATA(tmpDc),\n Dc_dim[0] * Dc_dim[1] * sizeof(float));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.LstmLowMem[source]#
This is designed to require minimal memory during training. It only stores the outputs and the cell states, i.e. it requires time * cells * 2 floats for memory in total.
- inputs:
- param X:
(time,batch,in_dim)
- param W:
forward+recurrent matrix. 2d (in_dim+dim,dim*4)
- param b:
bias. 1d (dim*4,)
- param y0:
initial output|hidden state. 2d (batch,dim)
- param c0:
initial cell state. 2d (batch,dim)
- param i:
index. 2d (time,batch) -> 0 or 1
- param start:
where to start. must be >=0, default is usually 0. dtype int, scalar.
- param step:
+1 for fwd, -1 for bwd direction. can also be |step|>1 for wider steps. dtype int, scalar. for bwd (<0), will start at T-start-1.
- outputs:
- param Y:
output. 3d (time,batch,dim)
- param C:
cell states. 3d (time,batch,dim). gradient ignored!
- param d:
final cell state. 2d (batch,dim)
- in_info: Tuple[Dict[str]] = ({'name': 'X', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'b', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'name': 'y0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'start', 'ndim': 0, 'shape': ()}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'step', 'ndim': 0, 'shape': ()})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (4, 1))}, {'name': 'C', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (4, 1))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 1), (4, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'add_bias_kernel': '\n DEF_KERNEL\n void add_bias_kernel(int n_batch, int n_dim, float* x, float* b) {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_batch * n_dim) {\n int dim_idx = idx % n_dim;\n x[idx] += b[dim_idx];\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'copy_x_h_kernel': '\n DEF_KERNEL\n void copy_x_h_kernel(\n int n_batch, int n_in, int n_cells,\n float* x_h,\n float* x,\n float* h)\n {\n int n_total_in = n_in + n_cells;\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_batch * n_total_in) {\n int batch_idx = idx / n_total_in;\n int in_dim_idx = idx % n_total_in;\n\n if(in_dim_idx < n_in)\n x_h[idx] = x[batch_idx * n_in + in_dim_idx];\n else\n x_h[idx] = h[batch_idx * n_cells + in_dim_idx - n_in];\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'inv_copy_x_h_kernel': '\n DEF_KERNEL\n void inv_copy_x_h_kernel(\n int n_batch, int n_in, int n_cells,\n float* x_h,\n float* x,\n float* h)\n {\n int n_total_in = n_in + n_cells;\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_batch * n_total_in) {\n int batch_idx = idx / n_total_in;\n int in_dim_idx = idx % n_total_in;\n\n if(in_dim_idx < n_in)\n x[batch_idx * n_in + in_dim_idx] = x_h[idx];\n else\n h[batch_idx * n_cells + in_dim_idx - n_in] = x_h[idx];\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'lstm_bwd_kernel': '\n DEF_KERNEL\n void lstm_bwd_kernel(\n int n_batch, int n_in, int n_cells, const float* mask,\n float* x_h,\n float* intern,\n float* prev_c,\n float* y,\n float* c,\n float* d_y,\n float* d_h,\n float* d_c,\n float* d_intern,\n float* d_b)\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int cell_idx = idx % n_cells;\n int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n float mask_b = mask[batch_idx];\n float d_y_b = d_y[idx] * mask_b + d_h[idx];\n float d_c_b = d_c[idx] * mask_b;\n float prev_c_b = prev_c[idx];\n\n // cell-in + input, forget and output gates\n float cellIn = tanhf(intern[intern_offset]);\n float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));\n float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));\n float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));\n\n float c_b = prev_c_b * fgtGate + cellIn * inpGate;\n float gc = tanhf(c_b);\n float d_outGate_in = (1.f - outGate) * outGate * gc * d_y_b;\n float d_c2 = d_c_b + outGate * d_y_b * (1.f - gc * gc);\n float d_cellIn_in = (1.f - cellIn * cellIn) * inpGate * d_c2;\n float d_inpGate_in = (1.f - inpGate) * inpGate * cellIn * d_c2;\n float d_fgtGate_in = (1.f - fgtGate) * fgtGate * prev_c_b * d_c2;\n d_c[idx] = fgtGate * d_c2 + d_c[idx] * (1.f - mask_b);\n\n d_intern[intern_offset] = d_cellIn_in;\n d_intern[intern_offset + n_cells] = d_inpGate_in;\n d_intern[intern_offset + 2 * n_cells] = d_fgtGate_in;\n d_intern[intern_offset + 3 * n_cells] = d_outGate_in;\n\n elem_atomic_add(&d_b[cell_idx], d_cellIn_in);\n elem_atomic_add(&d_b[cell_idx + n_cells], d_inpGate_in);\n elem_atomic_add(&d_b[cell_idx + 2 * n_cells], d_fgtGate_in);\n elem_atomic_add(&d_b[cell_idx + 3 * n_cells], d_outGate_in);\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'lstm_kernel': '\n DEF_KERNEL\n void lstm_kernel(\n int n_batch, int n_cells, const float* mask,\n float* intern,\n float* prev_c,\n float* y,\n float* c)\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int cell_idx = idx % n_cells;\n int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n float prev_c_b = prev_c[idx];\n float mask_b = mask[batch_idx];\n\n // cell-in + input, forget and output gates\n float cellIn = tanhf(intern[intern_offset]);\n float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));\n float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));\n float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));\n\n float c_b = (prev_c_b * fgtGate + cellIn * inpGate) * mask_b\n + prev_c_b * (1.f - mask_b);\n c[idx] = c_b;\n y[idx] = tanhf(c_b) * outGate * mask_b;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n // X, W, b, y0, c0, i, start, step = input_names\n // Y, C, d = output_names\n assert(n_inputs == 8);\n assert(n_outputs == 3);\n Ndarray* X = inputs[0];\n Ndarray* W = inputs[1];\n Ndarray* b = inputs[2];\n Ndarray* y0 = inputs[3];\n Ndarray* c0 = inputs[4];\n Ndarray* i = inputs[5];\n assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n assert_cmp(Ndarray_NDIM(inputs[7]), ==, 0);\n int start = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n int step = Ndarray_DEV_DATA_int32_scalar(inputs[7]);\n Ndarray* Y = *outputs[0];\n Ndarray* C = *outputs[1];\n Ndarray* d = *outputs[2];\n\n assert_cmp(Ndarray_NDIM(X), ==, 3);\n assert_cmp(Ndarray_NDIM(W), ==, 2);\n assert_cmp(Ndarray_NDIM(b), ==, 1);\n assert_cmp(Ndarray_NDIM(y0), ==, 2);\n assert_cmp(Ndarray_NDIM(c0), ==, 2);\n assert_cmp(Ndarray_NDIM(i), ==, 2);\n assert_cmp(Ndarray_NDIM(Y), ==, 3);\n assert_cmp(Ndarray_NDIM(C), ==, 3);\n assert_cmp(Ndarray_NDIM(d), ==, 2);\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n int n_cells = Ndarray_DIMS(y0)[1];\n int n_in = Ndarray_DIMS(X)[2];\n assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(W)[0], ==, n_in + n_cells);\n assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(b)[0], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(d)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(d)[1], ==, n_cells);\n\n float* x_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n float* intern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float)); // 3 gates + in\n\n assert_cmp(T, >, 0);\n assert_cmp(start, >=, 0);\n assert_cmp(start, <, T);\n assert_cmp(step, !=, 0);\n int end = T - 1;\n if(step < 0) {\n end = start;\n start = T - start - 1;\n }\n int t = start;\n for(; (step > 0) ? (t <= end) : (t >= end); t += step) {\n // x_h = X[t], Y[t-1]\n start_dev_kernel(copy_x_h_kernel,\n (n_batch, n_in, n_cells, x_h, data_ptr(X, t), (t != start) ? data_ptr(Y, t-step) : Ndarray_DEV_DATA(y0)));\n // intern = x_h * W\n affine_raw(\n x_h, n_batch, n_in + n_cells,\n Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n intern, n_batch, n_cells * 4,\n false, false, 0.0);\n // intern += b\n start_dev_kernel(add_bias_kernel, (\n n_batch, n_cells * 4, intern, Ndarray_DEV_DATA(b)));\n\n start_dev_kernel(lstm_kernel, (\n n_batch,\n n_cells,\n Ndarray_DEV_DATA(i) + t * n_batch,\n intern,\n (t != start) ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n data_ptr(Y, t), // out\n data_ptr(C, t) // out\n ));\n }\n HANDLE_LAST_ERROR();\n\n device_free(x_h);\n device_free(intern);\n\n Ndarray_memcpy(Ndarray_DEV_DATA(d), data_ptr(C, t - step), n_batch * n_cells * sizeof(float));\n '[source]#
- c_bw_code: str = '\n // X, W, b, y0, c0, i, start, step, Y, C, DY, Dd = input_names\n // DX, DW, Db, Dh, Dc = output_names\n assert(n_inputs == 12);\n assert(n_outputs == 5);\n Ndarray* X = inputs[0];\n Ndarray* W = inputs[1];\n Ndarray* b = inputs[2];\n Ndarray* y0 = inputs[3];\n Ndarray* c0 = inputs[4];\n Ndarray* i = inputs[5];\n assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n assert_cmp(Ndarray_NDIM(inputs[7]), ==, 0);\n int start = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n int step = Ndarray_DEV_DATA_int32_scalar(inputs[7]);\n Ndarray* Y = inputs[8];\n Ndarray* C = inputs[9];\n Ndarray* DY = inputs[10];\n Ndarray* Dd = inputs[11];\n Ndarray* DX = *outputs[0];\n Ndarray* DW = *outputs[1];\n Ndarray* Db = *outputs[2];\n Ndarray* Dh = *outputs[3];\n Ndarray* Dc = *outputs[4];\n\n assert_cmp(Ndarray_NDIM(X), ==, 3);\n assert_cmp(Ndarray_NDIM(W), ==, 2);\n assert_cmp(Ndarray_NDIM(b), ==, 1);\n assert_cmp(Ndarray_NDIM(y0), ==, 2);\n assert_cmp(Ndarray_NDIM(c0), ==, 2);\n assert_cmp(Ndarray_NDIM(i), ==, 2);\n assert_cmp(Ndarray_NDIM(Y), ==, 3);\n assert_cmp(Ndarray_NDIM(C), ==, 3);\n assert_cmp(Ndarray_NDIM(DY), ==, 3);\n assert_cmp(Ndarray_NDIM(Dd), ==, 2);\n assert_cmp(Ndarray_NDIM(DX), ==, 3);\n assert_cmp(Ndarray_NDIM(DW), ==, 2);\n assert_cmp(Ndarray_NDIM(Db), ==, 1);\n assert_cmp(Ndarray_NDIM(Dh), ==, 2);\n assert_cmp(Ndarray_NDIM(Dc), ==, 2);\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n int n_cells = Ndarray_DIMS(y0)[1];\n int n_in = Ndarray_DIMS(X)[2];\n assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(W)[0], ==, n_in + n_cells);\n assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(b)[0], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(DY)[0], ==, T);\n assert_cmp(Ndarray_DIMS(DY)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(DY)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Dd)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dd)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(DX)[0], ==, T);\n assert_cmp(Ndarray_DIMS(DX)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(DX)[2], ==, n_in);\n assert_cmp(Ndarray_DIMS(DW)[0], ==, n_in + n_cells);\n assert_cmp(Ndarray_DIMS(DW)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(Db)[0], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(Dh)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dh)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Dc)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dc)[1], ==, n_cells);\n\n float* x_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n float* intern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float)); // 3 gates + in\n float* Dx_h = (float*) device_malloc(n_batch * (n_in + n_cells) * sizeof(float));\n float* Dintern = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float)); // 3 gates + in\n\n // We will work inplace on DX/DW/Db.\n Ndarray_memset(Ndarray_DEV_DATA(DX), 0, T * n_batch * n_in * sizeof(float));\n Ndarray_memset(Ndarray_DEV_DATA(DW), 0, (n_in + n_cells) * n_cells * 4 * sizeof(float));\n Ndarray_memset(Ndarray_DEV_DATA(Db), 0, n_cells * 4 * sizeof(float));\n // We will work inplace on Dh.\n Ndarray_memset(Ndarray_DEV_DATA(Dh), 0, n_batch * n_cells * sizeof(float));\n // We will work inplace on Dc, and init it with Dd.\n Ndarray_memcpy(Ndarray_DEV_DATA(Dc), Ndarray_DEV_DATA(Dd), n_batch * n_cells * sizeof(float));\n\n assert_cmp(T, >, 0);\n assert_cmp(start, >=, 0);\n assert_cmp(start, <, T);\n assert_cmp(step, !=, 0);\n int end = T - 1;\n if(step < 0) {\n end = start;\n start = T - start - 1;\n }\n int t = end; // go backwards\n for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n bool right = (step > 0) ? (t - step >= start) : (t - step <= start);\n\n // TODO: correct handling of mask in grad, fwd, initial cell,hidden, etc\n // x_h = X[t], Y[t-1]\n start_dev_kernel(copy_x_h_kernel,\n (n_batch, n_in, n_cells,\n x_h, data_ptr(X, t), right ? data_ptr(Y, t-step) : Ndarray_DEV_DATA(y0)));\n\n // intern = x_h * W\n affine_raw(\n x_h, n_batch, n_in + n_cells,\n Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n intern, n_batch, n_cells * 4,\n false, false, 0.0);\n // intern += b\n start_dev_kernel(add_bias_kernel, (\n n_batch, n_cells * 4, intern, Ndarray_DEV_DATA(b)));\n\n start_dev_kernel(lstm_bwd_kernel, (\n n_batch,\n n_in,\n n_cells,\n Ndarray_DEV_DATA(i) + t * n_batch,\n x_h,\n intern,\n right ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n data_ptr(Y, t),\n data_ptr(C, t),\n data_ptr(DY, t),\n Ndarray_DEV_DATA(Dh), // error from prev frame, excluding DY. updated below\n Ndarray_DEV_DATA(Dc), // in+out, working inplace. also error from prev frame, initially Dd\n Dintern, // out\n Ndarray_DEV_DATA(Db) // out\n ));\n\n // Dx_h = Dintern * W^T\n affine_raw(\n Dintern, n_batch, n_cells * 4,\n Ndarray_DEV_DATA(W), n_in + n_cells, n_cells * 4,\n Dx_h, n_batch, n_in + n_cells,\n false, true, 0.0);\n\n // DW += x_h^T * Dintern\n affine_raw(\n x_h, n_batch, n_in + n_cells,\n Dintern, n_batch, n_cells * 4,\n Ndarray_DEV_DATA(DW), n_in + n_cells, n_cells * 4,\n true, false);\n\n // DX[t], Dh = Dx_h\n start_dev_kernel(inv_copy_x_h_kernel,\n (n_batch, n_in, n_cells, Dx_h, data_ptr(DX, t), Ndarray_DEV_DATA(Dh)));\n }\n HANDLE_LAST_ERROR();\n\n device_free(x_h);\n device_free(intern);\n device_free(Dx_h);\n device_free(Dintern);\n '[source]#
- class returnn.native_op.NativeLstm2[source]#
Yet another LSTM kernel. This kernel is about 27% than NativeLstm, and also has some more options (like the direction). But it requires time * batch * cells more memory, thus time * batch * cells * 6 in total.
- inputs:
- param X:
(time,batch,dim*4)
- param W:
recurrent matrix. 2d (dim,dim*4)
- param y0:
initial output|hidden state. 2d (batch,dim)
- param c0:
initial cell state. 2d (batch,dim)
- param i:
index. 2d (time,batch) -> 0 or 1
- param start:
where to start. must be >=0, default is usually 0. dtype int, scalar.
- param step:
+1 for fwd, -1 for bwd direction. can also be |step|>1 for wider steps. dtype int, scalar. for bwd (<0), will start at T-start-1.
- outputs:
- param Y:
output. 3d (time,batch,dim)
- param C:
cell states. 3d (time,batch,dim). gradient ignored!
- param H:
cell-in + gates. 3d (time,batch,dim*4). gradient ignored!
- param d:
final cell state. 2d (batch,dim)
- in_info: Tuple[Dict[str]] = ({'name': 'X', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'c0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'start', 'ndim': 0, 'shape': ()}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'step', 'ndim': 0, 'shape': ()})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'Y', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'name': 'C', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 0))}, {'name': 'H', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (1, 1))}, {'name': 'd', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 1), (1, 0))})[source]#
- c_extra_support_code: Dict[str, str] = {'lstm_bwd_kernel': '\n DEF_KERNEL\n void lstm_bwd_kernel(\n int n_batch, int n_cells, const float* mask,\n float* h,\n float* prev_c,\n float* y,\n float* c,\n float* d_y,\n float* d_h,\n float* d_c,\n float* d_x,\n float* d_x0)\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int cell_idx = idx % n_cells;\n int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n float mask_b = mask[batch_idx];\n float d_y_b = (d_y[idx] + d_h[idx]) * mask_b;\n float d_c_b = d_c[idx] * mask_b;\n float prev_c_b = prev_c[idx];\n\n // cell-in + input, forget and output gates\n float cellIn = h[intern_offset];\n float inpGate = h[intern_offset + n_cells];\n float fgtGate = h[intern_offset + 2 * n_cells];\n float outGate = h[intern_offset + 3 * n_cells];\n\n float c_b = prev_c_b * fgtGate + cellIn * inpGate;\n float gc = tanhf(c_b);\n float d_outGate_in = (1.f - outGate) * outGate * gc * d_y_b;\n float d_c2 = d_c_b + outGate * d_y_b * (1.f - gc * gc);\n float d_cellIn_in = (1.f - cellIn * cellIn) * inpGate * d_c2;\n float d_inpGate_in = (1.f - inpGate) * inpGate * cellIn * d_c2;\n float d_fgtGate_in = (1.f - fgtGate) * fgtGate * prev_c_b * d_c2;\n d_c[idx] = fgtGate * d_c2 + d_c[idx] * (1.f - mask_b);\n\n d_x[intern_offset] = d_cellIn_in;\n d_x[intern_offset + n_cells] = d_inpGate_in;\n d_x[intern_offset + 2 * n_cells] = d_fgtGate_in;\n d_x[intern_offset + 3 * n_cells] = d_outGate_in;\n\n #define set_x0(off) { d_x0[off] = d_x[off] + d_x0[off] * (1.f - mask_b); }\n set_x0(intern_offset);\n set_x0(intern_offset + n_cells);\n set_x0(intern_offset + 2 * n_cells);\n set_x0(intern_offset + 3 * n_cells);\n #undef set_x0\n\n // Reset if used frame, otherwise leave as-is.\n d_h[idx] *= (1.f - mask_b);\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', 'lstm_kernel': '\n DEF_KERNEL\n void lstm_kernel(\n int n_batch, int n_cells, const float* mask,\n float* h,\n float* prev_y,\n float* prev_c,\n float* y,\n float* c,\n float* y_prev_out)\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_cells * n_batch) {\n int batch_idx = idx / n_cells;\n int cell_idx = idx % n_cells;\n int intern_offset = batch_idx * 4 * n_cells + cell_idx;\n float prev_c_b = prev_c[idx];\n float mask_b = mask[batch_idx];\n\n // cell-in + input, forget and output gates\n float cellIn = tanhf(h[intern_offset]);\n float inpGate = 1.f / (1.f + expf(-h[intern_offset + n_cells]));\n float fgtGate = 1.f / (1.f + expf(-h[intern_offset + 2 * n_cells]));\n float outGate = 1.f / (1.f + expf(-h[intern_offset + 3 * n_cells]));\n\n h[intern_offset] = cellIn;\n h[intern_offset + n_cells] = inpGate;\n h[intern_offset + 2 * n_cells] = fgtGate;\n h[intern_offset + 3 * n_cells] = outGate;\n\n float c_b = (prev_c_b * fgtGate + cellIn * inpGate) * mask_b\n + prev_c_b * (1.f - mask_b);\n c[idx] = c_b;\n float y_b = tanhf(c_b) * outGate * mask_b;\n y[idx] = y_b;\n y_prev_out[idx] = y_b + prev_y[idx] * (1.f - mask_b);\n\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n // X, W, y0, c0, i, start, step = input_names\n // Y, C, H, d = output_names\n assert(n_inputs == 7);\n assert(n_outputs == 4);\n Ndarray* X = inputs[0];\n Ndarray* W = inputs[1];\n Ndarray* y0 = inputs[2];\n Ndarray* c0 = inputs[3];\n Ndarray* i = inputs[4];\n assert_cmp(Ndarray_NDIM(inputs[5]), ==, 0);\n assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n int start = Ndarray_DEV_DATA_int32_scalar(inputs[5]);\n int step = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n Ndarray* Y = *outputs[0];\n Ndarray* C = *outputs[1];\n Ndarray* H = *outputs[2];\n Ndarray* d = *outputs[3];\n\n assert_cmp(Ndarray_NDIM(X), ==, 3);\n assert_cmp(Ndarray_NDIM(W), ==, 2);\n assert_cmp(Ndarray_NDIM(y0), ==, 2);\n assert_cmp(Ndarray_NDIM(c0), ==, 2);\n assert_cmp(Ndarray_NDIM(i), ==, 2);\n assert_cmp(Ndarray_NDIM(Y), ==, 3);\n assert_cmp(Ndarray_NDIM(C), ==, 3);\n assert_cmp(Ndarray_NDIM(H), ==, 3);\n assert_cmp(Ndarray_NDIM(d), ==, 2);\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n int n_cells = Ndarray_DIMS(y0)[1];\n assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(X)[2], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(W)[0], ==, n_cells);\n assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(H)[0], ==, T);\n assert_cmp(Ndarray_DIMS(H)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(H)[2], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(d)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(d)[1], ==, n_cells);\n\n if(T == 0) {\n Ndarray_memcpy(Ndarray_DEV_DATA(d), Ndarray_DEV_DATA(c0), n_batch * n_cells * sizeof(float));\n\n } else { // T > 0\n // It makes the backprop with step<0 easier to implement,\n // esp. the DW = Y[0..T-2]^T * DX[1..T-1] calculation,\n // if we can have Y[t] = 0 where mask[t] = 0.\n // That is why we need to keep track of Y[t-1] explicitly.\n float* y_prev = (float*) device_malloc(n_batch * n_cells * sizeof(float));\n\n // H = X\n Ndarray_memcpy(Ndarray_DEV_DATA(H), Ndarray_DEV_DATA(X), T * n_batch * n_cells * 4 * sizeof(float));\n\n assert_cmp(T, >, 0);\n assert_cmp(start, >=, 0);\n assert_cmp(start, <, T);\n assert_cmp(step, !=, 0);\n int end = T - 1;\n if(step < 0) {\n end = 0;\n start = T - start - 1;\n }\n int t = start;\n for(; (step > 0) ? (t <= end) : (t >= end); t += step) {\n // H[t] += Y[t-1] * W\n affine_raw(\n (t != start) ? y_prev : Ndarray_DEV_DATA(y0), n_batch, n_cells,\n Ndarray_DEV_DATA(W), n_cells, n_cells * 4,\n data_ptr(H, t), n_batch, n_cells * 4,\n false, false);\n\n start_dev_kernel(lstm_kernel, (\n n_batch,\n n_cells,\n Ndarray_DEV_DATA(i) + t * n_batch,\n data_ptr(H, t), // inplace\n (t != start) ? y_prev : Ndarray_DEV_DATA(y0),\n (t != start) ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n data_ptr(Y, t), // out\n data_ptr(C, t), // out\n y_prev // out\n ));\n }\n HANDLE_LAST_ERROR();\n\n Ndarray_memcpy(Ndarray_DEV_DATA(d), data_ptr(C, t - step), n_batch * n_cells * sizeof(float));\n\n device_free(y_prev);\n }\n '[source]#
- c_bw_code: str = '\n // X, W, y0, c0, i, start, step, Y, C, H, DY, Dd = input_names\n // DX, DW, Dy0, Dc0 = output_names\n assert(n_inputs == 12);\n assert(n_outputs == 4);\n Ndarray* X = inputs[0];\n Ndarray* W = inputs[1];\n Ndarray* y0 = inputs[2];\n Ndarray* c0 = inputs[3];\n Ndarray* i = inputs[4];\n assert_cmp(Ndarray_NDIM(inputs[5]), ==, 0);\n assert_cmp(Ndarray_NDIM(inputs[6]), ==, 0);\n int start = Ndarray_DEV_DATA_int32_scalar(inputs[5]);\n int step = Ndarray_DEV_DATA_int32_scalar(inputs[6]);\n Ndarray* Y = inputs[7];\n Ndarray* C = inputs[8];\n Ndarray* H = inputs[9];\n Ndarray* DY = inputs[10];\n Ndarray* Dd = inputs[11];\n Ndarray* DX = *outputs[0];\n Ndarray* DW = *outputs[1];\n Ndarray* Dy0 = *outputs[2];\n Ndarray* Dc0 = *outputs[3];\n\n assert_cmp(Ndarray_NDIM(X), ==, 3);\n assert_cmp(Ndarray_NDIM(W), ==, 2);\n assert_cmp(Ndarray_NDIM(y0), ==, 2);\n assert_cmp(Ndarray_NDIM(c0), ==, 2);\n assert_cmp(Ndarray_NDIM(i), ==, 2);\n assert_cmp(Ndarray_NDIM(Y), ==, 3);\n assert_cmp(Ndarray_NDIM(C), ==, 3);\n assert_cmp(Ndarray_NDIM(H), ==, 3);\n assert_cmp(Ndarray_NDIM(DY), ==, 3);\n assert_cmp(Ndarray_NDIM(Dd), ==, 2);\n assert_cmp(Ndarray_NDIM(DX), ==, 3);\n assert_cmp(Ndarray_NDIM(DW), ==, 2);\n assert_cmp(Ndarray_NDIM(Dy0), ==, 2);\n assert_cmp(Ndarray_NDIM(Dc0), ==, 2);\n long T = Ndarray_DIMS(i)[0];\n int n_batch = Ndarray_DIMS(i)[1];\n int n_cells = Ndarray_DIMS(y0)[1];\n assert_cmp(Ndarray_DIMS(X)[0], ==, T);\n assert_cmp(Ndarray_DIMS(X)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(X)[2], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(W)[0], ==, n_cells);\n assert_cmp(Ndarray_DIMS(W)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(y0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(y0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(c0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(c0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Y)[0], ==, T);\n assert_cmp(Ndarray_DIMS(Y)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Y)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(C)[0], ==, T);\n assert_cmp(Ndarray_DIMS(C)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(C)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(H)[0], ==, T);\n assert_cmp(Ndarray_DIMS(H)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(H)[2], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(DY)[0], ==, T);\n assert_cmp(Ndarray_DIMS(DY)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(DY)[2], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Dd)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dd)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(DX)[0], ==, T);\n assert_cmp(Ndarray_DIMS(DX)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(DX)[2], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(DW)[0], ==, n_cells);\n assert_cmp(Ndarray_DIMS(DW)[1], ==, n_cells * 4);\n assert_cmp(Ndarray_DIMS(Dy0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dy0)[1], ==, n_cells);\n assert_cmp(Ndarray_DIMS(Dc0)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(Dc0)[1], ==, n_cells);\n\n // We will work inplace on DW.\n Ndarray_memset(Ndarray_DEV_DATA(DW), 0, n_cells * n_cells * 4 * sizeof(float));\n // We will work inplace on (Dy0) DY[t], initially 0.\n Ndarray_memset(Ndarray_DEV_DATA(Dy0), 0, n_batch * n_cells * sizeof(float));\n // We will work inplace on (Dc0) DC[t], and init it with Dd.\n Ndarray_memcpy(Ndarray_DEV_DATA(Dc0), Ndarray_DEV_DATA(Dd), n_batch * n_cells * sizeof(float));\n\n if(T == 0) {\n // just do nothing. at least do not crash\n\n } else {\n // Need to keep track of (logical) DX[0], which in practice (masking, step<0)\n // can be different from data_ptr(DX, start).\n float* dx0 = (float*) device_malloc(n_batch * n_cells * 4 * sizeof(float));\n Ndarray_memset(dx0, 0, n_batch * n_cells * 4 * sizeof(float));\n\n assert_cmp(T, >, 0);\n assert_cmp(start, >=, 0);\n assert_cmp(start, <, T);\n assert_cmp(step, !=, 0);\n int abs_step = std::abs(step);\n\n if(abs_step > 1 || start > 0)\n // Normally the kernel would visit and reset all DX.\n // But with abs_step>1 or start>0, we will not visit all. Reset now.\n Ndarray_memset(Ndarray_DEV_DATA(DX), 0, T * n_batch * n_cells * 4 * sizeof(float));\n\n // e.g.:\n // step=1, start=0, T=10 -> num_steps=10=T\n // step=5, start=0, T=10 -> num_steps=2=T/step\n // step=5, start=0, T=9 -> num_steps=2=(T+step-1)/step\n // step=5, start=0, T=6 -> num_steps=2=(T+step-1)/step\n // step=5, start=0, T=5 -> num_steps=1=(T+step-1)/step\n // step=5, start=4, T=10 -> num_steps=2=(T-start+step-1)/step\n // step=-5, start=0, T=10 -> num_steps=2=T/abs_step\n // step=-5, start=0, T=9 -> num_steps=2=(T+abs_step-1)/abs_step\n // step=-5, start=4, T=10 -> num_steps=2=(T-start+abs_step-1)/abs_step\n int num_steps = (T - start + abs_step - 1) / abs_step;\n assert_cmp(num_steps, >, 0);\n if(step < 0)\n start = T - start - 1;\n int end = start + (num_steps - 1) * step; // inclusive\n assert_cmp(end, >=, 0);\n assert_cmp(end, <, T);\n int t = end; // go backwards\n for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n bool right = (step > 0) ? (t - step >= start) : (t - step <= start);\n\n start_dev_kernel(lstm_bwd_kernel, (\n n_batch,\n n_cells,\n Ndarray_DEV_DATA(i) + t * n_batch,\n data_ptr(H, t),\n right ? data_ptr(C, t-step) : Ndarray_DEV_DATA(c0),\n data_ptr(Y, t),\n data_ptr(C, t),\n data_ptr(DY, t),\n Ndarray_DEV_DATA(Dy0), // in+out, error from prev frame, excluding DY. reset here, updated below\n Ndarray_DEV_DATA(Dc0), // in+out, working inplace. also error from prev frame, initially Dd\n data_ptr(DX, t), // out\n dx0 // out\n ));\n\n // (Dy0) DY[t-1] += DX[t] * W^T\n affine_raw(\n data_ptr(DX, t), n_batch, n_cells * 4,\n Ndarray_DEV_DATA(W), n_cells, n_cells * 4,\n Ndarray_DEV_DATA(Dy0), n_batch, n_cells,\n false, true);\n }\n\n //DW = Y[0..T-2]^T * DX[1..T-1] (if step==1)\n if(num_steps > 1) {\n if(abs_step == 1) {\n affine_raw(\n data_ptr(Y, std::min(start, end) + std::max(0, -step)), (num_steps - 1) * n_batch, n_cells,\n data_ptr(DX, std::min(start, end) + std::max(0, step)), (num_steps - 1) * n_batch, n_cells * 4,\n Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n true, false, 0.0f, 1.0f);\n } else {\n // Unfortunately we cannot do efficient striding. Thus loop again.\n t = end - step; // one before\n for(; (step > 0) ? (t >= start) : (t <= start); t -= step) {\n affine_raw(\n data_ptr(Y, t), n_batch, n_cells,\n data_ptr(DX, t + step), n_batch, n_cells * 4,\n Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n true, false);\n }\n }\n }\n HANDLE_LAST_ERROR();\n\n //DW += y0^T * DX[0]\n affine_raw(\n Ndarray_DEV_DATA(y0), n_batch, n_cells,\n dx0, n_batch, n_cells * 4,\n Ndarray_DEV_DATA(DW), n_cells, n_cells * 4,\n true, false);\n\n device_free(dx0);\n }\n '[source]#
- class returnn.native_op.TwoDLSTM[source]#
- inputs:
- param X:
{input,output,forget,lambda} gate + cell state. 3d (timeT,timeS,batch,dim*5) // dim*5 or dim*1 ?
- param V_h:
recurrent matrix. 2d (dim,dim*5)
- param V_v:
recurrent matrix. 2d (dim,dim*5)
- param W:
recurrent matrix. 2d (dim,dim*5)
- param b:
bias. 2d (batch,dim)
- param ptr_storage:
ptr_storage. 1d (1 * 5 * max_diag_size * sizeof(float*) / sizeof(float))
- param valid:
used internally to store which cells are valid (have to be computed). 1d (1 * max_diag_size * n_minibatch)
- param workmem2:
used internally. 3d (H[0], H[2], H[3])
- param sizes:
height (target) x width (source) of the unpadded sentences. 2d (batch, 2)
- outputs:
- param CompleteY:
output. 4d (timeS,timeT,batch,dim)
- param H:
gates and cell state. 4d (timeS,timeT,batch,dim*5) ?
- param d:
final cell state. 3d (timeT,batch,dim)
- in_info: Tuple[Dict[str]] = ({'bw_out_var': {'shape': ((0, 0), (0, 1), (0, 2), (0, 3))}, 'name': 'X', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'name': 'V_h', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'V_v', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'W', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'b', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'ptr_storage_fwd', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'ptr_storage_bwd', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'valid', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'workmem', 'ndim': 5, 'need_contiguous': True, 'shape': (None, None, None, None, None)}, {'gradient': 'disconnected', 'name': 'workmem2', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'sizes', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'DYDummy', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'initialState', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'initialOutput', 'ndim': 4, 'need_contiguous': True, 'shape': (None, None, None, None)}, {'gradient': 'disconnected', 'name': 'iteration', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'CompleteY', 'ndim': 4, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2), (1, 0))}, {'name': 'H', 'ndim': 4, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2), (3, 1))})[source]#
- classmethod grad_input_map(X, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, workmem2, sizes, DYDummy, initialState, initialOutput, iteration, CompleteY, H, DCompleteY, DH)[source]#
- classmethod map_layer_inputs_to_op(Zs, Zt, V_h, V_v, W, b, ptr_storage)[source]#
- Parameters:
inputs –
- Returns:
inputs
- c_extra_support_code: Dict[str, str] = {'01_repvec': '\n DEF_KERNEL\n void repvec(const float * v, int vlen, int nCopies, float * dest)\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < vlen * nCopies)\n {\n dest[idx] = v[idx % vlen];\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '02_fillmat': '\n void fillmat(OpKernelContext* context, const Ndarray * b, Ndarray * dst)\n {\n const float * data_b = Ndarray_DEV_DATA(b);\n float * data_dst = Ndarray_DEV_DATA(dst);\n Ndarray_DIMS_Type dims_b = Ndarray_HOST_DIMS(b);\n int dims_dst[2];\n lastTwoDims(dst, dims_dst);\n assert(dims_b[0] == dims_dst[1]);\n start_dev_kernel(repvec, (\n data_b,\n dims_dst[1],\n Ndarray_SIZE(dst)/dims_dst[1],\n data_dst\n ));\n }\n ', '03_data_ptr': '\n // if nd is 2 then assume a weight matrix and just return beginning of data\n // else nd should be 3 and we pick the x part\n float* data_ptr(const Ndarray* a, int y, int x, int outer_dim=0) {\n assert(Ndarray_NDIM(a) == 2 || Ndarray_NDIM(a) == 4 || Ndarray_NDIM(a) == 5);\n if(Ndarray_NDIM(a) == 2)\n return Ndarray_DEV_DATA(a);\n else if(Ndarray_NDIM(a) == 4) {\n Ndarray_DIMS_Type dims = Ndarray_HOST_DIMS(a);\n return Ndarray_DEV_DATA(a)\n + y * dims[1] * dims[2] * dims[3]\n + x * dims[2] * dims[3]; // row-major or minor?\n }\n else {\n Ndarray_DIMS_Type dims = Ndarray_HOST_DIMS(a);\n return Ndarray_DEV_DATA(a)\n + outer_dim * dims[1] * dims[2] * dims[3] * dims[4]\n + y * dims[2] * dims[3] * dims[4]\n + x * dims[3] * dims[4];\n }\n }\n\n float * data_ptr(Ndarray * a, int y, int x, int outer_dim=0)\n {\n const Ndarray * ca = a;\n return const_cast<float *>(data_ptr(ca, y, x, outer_dim));\n }\n ', '04_affine_y_x_batched_onedir': "\n // ys and xs: base indices, offset by y_A, x_A (-1,0,1)\n void affine_y_x_batched_onedir(OpKernelContext* context, int y_A, int x_A,\n const Ndarray * A1,\n const Ndarray * B1,\n Ndarray * C1,\n const std::vector<int>& ys, const std::vector<int>& xs, Ndarray * ptr_storage, int height, int width,\n cudaStream_t stream = 0, bool transpose_A=false, bool transpose_B=false)\n {\n const int batch_size = ys.size();\n if(batch_size == 0)\n {\n return;\n }\n std::vector<const float*> ABC_ptrs(3 * 1 * batch_size); //content layout: 3x1xbatch_size (3: A,B,C, 1: dirs)\n\n for(int i = 0; i < batch_size; ++i)\n {\n //A\n //y not flipped, x not flipped\n ABC_ptrs[0 * 1 * batch_size + 0 * batch_size + i] = data_ptr(A1, y_A + ys[i], x_A + xs[i]);\n\n //B\n //index doesent matter here, as B is only 2dimensional\n ABC_ptrs[1 * 1 * batch_size + 0 * batch_size + i] = data_ptr(B1, 0, 0);\n\n //we write the result (C) in the same destination (y,x) as the source (A), so we don't need to flip later\n //C\n //y not flipped, x not flipped\n ABC_ptrs[2 * 1 * batch_size + 0 * batch_size + i] = data_ptr(C1, ys[i], xs[i]);\n }\n const float ** ptr_storage_data = reinterpret_cast<const float**>(&(ABC_ptrs[0]));\n const float ** A_ptrs_data = (const float**) ptr_storage_data + 0 * 1 * batch_size;\n const float ** B_ptrs_data = (const float**) ptr_storage_data + 1 * 1 * batch_size;\n const float ** C_ptrs_data = ptr_storage_data + 2 * 1 * batch_size;\n\n int A_dim[2], B_dim[2];\n lastTwoDims(A1, A_dim);\n lastTwoDims(B1, B_dim);\n int ldB = B_dim[1];\n int ldA = A_dim[1];\n char transA = transpose_A ? 'T' : 'N';\n char transB = transpose_B ? 'T' : 'N';\n if (transpose_A)\n {\n std::swap(A_dim[0], A_dim[1]);\n }\n if (transpose_B)\n {\n std::swap(B_dim[0], B_dim[1]);\n }\n\n const float alpha = 1;\n const float beta = 1;\n\n Ndarray_sgemm_batched(\n transB, transA, B_dim[1], A_dim[0], A_dim[1], &alpha,\n B_ptrs_data, ldB, A_ptrs_data, ldA, &beta,\n C_ptrs_data, B_dim[1], 1 * batch_size, batch_size == 1);\n }\n ", '05_lstm_stable_cell_kernel_batched': '\n DEF_KERNEL\n void lstm_stable_cell_kernel_batched(float ** datas, const float ** old_state_ys, const float ** old_state_xs,\n float ** outputs, const float ** valids, int n_outer_batch, int n_cells, int n_minibatch)\n {\n //layout (for every outer batch):\n //data[0*n_cells..1*n_cells-1] : input gate\n //data[1*n_cells..2*n_cells-1] : forget gate\n //data[2*n_cells..3*n_cells-1] : lambda gate\n //data[3*n_cells..4*n_cells-1] : output gate\n //data[5*n_cells..6*n_cells-1] : cell state\n //output[0*n_cells..1*n_cells-1]: cell output\n //valids: either 1.0 or 0.0, indicating if the current (y,x) position\n // is still inside the image in this minibatch\n //repeated for every mini-batch\n\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_outer_batch * n_cells * n_minibatch)\n {\n int size_per_outer_batch = n_cells * n_minibatch;\n int outer_batch_idx = idx / size_per_outer_batch;\n float * data = datas[outer_batch_idx];\n const float * old_state_y = old_state_ys[outer_batch_idx];\n const float * old_state_x = old_state_xs[outer_batch_idx];\n float * output = outputs[outer_batch_idx];\n const float * valid = valids[outer_batch_idx];\n\n int inner_idx = idx % size_per_outer_batch;\n int minibatch_idx = inner_idx / n_cells;\n int batch_offset = minibatch_idx * 5 * n_cells;\n int cell_offset = inner_idx % n_cells;\n int start = batch_offset + cell_offset;\n\n float valid_batch = valid[minibatch_idx];\n\n //input, forget and output gates\n float inpGate = 1.f / (1.f + expf(-data[start]));\n float fgtGate = 1.f / (1.f + expf(-data[start + n_cells]));\n float lambdaGate = 1.f / (1.f + expf(-data[start + 2 * n_cells]));\n float outGate = 1.f / (1.f + expf(-data[start + 3 * n_cells]));\n float state = inpGate * tanhf(data[start + 4 * n_cells]);\n if (old_state_y)\n {\n state += fgtGate * lambdaGate * old_state_y[start];\n }\n if (old_state_x)\n {\n state += fgtGate * (1.0f - lambdaGate) * old_state_x[start];\n }\n state *= valid_batch;\n\n //cell output\n output[inner_idx] = outGate * tanhf(state) * valid_batch;\n\n data[start] = inpGate;\n data[start + n_cells] = fgtGate;\n data[start + 2 * n_cells] = lambdaGate;\n data[start + 3 * n_cells] = outGate;\n data[start + 4 * n_cells] = state;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '06_do_lstm_batched_onedir': '\n // H, CompleteY, ys, xs, ptr_storage\n void do_lstm_batched_onedir(\n OpKernelContext* context, Ndarray* H, Ndarray* initialState, float iteration, Ndarray* completeOut,\n const std::vector<int>& ys, const std::vector<int>& xs,\n Ndarray* ptr_storage, Ndarray* valid_storage, Ndarray* sizes)\n {\n int n_outer_batch = ys.size();\n Ndarray_DIMS_Type H_dims = Ndarray_HOST_DIMS(H);\n int height = H_dims[0];\n int width = H_dims[1];\n int n_minibatch = H_dims[2];\n int n_cells = H_dims[3] / 5;\n assert(H_dims[3] % 5 == 0); //4 gates + cell\n\n std::vector<float*> ptrs(1 * 5 * n_outer_batch); //1 dirs * 5 arrays\n std::vector<float> valid(1 * n_minibatch * n_outer_batch, 1.0f);\n\n float* h_sizes; // the sizes array is stored on the GPU, we have to copy it to the CPU\n int dsize =\n (n_outer_batch) * (n_minibatch) * sizeof(float) * 2; // (*2), because we have 2 (height, width) numbers\n h_sizes = (float*)malloc(dsize);\n HANDLE_ERROR(cudaMemcpy(h_sizes, Ndarray_DEV_DATA(sizes), dsize, cudaMemcpyDeviceToHost));\n\n for(int i = 0; i < n_outer_batch; ++i)\n {\n int y = ys[i];\n int x = xs[i];\n\n //fill valid\n for(int n = 0; n < n_minibatch; ++n) // iterates through all examples in the current batch\n {\n float img_height = *(h_sizes + 2*n);\n float img_width = *(h_sizes + 2*n +1);\n\n valid[i * 1 * n_minibatch + 0 * n_minibatch + n] = float(y < img_height && x < img_width);\n }\n\n //y not flipped, x not flipped\n float * data_H = data_ptr(H, y, x);\n\n //y not flipped, x not flipped\n float * data_old_state_y;\n data_old_state_y = y > 0 ? data_ptr(H, y - 1, x) + 4 * n_cells : data_ptr(initialState, 0, x) + 4 * n_cells;\n\n //y not flipped, x not flipped\n float * data_old_state_x = x > 0 ? data_ptr(H, y, x - 1) + 4 * n_cells : 0;\n\n //y not flipped, x not flipped\n float * data_out = data_ptr(completeOut, y, x);\n\n float * valid = Ndarray_DEV_DATA(valid_storage) + i * 1 * n_minibatch + 0 * n_minibatch;\n\n ptrs[0 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_H;\n ptrs[1 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_old_state_y;\n ptrs[2 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_old_state_x;\n ptrs[3 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_out;\n ptrs[4 * 1 * n_outer_batch + 0 * n_outer_batch + i] = valid;\n }\n\n free(h_sizes);\n\n HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(valid_storage), valid.data(),\n valid.size() * sizeof(float), cudaMemcpyHostToDevice));\n HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(ptr_storage), ptrs.data(),\n ptrs.size() * sizeof(float*), cudaMemcpyHostToDevice));\n float ** ptr_storage_data = reinterpret_cast<float**>(Ndarray_DEV_DATA(ptr_storage));\n float ** data_Hs = ptr_storage_data + 0 * 1 * n_outer_batch;\n const float ** data_old_state_ys = (const float**) ptr_storage_data + 1 * 1 * n_outer_batch;\n const float ** data_old_state_xs = (const float**) ptr_storage_data + 2 * 1 * n_outer_batch;\n float ** data_outs = ptr_storage_data + 3 * 1 * n_outer_batch;\n const float ** data_valids = (const float**) (ptr_storage_data + 4 * 1 * n_outer_batch);\n\n start_dev_kernel(lstm_stable_cell_kernel_batched, (\n data_Hs,\n data_old_state_ys,\n data_old_state_xs,\n data_outs,\n data_valids,\n 1 * n_outer_batch,\n n_cells,\n n_minibatch\n ));\n }\n ', '07_lstm_bwd_stable_cell_kernel_batched': '\n DEF_KERNEL\n void lstm_bwd_stable_cell_kernel_batched(float ** deltas, const float ** epsilons,\n const float ** next_epsilon_ys, const float ** next_epsilon_xs, float ** epsilon_ys, float ** epsilon_xs,\n const float ** last_state_ys, const float ** last_state_xs, const float ** Ys, const float ** valids,\n int n_outer_batch, int n_cells, int n_minibatch)\n {\n //layout (for every outer batch):\n //delta[0*n_cells..1*n_cells-1] : input gate\n //delta[1*n_cells..2*n_cells-1] : forget gate\n //delta[2*n_cells..3*n_cells-1] : lambda gate\n //delta[3*n_cells..4*n_cells-1] : output gate\n //delta[4*n_cells..5*n_cells-1] : cell state\n //epsilon[0*n_cells..1*n_cells-1]: cell output derivative\n //next_epsilon_y[0*n_cells..1*n_cells-1]: cell state derivative * forget_gate * lambda_gate (of next timestep)\n //next_epsilon_x[0*n_cells..1*n_cells-1]:\n // cell state derivative * forget_gate * (-1*)lambda_gate (of next timestep)\n //epsilon_y[0*n_cells..1*n_cells-1]:\n // cell state derivative * forget_gate * lambda_gate (of current timestep, as output)\n //epsilon_x[0*n_cells..1*n_cells-1]:\n // cell state derivative * forget_gate * (1-lambda_gate) (of current timestep, as output)\n //valids: either 1.0 or 0.0, indicating if the current (y,x) position\n // is still inside the image in this minibatch\n //repeated for every mini-batch\n\n float near_zero = 0.00000000001f;\n\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while (idx < n_outer_batch * n_cells * n_minibatch)\n {\n int size_per_outer_batch = n_cells * n_minibatch;\n int outer_batch_idx = idx / size_per_outer_batch;\n const float * valid = valids[outer_batch_idx];\n\n float * delta = deltas[outer_batch_idx];\n const float * epsilon = epsilons[outer_batch_idx];\n const float * next_epsilon_y = next_epsilon_ys[outer_batch_idx];\n const float * next_epsilon_x = next_epsilon_xs[outer_batch_idx];\n float * epsilon_y = epsilon_ys[outer_batch_idx];\n float * epsilon_x = epsilon_xs[outer_batch_idx];\n const float * last_state_y = last_state_ys[outer_batch_idx];\n const float * last_state_x = last_state_xs[outer_batch_idx];\n const float * Y = Ys[outer_batch_idx];\n\n int inner_idx = idx % size_per_outer_batch;\n int minibatch_idx = inner_idx / n_cells;\n int batch_offset = minibatch_idx * 5 * n_cells;\n int cell_offset = inner_idx % n_cells;\n int start = batch_offset + cell_offset;\n float valid_batch = valid[minibatch_idx];\n\n float inpGate = delta[start];\n float fgtGate = delta[start + n_cells];\n float lambdaGate = delta[start + 2 * n_cells];\n float outGate = delta[start + 3 * n_cells];\n float state = delta[start + 4 * n_cells];\n float lastState_y = last_state_y ? last_state_y[start] : 0.f;\n float lastState_x = last_state_x ? last_state_x[start] : 0.f;\n float eps = epsilon[inner_idx];\n\n //avoid division by 0\n float gc = 0.f; //g(c(t))\n float gzc = 0.f; //g(z_c(t))\n if (outGate < -near_zero || outGate > near_zero)\n {\n gc = Y[inner_idx] / outGate;\n }\n\n if (inpGate < -near_zero || inpGate > near_zero)\n {\n gzc = (state - fgtGate * lambdaGate * lastState_y - fgtGate * (1.0f - lambdaGate) * lastState_x) / inpGate;\n }\n\n //delta_output\n delta[start + 3 * n_cells] = outGate * (1.f - outGate) * gc * eps * valid_batch;\n\n //epsilon_c\n float epsilon_c = (1.f - (gc * gc)) * outGate * eps;\n if (next_epsilon_y)\n {\n epsilon_c += next_epsilon_y[inner_idx];\n }\n if (next_epsilon_x)\n {\n epsilon_c += next_epsilon_x[inner_idx];\n }\n\n //TODO: clip epsilon_c?\n //epsilon_c = max(epsilon_c, -10.f);\n //epsilon_c = min(epsilon_c, 10.f);\n\n epsilon_y[inner_idx] = epsilon_c * fgtGate * lambdaGate * valid_batch;\n epsilon_x[inner_idx] = epsilon_c * fgtGate * (1.0f - lambdaGate) * valid_batch;\n\n //delta_cell\n delta[start + 4 * n_cells] = inpGate * (1.f - (gzc * gzc)) * epsilon_c * valid_batch;\n\n //delta_forget\n delta[start + n_cells] = fgtGate * (1.f - fgtGate) * epsilon_c *\n (lastState_y * lambdaGate + lastState_x * (1.0f - lambdaGate)) * valid_batch;\n\n //delta_lambda\n delta[start + 2 * n_cells] = fgtGate * lambdaGate * (1.f - lambdaGate) * epsilon_c\n * (lastState_y - lastState_x) * valid_batch;\n\n //delta_input\n delta[start] = inpGate * (1.f - inpGate) * gzc * epsilon_c * valid_batch;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '08_do_lstm_bwd_batched_onedir': '\n //epsilon are the derivates w.r.t. Z, delta stores the gate and cell activations\n // and will store the derivatives later\n void do_lstm_bwd_batched_onedir(OpKernelContext* context, Ndarray * delta1, Ndarray * epsilon1,\n const Ndarray* CompleteY, Ndarray * workmem1,\n int height, int width, const std::vector<int>& ys, const std::vector<int>& xs,\n Ndarray * ptr_storage, Ndarray * valid_storage, Ndarray* sizes, int iteration, cudaStream_t stream=0)\n {\n int n_outer_batch = ys.size();\n int dims[2];\n lastTwoDims(delta1, dims);\n assert(dims[1] % 5 == 0); //4 gates + cell\n int n_cells = dims[1] / 5;\n int n_minibatch = dims[0];\n\n std::vector<const float*> ptrs(1 * 10 * n_outer_batch); //1 dirs * 10 arrays\n std::vector<float> valid(1 * n_minibatch * n_outer_batch, 1.0f);\n\n float* h_sizes; // the sizes array is stored on the GPU, we have to copy it to the CPU\n int dsize =\n (n_outer_batch) * (n_minibatch) * sizeof(float) * 2; // (*2), because we have 2 (height, width) numbers\n h_sizes = (float*)malloc(dsize);\n HANDLE_ERROR(cudaMemcpy(h_sizes, Ndarray_DEV_DATA(sizes), dsize, cudaMemcpyDeviceToHost));\n\n for(int i = 0; i < n_outer_batch; ++i)\n {\n int y = ys[i];\n int x = xs[i];\n //fill valid\n for(int n = 0; n < n_minibatch; ++n)\n {\n //these are the sizes of a single image in the batch, while height/width are the maximum sizes in the batch\n float img_height = *(h_sizes + 2*n);\n float img_width = *(h_sizes + 2*n +1);\n valid[i * 1 * n_minibatch + 0 * n_minibatch + n] = float(y < img_height && x < img_width);\n }\n\n bool botBorder = (y == height-1);\n bool rightBorder = (x == width-1);\n int yp1 = y + 1;\n int xp1 = x + 1;\n int ym1 = y - 1;\n int xm1 = x - 1;\n\n float * data_delta1 = data_ptr(delta1, y, x);\n const float * data_epsilon1 = data_ptr(epsilon1, y, x);\n const float * data_next_epsilon_y1 = botBorder ? 0 : data_ptr(workmem1, (iteration-1)%2, x, 0);\n const float * data_next_epsilon_x1 = rightBorder ? 0 : data_ptr(workmem1, (iteration-1)%2, xp1, 1);\n float * data_epsilon_y1 = data_ptr(workmem1, iteration%2, x, 0);\n float * data_epsilon_x1 = data_ptr(workmem1, iteration%2, x, 1);\n const float * data_last_state_y1 = y > 0 ? data_ptr(delta1, ym1, x) + 4 * n_cells : 0;\n const float * data_last_state_x1 = x > 0 ? data_ptr(delta1, y, xm1) + 4 * n_cells : 0;\n const float * data_Y1 = data_ptr(CompleteY, y, x);\n float * valid1 = Ndarray_DEV_DATA(valid_storage) + i * 1 * n_minibatch + 0 * n_minibatch;\n\n ptrs[0 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_delta1;\n ptrs[1 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon1;\n ptrs[2 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_next_epsilon_y1;\n ptrs[3 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_next_epsilon_x1;\n ptrs[4 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon_y1;\n ptrs[5 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_epsilon_x1;\n ptrs[6 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_last_state_y1;\n ptrs[7 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_last_state_x1;\n ptrs[8 * 1 * n_outer_batch + 0 * n_outer_batch + i] = data_Y1;\n ptrs[9 * 1 * n_outer_batch + 0 * n_outer_batch + i] = valid1;\n }\n\n free(h_sizes);\n\n HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(valid_storage), valid.data(),\n valid.size() * sizeof(float), cudaMemcpyHostToDevice));\n HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(ptr_storage), ptrs.data(),\n ptrs.size() * sizeof(float*), cudaMemcpyHostToDevice));\n float ** ptr_storage_data = reinterpret_cast<float**>(Ndarray_DEV_DATA(ptr_storage));\n float ** data_deltas = ptr_storage_data + 0 * 1 * n_outer_batch;\n const float ** data_epsilons = (const float**) ptr_storage_data + 1 * 1 * n_outer_batch;\n const float ** data_next_epsilon_ys = (const float**) ptr_storage_data + 2 * 1 * n_outer_batch;\n const float ** data_next_epsilon_xs = (const float**) ptr_storage_data + 3 * 1 * n_outer_batch;\n float ** data_epsilon_ys = ptr_storage_data + 4 * 1 * n_outer_batch;\n float ** data_epsilon_xs = ptr_storage_data + 5 * 1 * n_outer_batch;\n const float ** data_last_state_ys = (const float**) ptr_storage_data + 6 * 1 * n_outer_batch;\n const float ** data_last_state_xs = (const float**) ptr_storage_data + 7 * 1 * n_outer_batch;\n const float ** data_Ys = (const float**) ptr_storage_data + 8 * 1 * n_outer_batch;\n const float ** data_valids = (const float**) (ptr_storage_data + 9 * 1 * n_outer_batch);\n\n start_dev_kernel(lstm_bwd_stable_cell_kernel_batched, (\n data_deltas,\n data_epsilons,\n data_next_epsilon_ys,\n data_next_epsilon_xs,\n data_epsilon_ys,\n data_epsilon_xs,\n data_last_state_ys,\n data_last_state_xs,\n data_Ys,\n data_valids,\n 1 * n_outer_batch,\n n_cells,\n n_minibatch\n ));\n }\n '}[source]#
- c_fw_code: str = '\n // X*, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, sizes, DYDummy,\n // initialState, initialOutput, iteration = input_names (*: inplace)\n // CompleteY, H = output_names\n\n assert(n_inputs == 15);\n assert(n_outputs == 2);\n\n Ndarray* X = inputs[0];\n Ndarray* V_h = inputs[1];\n Ndarray* V_v = inputs[2];\n Ndarray* W = inputs[3];\n Ndarray* b = inputs[4];\n Ndarray* ptr_storage_fwd = inputs[5];\n Ndarray* ptr_storage_bwd = inputs[6]; // not used in fwd\n Ndarray* valid = inputs[7];\n Ndarray* workmem = inputs[8]; // not used in fwd\n Ndarray* workmem2 = inputs[9]; // not used in fwd\n Ndarray* sizes = inputs[10];\n Ndarray* DYDummy = inputs[11]; // not used in fwd\n Ndarray* initialState = inputs[12];\n Ndarray* initialOutput = inputs[13];\n Ndarray* iteration = inputs[14];\n\n assert(sizeof(float) == 4 && "ptr_storage has wrong size if sizeof(float) != 4");\n assert(sizeof(float*) == 8 && "ptr_storage has wrong size if sizeof(float*) != 8");\n\n Ndarray* CompleteY = *outputs[0];\n Ndarray* H = *outputs[1];\n\n Ndarray_DIMS_Type X_dim = Ndarray_DIMS(X);\n Ndarray_DIMS_Type W_dim = Ndarray_DIMS(W);\n Ndarray_DIMS_Type V_dim = Ndarray_DIMS(V_h);\n assert(W_dim[1] %% 5 == 0 && "W has wrong shape");\n assert(5 * V_dim[0] == V_dim[1] && "V has wrong shape");\n assert(W_dim[1] == V_dim[1]);\n assert(W_dim[0] == X_dim[3]);\n const long long Y_dim[] = {X_dim[0], X_dim[1], X_dim[2], W_dim[1] / 5};\n const long long H_dim[] = {X_dim[0], X_dim[1], X_dim[2], W_dim[1]};\n const long long height = X_dim[0];\n const long long width = X_dim[1];\n const long long n_minibatch = X_dim[2];\n const long long max_diag_size = std::min(height, width);\n const long long n_diags = width + height - 1;\n\n //H = XW (+ b, currently always 0)\n fillmat(context, b, H);\n affine_global(X, W, H);\n\n // The iteration is stored on the GPU, but we need it on the CPU to controll the programm flow (use explicitly\n // provided previous state/output on first iteration). Maybe this could be optimized by storing the tensor\n // directly on the CPU?\n // We only look at the first value of the tensor with shape (batch,), as every entry has the same value by design\n float h_iteration;\n HANDLE_ERROR(cudaMemcpy(&h_iteration, Ndarray_DEV_DATA(iteration), 1*sizeof(float), cudaMemcpyDeviceToHost));\n\n for(long long diag = 0; diag < n_diags; ++diag)\n {\n int diag_size = min(diag+1, min((long long) abs(n_diags-diag), min(width, height)));\n int y_high = min(diag, height-1);\n int x_low = max(diag-height+1,(long long) 0);\n std::vector<int> ys_h, xs_h, ys_v, xs_v, ys, xs;\n for(int idx = 0; idx < diag_size; ++idx)\n {\n int y = y_high - idx;\n int x = x_low + idx;\n if(x > 0)\n {\n ys_h.push_back(y);\n xs_h.push_back(x);\n }\n if(y > 0 || h_iteration >= 1) {\n ys_v.push_back(y);\n xs_v.push_back(x);\n }\n ys.push_back(y);\n xs.push_back(x);\n }\n\n affine_y_x_batched_onedir(context, 0, -1,\n CompleteY, V_h, H, ys_h, xs_h, ptr_storage_fwd, height, width);\n\n // If it\'s not the first iteration, we need to use the explicitly provided initial output\n if(h_iteration >= 1) {\n assert(ys_v.size() == 1); // Otherwise, the target length would be != 1, we don\'t support that yet.\n affine_y_x_batched_onedir(context, 0, 0,\n initialOutput, V_v, H, ys_v, xs_v, ptr_storage_fwd, height, width);\n }\n else {\n affine_y_x_batched_onedir(context, -1, 0,\n CompleteY, V_v, H, ys_v, xs_v, ptr_storage_fwd, height, width);\n }\n\n do_lstm_batched_onedir(context, H, initialState, h_iteration, CompleteY, ys, xs, ptr_storage_fwd, valid, sizes);\n }\n '[source]#
- c_bw_code: str = "\n // X, V_h, V_v, W, b, ptr_storage_fwd, ptr_storage_bwd, valid, workmem, workmem2, sizes, DYDummy, initialState,\n // initialOutput, iteration, CompleteY, H, DCompleteY, DH = inputs\n // DX, DV_h, DV_v, DW, Db = outputs\n\n assert(n_inputs == 19);\n assert(n_outputs == 5);\n\n Ndarray* X = inputs[0];\n Ndarray* V_h = inputs[1];\n Ndarray* V_v = inputs[2];\n Ndarray* W = inputs[3];\n Ndarray* b = inputs[4];\n Ndarray* ptr_storage_fwd = inputs[5]; // not used in bwd\n Ndarray* ptr_storage_bwd = inputs[6];\n Ndarray* valid_storage = inputs[7];\n Ndarray* workmem = inputs[8];\n Ndarray* workmem2 = inputs[9];\n Ndarray* sizes = inputs[10];\n Ndarray* DYDummy = inputs[11];\n Ndarray* initialState = inputs[12];\n Ndarray* initialOutput = inputs[13];\n Ndarray* iteration = inputs[14]; // not used in bwd (only for asserting it's == 0)\n Ndarray* CompleteY = inputs[15];\n Ndarray* H = inputs[16];\n Ndarray* DCompleteY = inputs[17];\n Ndarray* DH = inputs[18];\n\n Ndarray* DX = *outputs[0];\n Ndarray* DV_h = *outputs[1];\n Ndarray* DV_v = *outputs[2];\n Ndarray* DW = *outputs[3];\n Ndarray* Db = *outputs[4];\n\n Ndarray_DIMS_Type X_dim = Ndarray_HOST_DIMS(X);\n Ndarray_DIMS_Type Y_dim = Ndarray_HOST_DIMS(CompleteY);\n Ndarray_DIMS_Type Vh_dim = Ndarray_HOST_DIMS(V_h);\n const int height = X_dim[0];\n const int width = X_dim[1];\n const int n_minibatch = X_dim[2];\n const int n_diags = width + height - 1;\n const int max_diag_size = std::min(Y_dim[0], Y_dim[1]);\n\n Ndarray * delta1 = H;\n Ndarray * epsilon = DYDummy;\n\n int size = X_dim[0] * X_dim[1] * X_dim[2] * Vh_dim[0] * sizeof(float);\n HANDLE_ERROR(cudaMemcpy(Ndarray_DEV_DATA(epsilon), Ndarray_DEV_DATA(DCompleteY), size, cudaMemcpyDeviceToDevice));\n\n for(int diag = n_diags-1; diag >= 0; --diag)\n {\n int diag_size = std::min(diag+1, std::min(std::abs(n_diags-diag), std::min(width, height)));\n int y_high = std::min(diag, height-1);\n int x_low = std::max(diag-height+1,0);\n std::vector<int> ys_h, xs_h, ys_v, xs_v, ys, xs;\n for(int idx = 0; idx < diag_size; ++idx)\n {\n int y = y_high - idx;\n int x = x_low + idx;\n bool rightBorder = (x == X_dim[1]-1);\n if(!rightBorder)\n {\n ys_h.push_back(y);\n xs_h.push_back(x);\n }\n bool botBorder = (y == X_dim[0]-1);\n if(!botBorder)\n {\n ys_v.push_back(y);\n xs_v.push_back(x);\n }\n ys.push_back(y);\n xs.push_back(x);\n }\n\n affine_y_x_batched_onedir(context, 0, 1, delta1, V_h,\n epsilon, ys_h, xs_h, ptr_storage_bwd, height, width, 0, false, true);\n affine_y_x_batched_onedir(context, 1, 0, delta1, V_v,\n epsilon, ys_v, xs_v, ptr_storage_bwd, height, width, 0, false, true);\n\n do_lstm_bwd_batched_onedir(\n context, delta1, epsilon, CompleteY, workmem,\n X_dim[0], X_dim[2], ys, xs, ptr_storage_bwd, valid_storage, sizes, diag+1);\n }\n\n //DW = X^T * delta\n affine_global(X, delta1, DW, true, false, 0, 0.0f);\n //important! mind the order, first use X, then update DX, which might be aligned to X\n //DX = delta * W^T\n affine_global(delta1, W, DX, false, true, 0, 0.0f);\n\n // Currently, the bias is not trained\n //Db = (1 ... 1) * delta\n\n //copy left/right part to workmem2 and set to 0\n // (could be done more efficient, but profiling shows, it's not worth it)\n Ndarray_DIMS_Type H_dim = Ndarray_HOST_DIMS(H);\n const int block_size = H_dim[2] * H_dim[3];\n for(int y = 0; y < Y_dim[0]; ++y)\n {\n float * workmem2_1_data_ptr = Ndarray_DEV_DATA(workmem2) + y * block_size;\n float * delta1_data_ptr = data_ptr(delta1, y, 0);\n HANDLE_ERROR(cudaMemcpy(\n workmem2_1_data_ptr, delta1_data_ptr, block_size * sizeof(float), cudaMemcpyDeviceToDevice));\n HANDLE_ERROR(cudaMemset(delta1_data_ptr, 0, sizeof(float) * H_dim[2] * H_dim[3]));\n }\n\n //DV_h = Y[0..end-1]^T * delta[1..end]\n affine_global(CompleteY, delta1, DV_h, true, false, 1, 0.0f);\n\n //copy left/right part back\n for(int y = 0; y < Y_dim[0]; ++y)\n {\n float * workmem2_1_data_ptr = Ndarray_DEV_DATA(workmem2) + y * block_size;\n float * delta1_data_ptr = data_ptr(delta1, y, 0);\n HANDLE_ERROR(cudaMemcpy(\n delta1_data_ptr, workmem2_1_data_ptr, block_size * sizeof(float), cudaMemcpyDeviceToDevice));\n }\n\n //DV_v = Y[0..end-1]^T * delta[1..end]\n affine_global(CompleteY, delta1, DV_v, true, false, Y_dim[1], 0.0f);\n "[source]#
- class returnn.native_op.Chunking[source]#
Given an input in 3d (n_time,n_batch,n_dim), we chunk up the time dimension in chunks of size chunk_size, every chunk_step frames. This results in an 3d output (chunk_size, n_batch * n_chunks, n_dim) where n_chunks = floor( max(n_time - chunk_size + chunk_step - 1, 0) / chunk_step ) + 1. Examples:
n_time=1, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=49, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=50, chunk_size=50, chunk_step=10 -> n_chunks=1 n_time=51, chunk_size=50, chunk_step=10 -> n_chunks=2 n_time=60, chunk_size=50, chunk_step=10 -> n_chunks=2 n_time=61, chunk_size=50, chunk_step=10 -> n_chunks=3 n_time=99, chunk_size=50, chunk_step=10 -> n_chunks=6 n_time=100, chunk_size=50, chunk_step=10 -> n_chunks=6 n_time=101, chunk_size=50, chunk_step=10 -> n_chunks=7
- in_info: Tuple[Dict[str]] = ({'name': 'input', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'output_buffer', 'ndim': 3, 'shape': (None, None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': 'oindex_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 1}, {'gradient': 'disconnected', 'name': 'chunk_params', 'ndim': 1, 'need_contiguous': True, 'shape': (2,)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'shape': ((2, 0), (2, 1), (2, 2))}, {'name': 'oindex', 'ndim': 2, 'shape': ((3, 0), (3, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'copy_kernel': '\n DEF_KERNEL\n void copy_kernel(\n float* chunk_params,\n float* input, long in_dim0, long in_dim1, long in_dim2, long in_stride0, long in_stride1, long in_stride2,\n float* index, long idx_stride0, long idx_stride1,\n float* output, long out_dim0, long out_dim1, long out_stride0, long out_stride1, long out_stride2,\n float* oindex, long oidx_stride0, long oidx_stride1\n ) {\n assert_cmp(out_dim1 % in_dim1, ==, 0);\n const long n_chunks = out_dim1 / in_dim1;\n assert_cmp(n_chunks, >, 0);\n const long chunk_size = out_dim0;\n assert_cmp(long(chunk_params[0]), ==, chunk_size);\n const long chunk_step = long(chunk_params[1]);\n assert_cmp(chunk_step, >, 0);\n assert_cmp(chunk_step * (n_chunks - 1) + chunk_size, >=, in_dim0);\n assert_cmp(chunk_step * (n_chunks - 1), <, in_dim0);\n\n // Iterate over output (chunked) x/y coordinates.\n // In an inner loop, we will loop over z.\n const long max_idx = out_dim0 * out_dim1;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n long out_x = idx % out_dim0; // time\n long out_y = idx / out_dim0; // batch\n\n long chunk_idx = out_y % n_chunks;\n long in_y = out_y / n_chunks;\n\n long in_x = chunk_step * chunk_idx + out_x;\n\n if(in_x < in_dim0 && index[in_x * idx_stride0 + in_y * idx_stride1] > 0.1) {\n for(long z = 0; z < in_dim2; ++z)\n output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] =\n input[in_x * in_stride0 + in_y * in_stride1 + z * in_stride2];\n oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 1;\n }\n else {\n for(long z = 0; z < in_dim2; ++z)\n output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] = 0;\n oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 0;\n }\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert_cmp(n_inputs, ==, 5);\n assert_cmp(n_outputs, ==, 2);\n Ndarray* input = inputs[0];\n Ndarray* index = inputs[1];\n Ndarray* chunk_params = inputs[4];\n Ndarray* output = *outputs[0];\n Ndarray* oindex = *outputs[1];\n\n assert_cmp(Ndarray_NDIM(input), ==, 3);\n assert_cmp(Ndarray_NDIM(index), ==, 2);\n assert_cmp(Ndarray_DIMS(input)[0], ==, Ndarray_DIMS(index)[0]);\n assert_cmp(Ndarray_DIMS(input)[1], ==, Ndarray_DIMS(index)[1]);\n assert_cmp(Ndarray_NDIM(chunk_params), ==, 1);\n assert_cmp(Ndarray_DIMS(chunk_params)[0], ==, 2);\n assert_cmp(Ndarray_NDIM(output), ==, 3);\n assert_cmp(Ndarray_NDIM(oindex), ==, 2);\n assert_cmp(Ndarray_DIMS(output)[0], ==, Ndarray_DIMS(oindex)[0]);\n assert_cmp(Ndarray_DIMS(output)[1], ==, Ndarray_DIMS(oindex)[1]);\n assert_cmp(Ndarray_DIMS(output)[2], ==, Ndarray_DIMS(input)[2]);\n\n start_dev_kernel(copy_kernel, (\n Ndarray_DEV_DATA(chunk_params),\n Ndarray_DEV_DATA(input),\n Ndarray_DIMS(input)[0],\n Ndarray_DIMS(input)[1],\n Ndarray_DIMS(input)[2],\n Ndarray_STRIDE(input, 0),\n Ndarray_STRIDE(input, 1),\n Ndarray_STRIDE(input, 2),\n Ndarray_DEV_DATA(index),\n Ndarray_STRIDE(index, 0),\n Ndarray_STRIDE(index, 1),\n Ndarray_DEV_DATA(output),\n Ndarray_DIMS(output)[0],\n Ndarray_DIMS(output)[1],\n Ndarray_STRIDE(output, 0),\n Ndarray_STRIDE(output, 1),\n Ndarray_STRIDE(output, 2),\n Ndarray_DEV_DATA(oindex),\n Ndarray_STRIDE(oindex, 0),\n Ndarray_STRIDE(oindex, 1)\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.UnChunking[source]#
This reverses the output from Chunking, i.e. chunking the time dimension. We get a 3d input (chunk_size, n_batch * n_chunks, n_dim) and return an 3d output (n_time, n_batch, n_dim) where the chunks are of size chunk_size, every chunk_step frames. Because of overlaps, we have to combine the overlapping chunks somehow. We will do that with a uniform distribution, i.e. take the mean of all overlaps per frame.
- in_info: Tuple[Dict[str]] = ({'name': 'input', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'output_buffer', 'ndim': 3, 'shape': (None, None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': 'oindex_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 1}, {'gradient': 'disconnected', 'name': 'ofactors_buffer', 'ndim': 2, 'shape': (None, None), 'want_inplace': 2}, {'gradient': 'disconnected', 'name': 'chunk_params', 'ndim': 1, 'need_contiguous': True, 'shape': (2,)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'shape': ((2, 0), (2, 1), (2, 2))}, {'name': 'oindex', 'ndim': 2, 'shape': ((3, 0), (3, 1))}, {'name': 'ofactors', 'ndim': 2, 'shape': ((4, 0), (4, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'unchunk_kernel': '\n DEF_KERNEL\n void unchunk_kernel(\n float* chunk_params,\n float* input, long in_dim0, long in_dim1, long in_dim2, long in_stride0, long in_stride1, long in_stride2,\n float* index, long idx_stride0, long idx_stride1,\n float* output, long out_dim0, long out_dim1, long out_stride0, long out_stride1, long out_stride2,\n float* oindex, long oidx_stride0, long oidx_stride1,\n float* ofactors, long ofac_stride0, long ofac_stride1\n ) {\n assert_cmp(in_dim1 % out_dim1, ==, 0);\n const long n_chunks = in_dim1 / out_dim1;\n assert_cmp(n_chunks, >, 0);\n const long chunk_size = in_dim0;\n assert_cmp(long(chunk_params[0]), ==, chunk_size);\n const long chunk_step = long(chunk_params[1]);\n assert_cmp(chunk_step, >, 0);\n assert_cmp(chunk_step * (n_chunks - 1) + chunk_size, >=, out_dim0);\n assert_cmp(chunk_step * (n_chunks - 1), <, out_dim0);\n\n // Iterate over output (unchunked) x/y coordinates.\n // In an inner loop, we will loop over z.\n const long max_idx = out_dim0 * out_dim1;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n long out_x = idx % out_dim0; // time\n long out_y = idx / out_dim0; // batch\n\n float c = 0;\n for(long z = 0; z < in_dim2; ++z)\n output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] = 0;\n\n // in_x = out_x - chunk_step * chunk_idx,\n // thus in_x < 0 when chunk_idx * chunk_step > out_x,\n // and in_x >= chunk_size when chunk_idx * chunk_step <= out_x - chunk_size,\n // thus we need chunk_idx <= out_x / chunk_step,\n // and chunk_idx > (out_x - chunk_size) / chunk_step.\n // Examples:\n // out_x=0, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,1\n // out_x=3, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,1\n // out_x=4, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,2\n // out_x=7, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,2\n // out_x=8, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,3\n // out_x=9, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=0,3\n // out_x=10, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,3\n // out_x=11, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,3\n // out_x=12, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,4\n // out_x=13, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=1,4\n // out_x=14, chunk_size=10, chunk_step=4 -> chunk_idx_start,end=2,4\n long chunk_idx_start = (out_x - chunk_size + chunk_step) / chunk_step;\n if(chunk_idx_start < 0) chunk_idx_start = 0;\n long chunk_idx_end = out_x / chunk_step + 1;\n if(chunk_idx_end > n_chunks) chunk_idx_end = n_chunks;\n assert_cmp(chunk_idx_start, <, chunk_idx_end);\n for(long chunk_idx = chunk_idx_start; chunk_idx < chunk_idx_end; ++chunk_idx) {\n long in_y = out_y * n_chunks + chunk_idx;\n long in_x = out_x - chunk_step * chunk_idx;\n assert_cmp(in_x, >=, 0);\n assert_cmp(in_x, <, chunk_size);\n if(index[in_x * idx_stride0 + in_y * idx_stride1] > 0.1) {\n c += 1;\n for(long z = 0; z < in_dim2; ++z)\n output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] +=\n input[in_x * in_stride0 + in_y * in_stride1 + z * in_stride2];\n }\n }\n\n if(c > 0.1) {\n for(long z = 0; z < in_dim2; ++z)\n output[out_x * out_stride0 + out_y * out_stride1 + z * out_stride2] /= c;\n oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 1;\n ofactors[out_x * ofac_stride0 + out_y * ofac_stride1] = 1.0 / c;\n } else {\n oindex[out_x * oidx_stride0 + out_y * oidx_stride1] = 0;\n ofactors[out_x * ofac_stride0 + out_y * ofac_stride1] = 1.0;\n }\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert_cmp(n_inputs, ==, 6);\n assert_cmp(n_outputs, ==, 3);\n Ndarray* input = inputs[0];\n Ndarray* index = inputs[1];\n Ndarray* chunk_params = inputs[5];\n Ndarray* output = *outputs[0];\n Ndarray* oindex = *outputs[1];\n Ndarray* ofactors = *outputs[2];\n\n assert_cmp(Ndarray_NDIM(input), ==, 3);\n assert_cmp(Ndarray_NDIM(index), ==, 2);\n assert_cmp(Ndarray_DIMS(input)[0], ==, Ndarray_DIMS(index)[0]);\n assert_cmp(Ndarray_DIMS(input)[1], ==, Ndarray_DIMS(index)[1]);\n assert_cmp(Ndarray_NDIM(chunk_params), ==, 1);\n assert_cmp(Ndarray_DIMS(chunk_params)[0], ==, 2);\n assert_cmp(Ndarray_NDIM(output), ==, 3);\n assert_cmp(Ndarray_NDIM(oindex), ==, 2);\n assert_cmp(Ndarray_NDIM(ofactors), ==, 2);\n assert_cmp(Ndarray_DIMS(output)[0], ==, Ndarray_DIMS(oindex)[0]);\n assert_cmp(Ndarray_DIMS(output)[1], ==, Ndarray_DIMS(oindex)[1]);\n assert_cmp(Ndarray_DIMS(output)[2], ==, Ndarray_DIMS(input)[2]);\n assert_cmp(Ndarray_DIMS(oindex)[0], ==, Ndarray_DIMS(ofactors)[0]);\n assert_cmp(Ndarray_DIMS(oindex)[1], ==, Ndarray_DIMS(ofactors)[1]);\n\n start_dev_kernel(unchunk_kernel, (\n Ndarray_DEV_DATA(chunk_params),\n Ndarray_DEV_DATA(input),\n Ndarray_DIMS(input)[0],\n Ndarray_DIMS(input)[1],\n Ndarray_DIMS(input)[2],\n Ndarray_STRIDE(input, 0),\n Ndarray_STRIDE(input, 1),\n Ndarray_STRIDE(input, 2),\n Ndarray_DEV_DATA(index),\n Ndarray_STRIDE(index, 0),\n Ndarray_STRIDE(index, 1),\n Ndarray_DEV_DATA(output),\n Ndarray_DIMS(output)[0],\n Ndarray_DIMS(output)[1],\n Ndarray_STRIDE(output, 0),\n Ndarray_STRIDE(output, 1),\n Ndarray_STRIDE(output, 2),\n Ndarray_DEV_DATA(oindex),\n Ndarray_STRIDE(oindex, 0),\n Ndarray_STRIDE(oindex, 1),\n Ndarray_DEV_DATA(ofactors),\n Ndarray_STRIDE(ofactors, 0),\n Ndarray_STRIDE(ofactors, 1)\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.SubtensorBatchedIndex[source]#
- Consider you have:
idx: 2d (n_time, n_batch) -> idx (in [0..n_dim-1]) x: 3d (n_time, n_batch, n_dim)
- Then, this op will calculate:
x[…, idx[…]]: 2d (n_time, n_batch)
- in_info: Tuple[Dict[str]] = ({'bw_in_var': {'want_inplace': 0}, 'name': 'x', 'ndim': 3, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'idx', 'ndim': 2, 'shape': (None, None)})[source]#
- c_extra_support_code: Dict[str, str] = {'select_bw_kernel': '\n DEF_KERNEL\n void select_bw_kernel(\n float* Dx, long Dx_dim0, long Dx_dim1, long Dx_dim2, long Dx_stride0, long Dx_stride1, long Dx_stride2,\n float* index, long idx_stride0, long idx_stride1,\n float* Dy, long Dy_stride0, long Dy_stride1\n ) {\n const long max_idx = Dx_dim0 * Dx_dim1;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n long d0 = idx % Dx_dim0;\n long d1 = idx / Dx_dim0;\n long d2 = long(index[d0 * idx_stride0 + d1 * idx_stride1]);\n if(d2 < 0) d2 = 0;\n if(d2 >= Dx_dim2) d2 = Dx_dim2 - 1;\n Dx[d0 * Dx_stride0 + d1 * Dx_stride1 + d2 * Dx_stride2] = Dy[d0 * Dy_stride0 + d1 * Dy_stride1];\n }\n }\n ', 'select_kernel': '\n DEF_KERNEL\n void select_kernel(\n float* x, long x_dim0, long x_dim1, long x_dim2, long x_stride0, long x_stride1, long x_stride2,\n float* index, long idx_stride0, long idx_stride1,\n float* y, long y_stride0, long y_stride1\n ) {\n const long max_idx = x_dim0 * x_dim1;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n long d0 = idx % x_dim0;\n long d1 = idx / x_dim0;\n long d2 = long(index[d0 * idx_stride0 + d1 * idx_stride1]);\n if(d2 < 0) d2 = 0;\n if(d2 >= x_dim2) d2 = x_dim2 - 1;\n y[d0 * y_stride0 + d1 * y_stride1] = x[d0 * x_stride0 + d1 * x_stride1 + d2 * x_stride2];\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert_cmp(n_inputs, ==, 2);\n assert_cmp(n_outputs, ==, 1);\n Ndarray* x = inputs[0];\n Ndarray* idx = inputs[1];\n Ndarray* y = *outputs[0];\n\n assert_cmp(Ndarray_NDIM(x), ==, 3);\n assert_cmp(Ndarray_NDIM(idx), ==, 2);\n assert_cmp(Ndarray_DIMS(x)[0], ==, Ndarray_DIMS(idx)[0]);\n assert_cmp(Ndarray_DIMS(x)[1], ==, Ndarray_DIMS(idx)[1]);\n assert_cmp(Ndarray_NDIM(y), ==, 2);\n assert_cmp(Ndarray_DIMS(y)[0], ==, Ndarray_DIMS(idx)[0]);\n assert_cmp(Ndarray_DIMS(y)[1], ==, Ndarray_DIMS(idx)[1]);\n\n start_dev_kernel(select_kernel, (\n Ndarray_DEV_DATA(x),\n Ndarray_DIMS(x)[0],\n Ndarray_DIMS(x)[1],\n Ndarray_DIMS(x)[2],\n Ndarray_STRIDE(x, 0),\n Ndarray_STRIDE(x, 1),\n Ndarray_STRIDE(x, 2),\n Ndarray_DEV_DATA(idx),\n Ndarray_STRIDE(idx, 0),\n Ndarray_STRIDE(idx, 1),\n Ndarray_DEV_DATA(y),\n Ndarray_STRIDE(y, 0),\n Ndarray_STRIDE(y, 1)\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- c_bw_code: str = '\n assert_cmp(n_inputs, ==, 3);\n assert_cmp(n_outputs, ==, 1);\n Ndarray* x = inputs[0];\n Ndarray* idx = inputs[1];\n Ndarray* Dy = inputs[2];\n Ndarray* Dx = *outputs[0]; // inplace on x\n\n assert_cmp(Ndarray_NDIM(x), ==, 3);\n assert_cmp(Ndarray_NDIM(idx), ==, 2);\n assert_cmp(Ndarray_DIMS(x)[0], ==, Ndarray_DIMS(idx)[0]);\n assert_cmp(Ndarray_DIMS(x)[1], ==, Ndarray_DIMS(idx)[1]);\n assert_cmp(Ndarray_NDIM(Dy), ==, 2);\n assert_cmp(Ndarray_DIMS(Dy)[0], ==, Ndarray_DIMS(idx)[0]);\n assert_cmp(Ndarray_DIMS(Dy)[1], ==, Ndarray_DIMS(idx)[1]);\n assert_cmp(Ndarray_NDIM(Dx), ==, 3);\n assert_cmp(Ndarray_DIMS(Dx)[0], ==, Ndarray_DIMS(x)[0]);\n assert_cmp(Ndarray_DIMS(Dx)[1], ==, Ndarray_DIMS(x)[1]);\n assert_cmp(Ndarray_DIMS(Dx)[2], ==, Ndarray_DIMS(x)[2]);\n\n Ndarray_set_zero(Dx);\n start_dev_kernel(select_bw_kernel, (\n Ndarray_DEV_DATA(Dx),\n Ndarray_DIMS(Dx)[0],\n Ndarray_DIMS(Dx)[1],\n Ndarray_DIMS(Dx)[2],\n Ndarray_STRIDE(Dx, 0),\n Ndarray_STRIDE(Dx, 1),\n Ndarray_STRIDE(Dx, 2),\n Ndarray_DEV_DATA(idx),\n Ndarray_STRIDE(idx, 0),\n Ndarray_STRIDE(idx, 1),\n Ndarray_DEV_DATA(Dy),\n Ndarray_STRIDE(Dy, 0),\n Ndarray_STRIDE(Dy, 1)\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.SparseToDense[source]#
Expects a sparse matrix in COOrdinate format, where W[s0[i,b],b,s1[i]] = weight[i,b] for all i, and all batches b. Will return W (time,batch,dim).
- in_info: Tuple[Dict[str]] = ({'name': '_initial_W', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None), 'want_inplace': 0}, {'name': 's0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 's1', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'weight', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'W', 'ndim': 3, 'shape': ((0, 0), (0, 1), (0, 2))},)[source]#
- c_extra_support_code: Dict[str, str] = {'assign_kernel': '\n DEF_KERNEL\n void assign_kernel(\n float* out, float* s0, float* s1, float* w, float* mask,\n long n_sparse_idx, long n_time, long n_batch, long n_dim)\n {\n long max_idx = n_batch * n_sparse_idx;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n if(mask[idx] < 0.1) continue;\n long batch = idx % n_batch;\n long t = (long) s0[idx];\n long j = (long) s1[idx];\n float y = w[idx];\n if(t < 0 || t >= n_time) continue; // error somehow?\n if(j < 0 || j >= n_dim) continue; // error somehow?\n long out_idx = t * n_batch * n_dim + batch * n_dim + j;\n out[out_idx] += y;\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert(n_inputs == 5);\n assert(n_outputs == 1);\n Ndarray* s0 = inputs[1];\n Ndarray* s1 = inputs[2];\n Ndarray* weight = inputs[3];\n Ndarray* mask = inputs[4];\n Ndarray* out_W = *outputs[0];\n\n assert(Ndarray_NDIM(s0) == 2);\n assert(Ndarray_NDIM(s1) == 2);\n assert(Ndarray_NDIM(weight) == 2);\n assert(Ndarray_NDIM(mask) == 2);\n assert(Ndarray_NDIM(out_W) == 3);\n int n_sparse_idx = Ndarray_DIMS(s0)[0];\n assert(n_sparse_idx == Ndarray_DIMS(s1)[0]);\n assert(n_sparse_idx == Ndarray_DIMS(weight)[0]);\n assert(n_sparse_idx == Ndarray_DIMS(mask)[0]);\n int n_batch = Ndarray_DIMS(s0)[1];\n assert(n_batch == Ndarray_DIMS(s1)[1]);\n assert(n_batch == Ndarray_DIMS(weight)[1]);\n assert(n_batch == Ndarray_DIMS(mask)[1]);\n assert(n_batch == Ndarray_DIMS(out_W)[1]);\n int n_time = Ndarray_DIMS(out_W)[0];\n int n_dim = Ndarray_DIMS(out_W)[2];\n\n start_dev_kernel(assign_kernel, (\n Ndarray_DEV_DATA(out_W),\n Ndarray_DEV_DATA(s0),\n Ndarray_DEV_DATA(s1),\n Ndarray_DEV_DATA(weight),\n Ndarray_DEV_DATA(mask),\n n_sparse_idx, n_time, n_batch, n_dim\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.MaxAndArgmaxSparse[source]#
Expects a sparse matrix in COOrdinate format, where W[s0[i,b],s1[i],b] = weight[i,b] for all i, and all batches b. It will return the max and argmax for all W[:,:,b] over the second axis.
- in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 's0', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 's1', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'weight', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': '_out_max', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None), 'want_inplace': 0}, {'gradient': 'disconnected', 'name': '_out_arg', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None), 'want_inplace': 1})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'out_max', 'ndim': 2, 'shape': ((4, 0), (4, 1))}, {'name': 'out_arg', 'ndim': 2, 'shape': ((5, 0), (5, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'doit_kernel': '\n DEF_KERNEL\n void doit_kernel(\n long n_batch, long n_in_time, long n_out_time,\n float* s0, float* s1, float* weight, float* mask,\n float* out_max, float* out_arg) {\n long batch_idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(batch_idx < n_batch) {\n for(long i = 0; i < n_in_time; ++i) {\n long idx = i * n_batch + batch_idx;\n if(mask[idx] < 0.1) continue;\n long t = (long) s0[idx];\n long j = (long) s1[idx];\n float w = weight[idx];\n if(t < 0 || t >= n_out_time) continue; // error somehow?\n long out_idx = t * n_batch + batch_idx;\n if(w > out_max[out_idx]) {\n out_max[out_idx] = w;\n out_arg[out_idx] = (float) j;\n }\n }\n batch_idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert(n_inputs == 6);\n assert(n_outputs == 2);\n Ndarray* s0 = inputs[0];\n Ndarray* s1 = inputs[1];\n Ndarray* weight = inputs[2];\n Ndarray* mask = inputs[3];\n Ndarray* out_max = *outputs[0];\n Ndarray* out_arg = *outputs[1];\n\n assert(Ndarray_NDIM(s0) == 2);\n assert(Ndarray_NDIM(s1) == 2);\n assert(Ndarray_NDIM(weight) == 2);\n assert(Ndarray_NDIM(mask) == 2);\n assert(Ndarray_NDIM(out_max) == 2);\n assert(Ndarray_NDIM(out_arg) == 2);\n int n_in_time = Ndarray_DIMS(s0)[0];\n assert(n_in_time == Ndarray_DIMS(s1)[0]);\n assert(n_in_time == Ndarray_DIMS(weight)[0]);\n assert(n_in_time == Ndarray_DIMS(mask)[0]);\n int n_batch = Ndarray_DIMS(s0)[1];\n assert(n_batch == Ndarray_DIMS(s1)[1]);\n assert(n_batch == Ndarray_DIMS(weight)[1]);\n assert(n_batch == Ndarray_DIMS(mask)[1]);\n assert(n_batch == Ndarray_DIMS(out_arg)[1]);\n assert(n_batch == Ndarray_DIMS(out_max)[1]);\n int n_out_time = Ndarray_DIMS(out_arg)[0];\n assert(n_out_time == Ndarray_DIMS(out_max)[0]);\n assert(out_max != out_arg); // earlier bug in NativeOp\n\n start_dev_kernel(doit_kernel, (\n n_batch, n_in_time, n_out_time,\n Ndarray_DEV_DATA(s0),\n Ndarray_DEV_DATA(s1),\n Ndarray_DEV_DATA(weight),\n Ndarray_DEV_DATA(mask),\n Ndarray_DEV_DATA(out_max),\n Ndarray_DEV_DATA(out_arg)\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.CrossEntropySoftmaxAndGradientZSparse[source]#
y_target is given in sparse COOrdinate format. We will calculate CE[t,b] = sum_i y_target[t,b,i] * log(softmax(z[t,b])[i]), for any timeframe t and batch b, and grad(CE[t,b], z[t,b]) = softmax(z[t,b]) - y_target[t,b]. We also support an index-mask for z, i.e. for the possible [t,b].
- in_info: Tuple[Dict[str]] = ({'name': 'z', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'name': 'z_mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_t', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_i', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_w', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'name': 'y_target_mask', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'out_ce', 'ndim': 2, 'shape': ((0, 0), (0, 1))}, {'name': 'out_grad_z', 'ndim': 3, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': '_out_max_z', 'ndim': 2, 'shape': ((0, 0), (0, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'ce_sm_grad_kernel': '\n DEF_KERNEL\n void ce_sm_grad_kernel(\n float* out_ce, float* out_grad_z,\n float* z, float* max_z, float* z_mask,\n float* s0, float* s1, float* w, float* s_mask,\n long n_time, long n_batch, long n_dim, long n_sparse_index)\n {\n long max_idx = n_batch * n_sparse_index;\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n if(s_mask[idx] < 0.1) continue;\n long batch = idx % n_batch;\n long t = (long) s0[idx];\n long j = (long) s1[idx];\n float y_target = w[idx];\n if(t < 0 || t >= n_time) continue; // error somehow?\n if(j < 0 || j >= n_dim) continue; // error somehow?\n long out_ce_idx = t * n_batch + batch;\n long out_y_idx = t * n_batch * n_dim + batch * n_dim + j;\n // This assumes that out_grad_z is still softmax(z).\n // This also assumes that every [t,j] is only represented once in the sparse data.\n out_ce[out_ce_idx] -= y_target * log(fmax(out_grad_z[out_y_idx], 1e-30f));\n out_grad_z[out_y_idx] -= y_target;\n }\n }\n ', 'max_kernel': '\n DEF_KERNEL\n void max_kernel(float* out, float* v, float* mask, long stride, long max_idx) {\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n if(mask[idx] < 0.1)\n continue;\n long start = idx * stride;\n float last_max = v[start];\n out[idx] = last_max;\n for(long i = 1; i < stride; ++i) {\n float cur = v[start + i];\n if(cur > last_max) {\n last_max = cur;\n out[idx] = cur;\n }\n }\n }\n }\n ', 'softmax_kernel': '\n DEF_KERNEL\n void softmax_kernel(\n float* out_softmax,\n float* z, float* max_z, float* mask,\n long stride, long max_idx)\n {\n for(\n long idx = threadIdx.x + blockDim.x * blockIdx.x;\n idx < max_idx;\n idx += gridDim.x * blockDim.x)\n {\n long start = idx * stride;\n float s = 0;\n for(long i = 0; i < stride; ++i) {\n s += exp(z[start + i] - max_z[idx]);\n }\n if(s < 1e-16) s = 1e-16;\n for(long i = 0; i < stride; ++i) {\n float y = exp(z[start + i] - max_z[idx]) / s;\n out_softmax[start + i] = (mask[idx] > 0.5) ? y : 0;\n }\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert(n_inputs == 6);\n assert(n_outputs == 3);\n Ndarray* z = inputs[0];\n Ndarray* z_mask = inputs[1];\n Ndarray* s0 = inputs[2];\n Ndarray* s1 = inputs[3];\n Ndarray* w = inputs[4];\n Ndarray* s_mask = inputs[5];\n Ndarray* out_ce = *outputs[0];\n Ndarray* out_grad_z = *outputs[1];\n Ndarray* out_max_z = *outputs[2];\n\n assert(Ndarray_NDIM(z) == 3);\n assert(Ndarray_NDIM(z_mask) == 2);\n assert(Ndarray_NDIM(out_ce) == 2);\n assert(Ndarray_NDIM(out_grad_z) == 3);\n assert(Ndarray_NDIM(out_max_z) == 2);\n assert(Ndarray_NDIM(s0) == 2);\n assert(Ndarray_NDIM(s1) == 2);\n assert(Ndarray_NDIM(w) == 2);\n assert(Ndarray_NDIM(out_ce) == 2);\n int n_time = Ndarray_DIMS(z)[0];\n int n_batch = Ndarray_DIMS(z)[1];\n int n_dim = Ndarray_DIMS(z)[2];\n assert(n_time == Ndarray_DIMS(z_mask)[0]);\n assert(n_time == Ndarray_DIMS(out_ce)[0]);\n assert(n_time == Ndarray_DIMS(out_grad_z)[0]);\n assert(n_time == Ndarray_DIMS(out_max_z)[0]);\n assert(n_batch == Ndarray_DIMS(z_mask)[1]);\n assert(n_batch == Ndarray_DIMS(out_ce)[1]);\n assert(n_batch == Ndarray_DIMS(out_grad_z)[1]);\n assert(n_batch == Ndarray_DIMS(out_max_z)[1]);\n assert(n_batch == Ndarray_DIMS(s0)[1]);\n assert(n_batch == Ndarray_DIMS(s1)[1]);\n assert(n_batch == Ndarray_DIMS(w)[1]);\n assert(n_batch == Ndarray_DIMS(s_mask)[1]);\n assert(n_dim == Ndarray_DIMS(out_grad_z)[2]);\n int n_sparse_index = Ndarray_DIMS(s0)[0];\n assert(n_sparse_index == Ndarray_DIMS(s1)[0]);\n assert(n_sparse_index == Ndarray_DIMS(w)[0]);\n assert(n_sparse_index == Ndarray_DIMS(s_mask)[0]);\n\n start_dev_kernel(max_kernel, (\n Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(z_mask),\n n_dim, n_time * n_batch\n ));\n HANDLE_LAST_ERROR();\n Ndarray_set_zero(out_ce);\n start_dev_kernel(softmax_kernel, (\n Ndarray_DEV_DATA(out_grad_z),\n Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z_mask),\n n_dim, n_time * n_batch\n ));\n HANDLE_LAST_ERROR();\n start_dev_kernel(ce_sm_grad_kernel, (\n Ndarray_DEV_DATA(out_ce), Ndarray_DEV_DATA(out_grad_z),\n Ndarray_DEV_DATA(z), Ndarray_DEV_DATA(out_max_z), Ndarray_DEV_DATA(z_mask),\n Ndarray_DEV_DATA(s0), Ndarray_DEV_DATA(s1), Ndarray_DEV_DATA(w), Ndarray_DEV_DATA(s_mask),\n n_time, n_batch, n_dim, n_sparse_index\n ));\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.FastBaumWelchOp[source]#
- inputs:
- param am_scores:
scores in -log space. 3d (time,batch,dim)
- param edges:
edges of the graph (from,to,emission_idx,sequence_idx)
- param weights:
weights of the edges
- outputs:
- param output:
Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
- in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'state_buffer', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'sums', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n DEF_KERNEL\n void set_start_states(float* states, unsigned* start_states) {\n unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n states[state_idx] = 0.0;\n }\n ', '010_fill_array': '\n DEF_KERNEL\n void fill_array(float* array, float value, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = value;\n }\n }\n ', '011_remove_inf': '\n DEF_KERNEL\n void remove_inf(float* array, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = fminf(array[idx], 1e32);\n }\n }\n ', '012_prob_add': '\n DEV_FUNC\n float prob_add(float a, float b) {\n float diff = a - b;\n if (isnan(diff)) {\n return INF_F;\n }\n else {\n return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n }\n }\n ', '013_atomic_prob_add': '\n DEV_FUNC\n void atomic_prob_add(float* a, float b) {\n int* addr = (int*)a;\n int old = float_as_int(*a);\n int assumed;\n do {\n assumed = old;\n old = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n } while (old != assumed);\n }\n ', '020_dump_to_file': "\n template<typename T>\n void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n std::vector<T> buffer(n_d1);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n T val = buffer[i1];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << val << '\\n';\n }\n }\n }\n\n template<typename T>\n void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n T val = buffer[i1 * n_d2 + i2];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << val << '\\n';\n }\n }\n }\n }\n\n template<typename T>\n void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n }\n }\n }\n }\n }\n ", '100_init_bwd_state_buffer': '\n DEF_KERNEL\n void init_bwd_state_buffer(\n float* states, unsigned* end_states, unsigned t, unsigned max_t, float* index, unsigned index_stride) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (index[t * index_stride + idx] == 1.0 && (t == max_t || index[(t + 1) * index_stride + idx] == 0.0)) {\n unsigned state_idx = end_states[idx];\n states[state_idx] = 0.0;\n }\n }\n ', '101_next_frame': '\n DEF_KERNEL\n void next_frame(bool fwd, unsigned num_edges, unsigned num_emissions,\n unsigned* sequence_idxs, unsigned* from_buffer, unsigned* to_buffer, float* weight_buffer,\n unsigned* emission_idxs,\n float* prev_frame, float* next_frame, float* am_scores, float* edge_buffer) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_edges) {\n return;\n }\n\n unsigned from = from_buffer [idx];\n float prev_val = prev_frame[from];\n if (isinf(prev_val)) {\n edge_buffer[idx] = INF_F;\n return;\n }\n\n unsigned to = to_buffer [idx];\n unsigned emission_idx = emission_idxs[idx];\n float edge_weight = weight_buffer[idx];\n unsigned sequence_idx = sequence_idxs[idx];\n\n float val = prev_val + edge_weight + am_scores[sequence_idx * num_emissions + emission_idx];\n\n if (fwd) {\n edge_buffer[idx] += val;\n }\n else {\n edge_buffer[idx] += prev_val;\n }\n atomic_prob_add(next_frame + to, val);\n }\n ', '102_normalize': '\n DEF_KERNEL\n void normalize(float* buffer, unsigned* sequence_idxs, unsigned num_edges, unsigned num_seqs, float* sum_output) {\n DEF_SHARED(float, sum);\n\n buffer += blockIdx.x * num_edges;\n\n for (unsigned s = 0u; s < num_seqs; s++) {\n sum[s] = INF_F;\n }\n\n for (unsigned e = 0u; e < num_edges; e++) {\n unsigned s = sequence_idxs[e];\n sum[s] = prob_add(sum[s], buffer[e]);\n }\n\n for (unsigned s = 0ul; s < num_seqs; s++) {\n if (isinf(sum[s])) {\n // if the frame is empty (happens due to batching of seqs with unequal length), set it to 0\n sum_output[blockIdx.x * num_seqs + s] = 0.0;\n }\n else {\n sum_output[blockIdx.x * num_seqs + s] = sum[s];\n }\n }\n\n for (unsigned e = 0u; e < num_edges; e++) {\n unsigned s = sequence_idxs[e];\n buffer[e] -= sum[s];\n }\n }\n ', '103_compute_result': '\n DEF_KERNEL\n void compute_result(float* edge_buffer, float* out, unsigned* emission_idxs, unsigned* sequence_idxs,\n unsigned frame_stride, unsigned seq_stride,\n unsigned num_frames, unsigned num_seqs, unsigned num_edges) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_frames * num_edges) {\n return;\n }\n\n unsigned e_idx = idx % num_edges;\n unsigned frame = idx / num_edges;\n unsigned emission_idx = emission_idxs[e_idx];\n unsigned seq_idx = sequence_idxs[e_idx];\n float score = edge_buffer[idx];\n\n atomic_prob_add(out + frame * frame_stride + seq_idx * seq_stride + emission_idx, score);\n }\n ', '110_write_alignment_to_file': '\n void write_alignment_to_file(float* d_state_buffer, float* d_index, unsigned index_stride,\n unsigned* d_start_states, unsigned* d_end_states,\n float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_states,\n unsigned batch_idx) {\n std::vector<float> state_buffer((n_frames + 1u) * n_states);\n std::vector<float> index (n_frames * index_stride);\n std::vector<unsigned> start_states(n_seqs);\n std::vector<unsigned> end_states (n_seqs);\n\n //HANDLE_ERROR(cudaMemcpy(\n // state_buffer.data(), d_state_buffer, state_buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // index.data(), d_index, index.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // start_states.data(), d_start_states, start_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // end_states.data(), d_end_states, end_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n\n for (unsigned seq = 0u; seq < n_seqs; seq++) {\n std::stringstream filename;\n filename << "alignment.dump." << batch_idx << \'.\' << seq;\n std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n for (unsigned t = 0u; t < n_frames; t++) {\n if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n break;\n }\n float sum = std::numeric_limits<float>::infinity();\n for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n const float val = state_buffer[t * n_states + s];\n float diff = val - sum;\n if (!isnan(diff)) {\n sum = -log1p(exp(-abs(diff))) + fminf(sum, val);\n }\n }\n for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n const float val = state_buffer[t * n_states + s] - sum;\n if (val <= pruning) {\n out << t << \' \' << (s - start_states[seq]) << \' \' << val << \'\\n\';\n }\n }\n }\n }\n }\n ', '111_write_output_to_file': '\n void write_output_to_file(float* d_out, float* d_index, unsigned index_stride,\n float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_emissions,\n unsigned batch_idx) {\n std::vector<float> buffer(n_frames * n_seqs * n_emissions);\n std::vector<float> index (n_frames * index_stride);\n\n //HANDLE_ERROR(cudaMemcpy(buffer.data(), d_out, buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(index.data(), d_index, index.size() * sizeof(float), cudaMemcpyDeviceToHost));\n\n for (unsigned seq = 0u; seq < n_seqs; seq++) {\n std::stringstream filename;\n filename << "target.dump." << batch_idx << \'.\' << seq;\n std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n for (unsigned t = 0u; t < n_frames; t++) {\n if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n break;\n }\n for (unsigned e = 0u; e < n_emissions; e++) {\n const float val = buffer[t * n_seqs * n_emissions + seq * n_emissions + e];\n if (val <= pruning) {\n out << t << \' \' << e << \' \' << val << \'\\n\';\n }\n }\n }\n }\n }\n '}[source]#
- c_fw_code: str = '\n // am_scores, edges, weights, start_end_states, index, state_buffer* = input_names (*: inplace)\n // output = output_names\n assert(n_inputs == 6);\n assert(n_outputs == 2);\n Ndarray* am_scores = inputs[0];\n Ndarray* edges = inputs[1];\n Ndarray* weights = inputs[2];\n Ndarray* start_end_states = inputs[3];\n Ndarray* index = inputs[4];\n Ndarray* state_buffer = inputs[5];\n Ndarray* out = *outputs[0];\n Ndarray* sum_output = *outputs[1];\n\n /*\n debug_print(context, am_scores, "am_scores");\n debug_print(context, edges, "edges");\n debug_print(context, weights, "weights");\n debug_print(context, start_end_states, "start_end_states");\n debug_print(context, index, "index");\n debug_print(context, state_buffer, "state_buffer");\n */\n\n assert_cmp(Ndarray_DIMS(am_scores)[0], ==, Ndarray_DIMS(out)[0]);\n assert_cmp(Ndarray_DIMS(am_scores)[1], ==, Ndarray_DIMS(out)[1]);\n assert_cmp(Ndarray_DIMS(am_scores)[2], ==, Ndarray_DIMS(out)[2]);\n assert_cmp(Ndarray_DIMS(am_scores)[1], ==, Ndarray_DIMS(start_end_states)[1]);\n\n assert_cmp(Ndarray_DIMS(sum_output)[0], ==, Ndarray_DIMS(am_scores)[0]);\n assert_cmp(Ndarray_DIMS(sum_output)[1], ==, Ndarray_DIMS(am_scores)[1]);\n\n bool dump_alignment = false;\n bool dump_output = false;\n unsigned dump_every = 40u;\n static unsigned batch_idx = 0u;\n float pruning = 10.f;\n\n unsigned* d_from = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 0 * Ndarray_STRIDE(edges, 0));\n unsigned* d_to = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 1 * Ndarray_STRIDE(edges, 0));\n unsigned* d_emission_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 2 * Ndarray_STRIDE(edges, 0));\n unsigned* d_sequence_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 3 * Ndarray_STRIDE(edges, 0));\n float* d_weights = Ndarray_DEV_DATA(weights);\n float* d_am_scores = Ndarray_DEV_DATA(am_scores);\n unsigned* d_start_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_end_states)\n + 0 * Ndarray_STRIDE(start_end_states, 0));\n unsigned* d_end_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_end_states)\n + 1 * Ndarray_STRIDE(start_end_states, 0));\n float* d_index = Ndarray_DEV_DATA(index);\n float* d_state_buffer_prev = Ndarray_DEV_DATA(state_buffer) + 0 * Ndarray_STRIDE(state_buffer, 0);\n float* d_state_buffer_next = Ndarray_DEV_DATA(state_buffer) + 1 * Ndarray_STRIDE(state_buffer, 0);\n float* d_out = Ndarray_DEV_DATA(out);\n float* d_sum_output = Ndarray_DEV_DATA(sum_output);\n\n unsigned n_frames = Ndarray_DIMS(am_scores)[0];\n unsigned n_seqs = Ndarray_DIMS(am_scores)[1];\n unsigned n_emissions = Ndarray_DIMS(am_scores)[2];\n unsigned n_states = Ndarray_DIMS(state_buffer)[1];\n unsigned n_edges = Ndarray_DIMS(edges)[1];\n unsigned n_threads = 1024u;\n unsigned n_blocks = (n_edges + n_threads - 1) / n_threads;\n\n unsigned frame_stride = Ndarray_STRIDE(am_scores, 0);\n unsigned sequence_stride = Ndarray_STRIDE(am_scores, 1);\n unsigned index_stride = Ndarray_STRIDE(index, 0);\n\n assert_cmp(n_frames, >, 0);\n assert_cmp(n_states, >, 0);\n //std::cerr << "n_frames: " << n_frames << std::endl;\n //std::cerr << "n_seqs: " << n_seqs << std::endl;\n //std::cerr << "n_emissions: " << n_emissions << std::endl;\n //std::cerr << "n_states: " << n_states << std::endl;\n //std::cerr << "n_edges: " << n_edges << std::endl;\n //std::cerr << "n_threads: " << n_threads << std::endl;\n //std::cerr << "n_blocks: " << n_blocks << std::endl;\n\n //std::cerr << "frame_stride: " << frame_stride << std::endl;\n //std::cerr << "sequnence_stride: " << sequence_stride << std::endl;\n //std::cerr << "index_stride: " << index_stride << std::endl;\n\n // initialize edge buffer\n float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));\n if(!d_edge_buffer) { HANDLE_LAST_ERROR(); abort(); } // error should have been set in device_malloc\n unsigned n_fill_blocks = (n_edges * n_frames + n_threads - 1u) / n_threads;\n start_dev_kernel2(fill_array, n_fill_blocks, n_threads, 0, (d_edge_buffer, 0.0, n_edges * n_frames));\n HANDLE_LAST_ERROR();\n\n // initialize the state buffer\n n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n start_dev_kernel2(\n fill_array, n_fill_blocks, n_threads, 0,\n (d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states));\n HANDLE_LAST_ERROR();\n start_dev_kernel2(set_start_states, 1, n_seqs, 0, (d_state_buffer_prev, d_start_states));\n HANDLE_LAST_ERROR();\n\n // initialize full state buffer (only used to dump the alignment)\n float* d_state_buffer_all = NULL;\n if (dump_alignment && batch_idx %% dump_every == 0) {\n d_state_buffer_all = reinterpret_cast<float*>(device_malloc(n_states * (n_frames + 1u) * sizeof(float)));\n if(!d_state_buffer_all) { HANDLE_LAST_ERROR(); abort(); } // error should have been set in device_malloc\n Ndarray_memcpy(d_state_buffer_all, d_state_buffer_prev, n_states * sizeof(float));\n HANDLE_LAST_ERROR();\n }\n\n // fwd pass\n for (unsigned t = 0u; t < n_frames; t++) {\n start_dev_kernel2(\n fill_array, n_fill_blocks, n_threads, 0,\n (d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states));\n HANDLE_LAST_ERROR();\n start_dev_kernel2(next_frame, n_blocks, n_threads, 0,\n (true, n_edges, sequence_stride,\n d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs,\n d_state_buffer_prev, d_state_buffer_next, d_am_scores + t * frame_stride, d_edge_buffer + t * n_edges));\n HANDLE_LAST_ERROR();\n if (dump_alignment && batch_idx %% dump_every == 0) {\n Ndarray_memcpy(d_state_buffer_all + (t + 1u) * n_states, d_state_buffer_next, n_states * sizeof(float));\n HANDLE_LAST_ERROR();\n }\n std::swap(d_state_buffer_prev, d_state_buffer_next);\n }\n\n // bwd pass\n start_dev_kernel2(\n fill_array, n_fill_blocks, n_threads, 0,\n (d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states));\n HANDLE_LAST_ERROR();\n for (unsigned t = n_frames; t > 0; t--) {\n start_dev_kernel2(init_bwd_state_buffer, 1, n_seqs, 0,\n (d_state_buffer_prev, d_end_states, t - 1, n_frames - 1, d_index, index_stride));\n HANDLE_LAST_ERROR();\n if (dump_alignment && batch_idx %% dump_every == 0) {\n float alpha = 1.0f;\n //HANDLE_ERROR(cublasSaxpy(\n // handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all + t * n_states, 1));\n }\n start_dev_kernel2(\n fill_array, n_fill_blocks, n_threads, 0,\n (d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states));\n HANDLE_LAST_ERROR();\n start_dev_kernel2(next_frame, n_blocks, n_threads, 0,\n (false, n_edges, sequence_stride,\n d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs,\n d_state_buffer_prev, d_state_buffer_next, d_am_scores + (t - 1) * frame_stride,\n d_edge_buffer + (t - 1) * n_edges));\n HANDLE_LAST_ERROR();\n std::swap(d_state_buffer_prev, d_state_buffer_next);\n }\n if (dump_alignment && batch_idx %% dump_every == 0) {\n float alpha = 1.0f;\n //HANDLE_ERROR(cublasSaxpy(handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all, 1));\n }\n\n // normalize at each time frame\n start_dev_kernel2(normalize, n_frames, 1, n_seqs * sizeof(float),\n (d_edge_buffer, d_sequence_idxs, n_edges, n_seqs, d_sum_output));\n HANDLE_LAST_ERROR();\n\n // dump alignment\n if (dump_alignment && batch_idx %% dump_every == 0) {\n write_alignment_to_file(d_state_buffer_all, d_index, index_stride, d_start_states, d_end_states,\n pruning, n_frames, n_seqs, n_states, batch_idx);\n }\n\n n_fill_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n start_dev_kernel2(\n fill_array, n_fill_blocks, n_threads, 0,\n (d_out, std::numeric_limits<float>::infinity(), n_frames * n_seqs * n_emissions));\n HANDLE_LAST_ERROR();\n\n frame_stride = Ndarray_STRIDE(out, 0);\n sequence_stride = Ndarray_STRIDE(out, 1);\n n_blocks = (n_frames * n_edges + n_threads - 1u) / n_threads;\n start_dev_kernel2(compute_result, n_blocks, n_threads, 0,\n (d_edge_buffer, d_out, d_emission_idxs, d_sequence_idxs,\n frame_stride, sequence_stride, n_frames, n_seqs, n_edges));\n HANDLE_LAST_ERROR();\n\n #if TENSORFLOW\n // Certain TensorFlow code doesn\'t like inf, even if it is just the CheckNumerics,\n // which is helpful for debugging.\n // We replace it by a very high number, so that tf.exp(-out) will still result in 0.0.\n n_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n start_dev_kernel2(remove_inf, n_blocks, n_threads, 0, (d_out, n_frames * n_seqs * n_emissions));\n //debug_print(context, out, "out");\n #endif\n if (dump_output && batch_idx %% dump_every == 0) {\n write_output_to_file(d_out, d_index, index_stride, pruning, n_frames, n_seqs, n_emissions, batch_idx);\n }\n\n device_free(d_edge_buffer);\n if (d_state_buffer_all != NULL) {\n device_free(d_state_buffer_all);\n }\n batch_idx++;\n '[source]#
- class returnn.native_op.MultiEndFastBaumWelchOp[source]#
- inputs:
- param am_scores:
scores in -log space. 3d (time,batch,dim)
- param edges:
edges of the graph (from,to,emission_idx,sequence_idx)
- param weights:
weights of the edges
- outputs:
- param output:
Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
- in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_states', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (None, 2)}, {'gradient': 'disconnected', 'name': 'end_state_weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((4, 0),)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'state_buffer', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)})[source]#
- out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'sums', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n DEF_KERNEL\n void set_start_states(float* states, unsigned* start_states) {\n unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n states[state_idx] = 0.0;\n }\n ', '010_fill_array': '\n DEF_KERNEL\n void fill_array(float* array, float value, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = value;\n }\n }\n ', '011_remove_inf': '\n DEF_KERNEL\n void remove_inf(float* array, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = fminf(array[idx], 1e32);\n }\n }\n ', '012_prob_add': '\n DEV_FUNC\n float prob_add(float a, float b) {\n float diff = a - b;\n if (isnan(diff)) {\n return INF_F;\n }\n else {\n return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n }\n }\n ', '013_atomic_prob_add': '\n DEV_FUNC\n void atomic_prob_add(float* a, float b) {\n int* addr = (int*)a;\n int old = float_as_int(*a);\n int assumed;\n do {\n assumed = old;\n old = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n } while (old != assumed);\n }\n ', '020_dump_to_file': "\n template<typename T>\n void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n std::vector<T> buffer(n_d1);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n T val = buffer[i1];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << val << '\\n';\n }\n }\n }\n\n template<typename T>\n void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n T val = buffer[i1 * n_d2 + i2];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << val << '\\n';\n }\n }\n }\n }\n\n template<typename T>\n void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n }\n }\n }\n }\n }\n ", '100_init_bwd_state_buffer': '\n __global__\n void init_bwd_state_buffer(unsigned t, unsigned max_t, unsigned num_endstates, unsigned index_stride,\n float* states, unsigned const* end_states, float const* end_state_weights,\n float const* index) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_endstates) {\n return;\n }\n\n unsigned seq_idx = end_states[idx * 2u + 0u];\n if (index[t * index_stride + seq_idx] == 1.0\n && (t == max_t || index[(t + 1) * index_stride + seq_idx] == 0.0)) {\n unsigned state_idx = end_states[idx * 2u + 1u];\n float weight = end_state_weights[idx];\n states[state_idx] = weight;\n }\n }\n ', '101_next_frame': '\n DEF_KERNEL\n void next_frame(bool fwd, unsigned num_edges, unsigned num_emissions,\n unsigned* sequence_idxs, unsigned* from_buffer, unsigned* to_buffer, float* weight_buffer,\n unsigned* emission_idxs,\n float* prev_frame, float* next_frame, float* am_scores, float* edge_buffer) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_edges) {\n return;\n }\n\n unsigned from = from_buffer [idx];\n float prev_val = prev_frame[from];\n if (isinf(prev_val)) {\n edge_buffer[idx] = INF_F;\n return;\n }\n\n unsigned to = to_buffer [idx];\n unsigned emission_idx = emission_idxs[idx];\n float edge_weight = weight_buffer[idx];\n unsigned sequence_idx = sequence_idxs[idx];\n\n float val = prev_val + edge_weight + am_scores[sequence_idx * num_emissions + emission_idx];\n\n if (fwd) {\n edge_buffer[idx] += val;\n }\n else {\n edge_buffer[idx] += prev_val;\n }\n atomic_prob_add(next_frame + to, val);\n }\n ', '102_normalize': '\n DEF_KERNEL\n void normalize(float* buffer, unsigned* sequence_idxs, unsigned num_edges, unsigned num_seqs, float* sum_output) {\n DEF_SHARED(float, sum);\n\n buffer += blockIdx.x * num_edges;\n\n for (unsigned s = 0u; s < num_seqs; s++) {\n sum[s] = INF_F;\n }\n\n for (unsigned e = 0u; e < num_edges; e++) {\n unsigned s = sequence_idxs[e];\n sum[s] = prob_add(sum[s], buffer[e]);\n }\n\n for (unsigned s = 0ul; s < num_seqs; s++) {\n if (isinf(sum[s])) {\n // if the frame is empty (happens due to batching of seqs with unequal length), set it to 0\n sum_output[blockIdx.x * num_seqs + s] = 0.0;\n }\n else {\n sum_output[blockIdx.x * num_seqs + s] = sum[s];\n }\n }\n\n for (unsigned e = 0u; e < num_edges; e++) {\n unsigned s = sequence_idxs[e];\n buffer[e] -= sum[s];\n }\n }\n ', '103_compute_result': '\n DEF_KERNEL\n void compute_result(float* edge_buffer, float* out, unsigned* emission_idxs, unsigned* sequence_idxs,\n unsigned frame_stride, unsigned seq_stride,\n unsigned num_frames, unsigned num_seqs, unsigned num_edges) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_frames * num_edges) {\n return;\n }\n\n unsigned e_idx = idx % num_edges;\n unsigned frame = idx / num_edges;\n unsigned emission_idx = emission_idxs[e_idx];\n unsigned seq_idx = sequence_idxs[e_idx];\n float score = edge_buffer[idx];\n\n atomic_prob_add(out + frame * frame_stride + seq_idx * seq_stride + emission_idx, score);\n }\n ', '110_write_alignment_to_file': '\n void write_alignment_to_file(float* d_state_buffer, float* d_index, unsigned index_stride,\n unsigned* d_start_states, unsigned* d_end_states,\n float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_states,\n unsigned batch_idx) {\n std::vector<float> state_buffer((n_frames + 1u) * n_states);\n std::vector<float> index (n_frames * index_stride);\n std::vector<unsigned> start_states(n_seqs);\n std::vector<unsigned> end_states (n_seqs);\n\n //HANDLE_ERROR(cudaMemcpy(\n // state_buffer.data(), d_state_buffer, state_buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // index.data(), d_index, index.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // start_states.data(), d_start_states, start_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(\n // end_states.data(), d_end_states, end_states.size() * sizeof(float), cudaMemcpyDeviceToHost));\n\n for (unsigned seq = 0u; seq < n_seqs; seq++) {\n std::stringstream filename;\n filename << "alignment.dump." << batch_idx << \'.\' << seq;\n std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n for (unsigned t = 0u; t < n_frames; t++) {\n if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n break;\n }\n float sum = std::numeric_limits<float>::infinity();\n for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n const float val = state_buffer[t * n_states + s];\n float diff = val - sum;\n if (!isnan(diff)) {\n sum = -log1p(exp(-abs(diff))) + fminf(sum, val);\n }\n }\n for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {\n const float val = state_buffer[t * n_states + s] - sum;\n if (val <= pruning) {\n out << t << \' \' << (s - start_states[seq]) << \' \' << val << \'\\n\';\n }\n }\n }\n }\n }\n ', '111_write_output_to_file': '\n void write_output_to_file(float* d_out, float* d_index, unsigned index_stride,\n float pruning, unsigned n_frames, unsigned n_seqs, unsigned n_emissions,\n unsigned batch_idx) {\n std::vector<float> buffer(n_frames * n_seqs * n_emissions);\n std::vector<float> index (n_frames * index_stride);\n\n //HANDLE_ERROR(cudaMemcpy(buffer.data(), d_out, buffer.size() * sizeof(float), cudaMemcpyDeviceToHost));\n //HANDLE_ERROR(cudaMemcpy(index.data(), d_index, index.size() * sizeof(float), cudaMemcpyDeviceToHost));\n\n for (unsigned seq = 0u; seq < n_seqs; seq++) {\n std::stringstream filename;\n filename << "target.dump." << batch_idx << \'.\' << seq;\n std::ofstream out(filename.str().c_str(), std::ios::out | std::ios::trunc);\n for (unsigned t = 0u; t < n_frames; t++) {\n if (t > 0u && index[seq * index_stride + t] <= 0.0) {\n break;\n }\n for (unsigned e = 0u; e < n_emissions; e++) {\n const float val = buffer[t * n_seqs * n_emissions + seq * n_emissions + e];\n if (val <= pruning) {\n out << t << \' \' << e << \' \' << val << \'\\n\';\n }\n }\n }\n }\n }\n '}[source]#
- c_fw_code: str = '\n // am_scores, edges, weights, start_states, end_states, end_state_weights,\n // index, state_buffer* = input_names (*: inplace)\n // output = output_names\n assert(n_inputs == 8);\n assert(n_outputs == 2);\n Ndarray* am_scores = inputs[0];\n Ndarray* edges = inputs[1];\n Ndarray* weights = inputs[2];\n Ndarray* start_states = inputs[3];\n Ndarray* end_states = inputs[4];\n Ndarray* end_state_weights = inputs[5];\n Ndarray* index = inputs[6];\n Ndarray* state_buffer = inputs[7];\n Ndarray* out = *outputs[0];\n Ndarray* sum_output = *outputs[1];\n\n assert(Ndarray_DIMS(am_scores)[0] == Ndarray_DIMS(out)[0]);\n assert(Ndarray_DIMS(am_scores)[1] == Ndarray_DIMS(out)[1]);\n assert(Ndarray_DIMS(am_scores)[2] == Ndarray_DIMS(out)[2]);\n// assert(Ndarray_DIMS(am_scores)[1] == Ndarray_DIMS(end_states)[0]);\n\n assert(Ndarray_DIMS(sum_output)[0] == Ndarray_DIMS(am_scores)[0]);\n assert(Ndarray_DIMS(sum_output)[1] == Ndarray_DIMS(am_scores)[1]);\n\n bool dump_alignment = false;\n bool dump_output = false;\n unsigned dump_every = 40u;\n static unsigned batch_idx = 0u;\n float pruning = 10.f;\n\n unsigned* d_from = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 0 * Ndarray_STRIDE(edges, 0));\n unsigned* d_to = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 1 * Ndarray_STRIDE(edges, 0));\n unsigned* d_emission_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 2 * Ndarray_STRIDE(edges, 0));\n unsigned* d_sequence_idxs = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(edges)\n + 3 * Ndarray_STRIDE(edges, 0));\n float* d_weights = Ndarray_DEV_DATA(weights);\n float* d_am_scores = Ndarray_DEV_DATA(am_scores);\n unsigned* d_start_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(start_states));\n unsigned* d_end_states = reinterpret_cast<unsigned*>(Ndarray_DEV_DATA_int32(end_states));\n float* d_end_state_weights = Ndarray_DEV_DATA(end_state_weights);\n float* d_index = Ndarray_DEV_DATA(index);\n float* d_state_buffer_prev = Ndarray_DEV_DATA(state_buffer) + 0 * Ndarray_STRIDE(state_buffer, 0);\n float* d_state_buffer_next = Ndarray_DEV_DATA(state_buffer) + 1 * Ndarray_STRIDE(state_buffer, 0);\n float* d_out = Ndarray_DEV_DATA(out);\n float* d_sum_output = Ndarray_DEV_DATA(sum_output);\n\n unsigned n_frames = Ndarray_DIMS(am_scores)[0];\n unsigned n_seqs = Ndarray_DIMS(am_scores)[1];\n unsigned n_emissions = Ndarray_DIMS(am_scores)[2];\n unsigned n_states = Ndarray_DIMS(state_buffer)[1];\n unsigned n_edges = Ndarray_DIMS(edges)[1];\n unsigned n_start_states = Ndarray_DIMS(start_states)[0];\n unsigned n_end_states = Ndarray_DIMS(end_states)[0];\n unsigned n_threads = 1024u;\n unsigned n_blocks = (n_edges + n_threads - 1) / n_threads;\n\n unsigned frame_stride = Ndarray_STRIDE(am_scores, 0);\n unsigned sequence_stride = Ndarray_STRIDE(am_scores, 1);\n unsigned index_stride = Ndarray_STRIDE(index, 0);\n\n assert(n_frames > 0);\n\n// std::cerr << "n_frames: " << n_frames << std::endl;\n// std::cerr << "n_seqs: " << n_seqs << std::endl;\n// std::cerr << "n_emissions: " << n_emissions << std::endl;\n// std::cerr << "n_states: " << n_states << std::endl;\n// std::cerr << "n_edges: " << n_edges << std::endl;\n// std::cerr << "n_start_states: " << n_start_states << std::endl;\n// std::cerr << "n_end_states: " << n_end_states << std::endl;\n// std::cerr << "n_threads: " << n_threads << std::endl;\n// std::cerr << "n_blocks: " << n_blocks << std::endl;\n\n// std::cerr << "frame_stride: " << frame_stride << std::endl;\n// std::cerr << "sequence_stride: " << sequence_stride << std::endl;\n// std::cerr << "index_stride: " << index_stride << std::endl;\n\n // initialize edge buffer\n float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n unsigned n_fill_blocks = (n_edges * n_frames + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(d_edge_buffer, 0.0, n_edges * n_frames);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n\n // initialize the state buffer\n n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n set_start_states<<<1, n_start_states>>>(d_state_buffer_prev, d_start_states);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n\n // initialize full state buffer (only used to dump the alignment)\n float* d_state_buffer_all = NULL;\n if (dump_alignment and batch_idx %% dump_every == 0) {\n d_state_buffer_all = reinterpret_cast<float*>(device_malloc(n_states * (n_frames + 1u) * sizeof(float)));\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n cudaMemcpy(d_state_buffer_all, d_state_buffer_prev, n_states * sizeof(float), cudaMemcpyDeviceToDevice);\n// HANDLE_LAST_ERROR();\n }\n\n // fwd pass\n for (unsigned t = 0u; t < n_frames; t++) {\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n// std::cerr << "frame " << t << std::endl;\n next_frame<<<n_blocks, n_threads>>>(true, n_edges, sequence_stride,\n d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs,\n d_state_buffer_prev, d_state_buffer_next, d_am_scores + t * frame_stride,\n d_edge_buffer + t * n_edges);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n if (dump_alignment and batch_idx %% dump_every == 0) {\n cudaMemcpy(\n d_state_buffer_all + (t + 1u) * n_states, d_state_buffer_next, n_states * sizeof(float),\n cudaMemcpyDeviceToDevice);\n }\n std::swap(d_state_buffer_prev, d_state_buffer_next);\n }\n\n // bwd pass\n const unsigned n_end_state_blocks = (n_end_states + n_threads - 1u) / n_threads;\n const unsigned n_end_state_threads = min(n_threads, n_end_states);\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n for (unsigned t = n_frames; t > 0; t--) {\n init_bwd_state_buffer<<<n_end_state_blocks, n_end_state_threads>>>(\n t - 1, n_frames - 1, n_end_states, index_stride,\n d_state_buffer_prev, d_end_states, d_end_state_weights, d_index);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n if (dump_alignment and batch_idx %% dump_every == 0) {\n float alpha = 1.0f;\n// HANDLE_ERROR(cublasSaxpy(\n// handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all + t * n_states, 1));\n }\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n next_frame<<<n_blocks, n_threads>>>(false, n_edges, sequence_stride,\n d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs,\n d_state_buffer_prev, d_state_buffer_next,\n d_am_scores + (t - 1) * frame_stride,\n d_edge_buffer + (t - 1) * n_edges);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n std::swap(d_state_buffer_prev, d_state_buffer_next);\n }\n if (dump_alignment and batch_idx %% dump_every == 0) {\n float alpha = 1.0f;\n// HANDLE_ERROR(cublasSaxpy(handle, n_states, &alpha, d_state_buffer_prev, 1, d_state_buffer_all, 1));\n }\n\n // normalize at each time frame\n normalize<<<n_frames, 1, n_seqs * sizeof(float)>>>(d_edge_buffer, d_sequence_idxs, n_edges, n_seqs, d_sum_output);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n\n // dump alignment\n if (dump_alignment and batch_idx %% dump_every == 0) {\n write_alignment_to_file(d_state_buffer_all, d_index, index_stride, d_start_states, d_end_states,\n pruning, n_frames, n_seqs, n_states, batch_idx);\n }\n\n n_fill_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(\n d_out, std::numeric_limits<float>::infinity(), n_frames * n_seqs * n_emissions);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n\n frame_stride = Ndarray_STRIDE(out, 0);\n sequence_stride = Ndarray_STRIDE(out, 1);\n n_blocks = (n_frames * n_edges + n_threads - 1u) / n_threads;\n compute_result<<<n_blocks, n_threads>>>(d_edge_buffer, d_out, d_emission_idxs, d_sequence_idxs,\n frame_stride, sequence_stride, n_frames, n_seqs, n_edges);\n// cudaDeviceSynchronize();\n// HANDLE_LAST_ERROR();\n\n #if TENSORFLOW\n // Certain TensorFlow code doesn\'t like inf, even if it is just the CheckNumerics,\n // which is helpful for debugging.\n // We replace it by a very high number, so that tf.exp(-out) will still result in 0.0.\n n_blocks = (n_frames * n_seqs * n_emissions + n_threads - 1u) / n_threads;\n remove_inf<<<n_blocks, n_threads>>>(d_out, n_frames * n_seqs * n_emissions);\n //debug_print(context, out, "out");\n #endif\n if (dump_output and batch_idx %% dump_every == 0) {\n write_output_to_file(d_out, d_index, index_stride, pruning, n_frames, n_seqs, n_emissions, batch_idx);\n }\n\n device_free(d_edge_buffer);\n if (d_state_buffer_all != NULL) {\n device_free(d_state_buffer_all);\n }\n batch_idx++;\n '[source]#
- class returnn.native_op.SegmentFastBaumWelchOp(segmentwise_normalization=False, dump_targets_interval=None, new_batch_idxs_format=False)[source]#
Segmental Baum-Welch…
- out_info: Tuple[Dict[str]] = ({'name': 'output', 'ndim': 3, 'need_contiguous': True, 'shape': ((0, 0), (0, 1), (0, 2))}, {'name': 'normalization_factors', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'name': 'posterior_weigths', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))})[source]#
- c_extra_support_code: Dict[str, str] = {'001_set_start_states': '\n DEF_KERNEL\n void set_start_states(float* states, unsigned* start_states) {\n unsigned state_idx = start_states[blockIdx.x * blockDim.x + threadIdx.x];\n states[state_idx] = 0.0;\n }\n ', '010_fill_array': '\n DEF_KERNEL\n void fill_array(float* array, float value, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = value;\n }\n }\n ', '011_remove_inf': '\n DEF_KERNEL\n void remove_inf(float* array, unsigned size) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx < size) {\n array[idx] = fminf(array[idx], 1e32);\n }\n }\n ', '012_prob_add': '\n DEV_FUNC\n float prob_add(float a, float b) {\n float diff = a - b;\n if (isnan(diff)) {\n return INF_F;\n }\n else {\n return -log1pf(expf(-fabsf(diff))) + fminf(a, b);\n }\n }\n ', '013_atomic_prob_add': '\n DEV_FUNC\n void atomic_prob_add(float* a, float b) {\n int* addr = (int*)a;\n int old = float_as_int(*a);\n int assumed;\n do {\n assumed = old;\n old = elem_atomic_cas(addr, assumed, float_as_int(prob_add(int_as_float(old), b)));\n } while (old != assumed);\n }\n ', '020_dump_to_file': "\n template<typename T>\n void dump_to_file_1d(T* d_mem, unsigned n_d1, std::string const& path) {\n std::vector<T> buffer(n_d1);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n T val = buffer[i1];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << val << '\\n';\n }\n }\n }\n\n template<typename T>\n void dump_to_file_2d(T* d_mem, unsigned n_d1, unsigned n_d2, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n T val = buffer[i1 * n_d2 + i2];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << val << '\\n';\n }\n }\n }\n }\n\n template<typename T>\n void dump_to_file_3d(T* d_mem, unsigned n_d1, unsigned n_d2, unsigned n_d3, std::string const& path) {\n std::vector<T> buffer(n_d1 * n_d2 * n_d3);\n //cudaMemcpy(buffer.data(), d_mem, buffer.size() * sizeof(T), cudaMemcpyDeviceToHost);\n\n std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);\n for (size_t i1 = 0ul; i1 < n_d1; i1++) {\n for (size_t i2 = 0ul; i2 < n_d2; i2++) {\n for (size_t i3 = 0ul; i3 < n_d3; i3++) {\n T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];\n if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {\n output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';\n }\n }\n }\n }\n }\n ", '100_get_batch_idx': '\n __device__\n int get_batch_idx(int const* batch_idxs, unsigned num_seqs, unsigned t, unsigned seq_idx) {\n if (NEW_BATCH_IDX_FORMAT) {\n int res = batch_idxs[seq_idx] + t;\n if (res >= batch_idxs[seq_idx + 1]) {\n return -1;\n }\n return res;\n }\n else {\n return batch_idxs[t * num_seqs + seq_idx];\n }\n }\n ', '101_init_bwd_state_buffer': '\n __global__\n void init_bwd_state_buffer(unsigned t, unsigned num_batches, unsigned num_seqs,\n int* batch_idxs, float* index, float* states, unsigned* end_states) {\n unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n int batch_idx = get_batch_idx(batch_idxs, num_seqs, t, idx);\n if (batch_idx < 0) {\n return;\n }\n float* batch_first_frame = index + batch_idx;\n //if (*batch_first_frame != 0.0 && (t == max_t || *(batch_first_frame + 1) == 0.0)) {\n if (batch_first_frame[0] != 0.0 && batch_first_frame[num_batches] == 0.0) {\n unsigned state_idx = end_states[idx];\n states[state_idx] = 0.0;\n }\n }\n ', '102_next_frame_fwd': '\n __global__\n void next_frame_fwd(unsigned time, unsigned num_states, unsigned num_edges, unsigned num_emissions,\n unsigned num_seg_frames,\n unsigned num_tot_frames, unsigned num_seqs, unsigned num_am_score_scales,\n unsigned const* sequence_idxs, unsigned const* from_buffer, unsigned const* to_buffer,\n float const* weight_buffer,\n unsigned const* emission_idxs, unsigned const* lenmod_idxs, int const* batch_idxs,\n float const* am_scores, float const* length_models, float const* am_score_scales,\n float const* epoch,\n float* state_buffer, float* edge_buffer) {\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_edges) {\n return;\n }\n\n const unsigned num_ringbuffer_frames = num_seg_frames + 1;\n const unsigned max_seg_frames = min(num_seg_frames, num_tot_frames - time);\n\n const unsigned prev_frame_idx = time % num_ringbuffer_frames;\n const unsigned prev_frame_start = prev_frame_idx * num_states;\n\n const unsigned from = from_buffer [idx];\n const float prev_val = state_buffer[prev_frame_start + from];\n if (isinf(prev_val)) {\n return;\n }\n\n const unsigned sequence_idx = sequence_idxs[idx];\n const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, sequence_idx);\n if (batch_idx == -1) {\n return;\n }\n\n const unsigned amss_idx = min(static_cast<unsigned>(*epoch), num_am_score_scales - 1);\n const float am_score_scale = am_score_scales[amss_idx];\n\n const unsigned to = to_buffer [idx];\n const unsigned emission_idx = emission_idxs[idx];\n const unsigned lenmod_idx = lenmod_idxs [idx];\n const float edge_weight = weight_buffer[idx];\n const float prev_plus_edge = prev_val + edge_weight;\n\n float const* am_buffer_in = am_scores + batch_idx * num_seg_frames * num_emissions + emission_idx;\n float const* length_scores = length_models + lenmod_idx * num_seg_frames;\n float* edge_buffer_out = edge_buffer + idx;\n\n for (unsigned i = 0u; i < max_seg_frames; i++) {\n const float val = prev_plus_edge + am_score_scale * am_buffer_in[i * num_emissions] + length_scores[i];\n edge_buffer_out[i * num_edges] = val;\n const unsigned next_frame = (prev_frame_idx + 1 + i) % num_ringbuffer_frames;\n atomic_prob_add(state_buffer + (next_frame * num_states + to), val);\n }\n }\n ', '103_next_frame_bwd': '\n __global__\n void next_frame_bwd(unsigned time, unsigned num_states, unsigned num_edges, unsigned num_emissions,\n unsigned num_seg_frames,\n unsigned num_tot_frames, unsigned num_seqs, unsigned num_am_score_scales,\n unsigned const* sequence_idxs, unsigned const* from_buffer, unsigned const* to_buffer,\n float const* weight_buffer,\n unsigned const* emission_idxs, unsigned const* lenmod_idxs, int const* batch_idxs,\n float const* am_scores, float const* length_models, float const* am_score_scales,\n float const* epoch,\n float* state_buffer, float* edge_buffer) {\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_edges) {\n return;\n }\n\n const unsigned num_ringbuffer_frames = num_seg_frames + 1;\n const unsigned max_seg_frames = min(num_seg_frames, num_tot_frames - time);\n\n const unsigned sequence_idx = sequence_idxs[idx];\n const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, sequence_idx);\n if (batch_idx == -1) {\n return;\n }\n\n const unsigned amss_idx = min(static_cast<unsigned>(*epoch), num_am_score_scales - 1);\n const float am_score_scale = am_score_scales[amss_idx];\n\n const unsigned from = from_buffer [idx];\n const unsigned to = to_buffer [idx];\n const unsigned emission_idx = emission_idxs[idx];\n const unsigned lenmod_idx = lenmod_idxs [idx];\n const float edge_weight = weight_buffer[idx];\n const unsigned next_frame_idx = time % num_ringbuffer_frames;\n\n float const* am_buffer_in = am_scores + batch_idx * num_seg_frames * num_emissions + emission_idx;\n float const* length_scores = length_models + lenmod_idx * num_seg_frames;\n float* edge_buffer_out = edge_buffer + idx;\n\n float acc_val = CUDART_INF_F;\n\n for (unsigned i = 0u; i < max_seg_frames; i++) {\n const unsigned prev_frame_idx = (next_frame_idx + i + 1) % num_ringbuffer_frames;\n const float prev_val = state_buffer[prev_frame_idx * num_states + from];\n if (isinf(prev_val)) {\n edge_buffer_out[i * num_edges] = CUDART_INF_F;\n }\n else {\n const float val =\n prev_val + edge_weight + am_score_scale * am_buffer_in[i * num_emissions] + length_scores[i];\n edge_buffer_out[i * num_edges] += prev_val;\n acc_val = prob_add(acc_val, val);\n }\n }\n\n atomic_prob_add(state_buffer + next_frame_idx * num_states + to, acc_val);\n }\n ', '104_compute_framewise_sum': '\n __global__\n void compute_framewise_sum(unsigned num_tot_frames, unsigned num_seqs, unsigned num_seg_frames,\n unsigned num_batches, unsigned num_edges,\n unsigned const* sequence_idxs, int const* batch_idxs, float const* index,\n float const* edge_buffer,\n float* output_buffer) {\n extern __shared__ float sum[];\n\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_tot_frames * num_seg_frames) {\n return;\n }\n\n float* sum_buffer = sum + threadIdx.x * num_seqs;\n edge_buffer += idx * num_edges;\n\n for (unsigned s = 0u; s < num_seqs; s++) {\n sum_buffer[s] = CUDART_INF_F;\n }\n\n for (unsigned i = 0; i < num_edges; i++) {\n const unsigned seq_idx = sequence_idxs[i];\n sum_buffer[seq_idx] = prob_add(sum_buffer[seq_idx], edge_buffer[i]);\n }\n\n const unsigned time = idx / num_seg_frames;\n const unsigned seg_size = idx % num_seg_frames;\n for (unsigned s = 0u; s < num_seqs; s++) {\n const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, s);\n if (batch_idx >= 0) {\n const unsigned output_idx = seg_size * num_batches + batch_idx;\n if (isinf(sum_buffer[s]) or index[output_idx] == 0.0) {\n output_buffer[output_idx] = 0.0;\n }\n else {\n output_buffer[output_idx] = sum_buffer[s];\n }\n }\n }\n }\n ', '105_merge_framewise_sums': '\n __global__\n void merge_framewise_sum(unsigned num_seg_frames, unsigned num_batches, float const* index, float* sum_buffer) {\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_batches) {\n return;\n }\n\n sum_buffer += idx;\n index += idx;\n\n float sum = sum_buffer[0];\n for (unsigned s = 1; s < num_seg_frames; s++) {\n if (index[s * num_batches] != 0.0f) {\n sum = prob_add(sum, sum_buffer[s * num_batches]);\n }\n }\n\n for (unsigned s = 0; s < num_seg_frames; s++) {\n if (index[s * num_batches] != 0.0f) {\n sum_buffer[s * num_batches] = sum;\n }\n }\n }\n ', '106_compute_targets': '\n __global__\n void compute_targets(unsigned num_tot_frames, unsigned num_seg_frames, unsigned num_edges, unsigned num_batches,\n unsigned num_seqs, unsigned num_emissions,\n unsigned const* sequence_idxs, unsigned const* emission_idxs, int const* batch_idxs,\n float const* index,\n float const* edge_buffer, float const* normalization_buffer, float* output_buffer) {\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_tot_frames * num_seg_frames * num_edges) {\n return;\n }\n\n const unsigned edge_idx = idx % num_edges;\n const unsigned time = idx / (num_edges * num_seg_frames);\n const unsigned seq_idx = sequence_idxs[edge_idx];\n const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, seq_idx);\n\n if (batch_idx < 0) {\n return;\n }\n\n const unsigned seg_length = (idx / num_edges) % num_seg_frames;\n\n if (index[seg_length * num_batches + batch_idx] == 0.0) {\n return;\n }\n\n const unsigned emission_idx = emission_idxs[edge_idx];\n const float normalization = normalization_buffer[seg_length * num_batches + batch_idx];\n\n atomic_prob_add(\n output_buffer + seg_length * num_batches * num_emissions + batch_idx * num_emissions + emission_idx,\n edge_buffer[idx] - normalization);\n }\n ', '107_compute_posterior_weights': '\n __global__\n void compute_posterior_weights(unsigned num_tot_frames, unsigned num_seg_frames, unsigned num_seqs,\n unsigned num_batches,\n float const* state_buffer, unsigned const* start_states, int const* batch_idxs,\n float const* index, float const* normalization_factors, float* posterior_weigths) {\n const unsigned idx = blockIdx.x * blockDim.x + threadIdx.x;\n if (idx >= num_tot_frames * num_seqs) {\n return;\n }\n\n const unsigned time = idx / num_seqs;\n const unsigned seq_idx = idx % num_seqs;\n\n const int batch_idx = get_batch_idx(batch_idxs, num_seqs, time, seq_idx);\n if (batch_idx < 0) {\n return;\n }\n\n const float seq_sum = state_buffer[start_states[seq_idx]];\n for (unsigned s = 0u; s < num_seg_frames; s++) {\n const unsigned i = s * num_batches + batch_idx;\n if (index[i] == 0.0) {\n return;\n }\n posterior_weigths[i] = exp(-(normalization_factors[i] - seq_sum));\n }\n }\n '}[source]#
- in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'gradient': 'disconnected', 'name': 'batch_idxs', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((2, 1),)}, {'gradient': 'disconnected', 'name': 'length_models', 'ndim': 2, 'need_contiguous': True, 'shape': (None, (0, 0))}, {'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, None)}, {'gradient': 'disconnected', 'name': 'index', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'gradient': 'disconnected', 'name': 'am_score_scales', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'gradient': 'disconnected', 'name': 'epoch', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
- c_fw_code: str = '\n // inputs: am_scores, batch_idxs, edges, weights, length_models, start_end_states, index, am_score_scales, epoch\n // outputs: output, normalization_factors, posterior_weigths\n assert(n_inputs == 9);\n assert(n_outputs == 3);\n Ndarray* ary_am_scores = inputs[0];\n Ndarray* ary_batch_idxs = inputs[1];\n Ndarray* ary_edges = inputs[2];\n Ndarray* ary_weights = inputs[3];\n Ndarray* ary_start_end_states = inputs[4];\n Ndarray* ary_length_models = inputs[5];\n Ndarray* ary_index = inputs[6];\n Ndarray* ary_am_score_scales = inputs[7];\n Ndarray* ary_epoch = inputs[8];\n Ndarray* ary_out = *outputs[0];\n Ndarray* ary_norm_factors = *outputs[1];\n Ndarray* ary_posterior_weights = *outputs[2];\n\n assert(Ndarray_DIMS(ary_edges)[1] == Ndarray_DIMS(ary_weights)[0]);\n\n static unsigned iter = 0u; // used for debug output\n\n float* d_am_scores = Ndarray_DEV_DATA(ary_am_scores);\n int* d_batch_idxs = reinterpret_cast<int*>(Ndarray_DEV_DATA(ary_batch_idxs));\n unsigned* d_from =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 0 * Ndarray_STRIDE(ary_edges, 0));\n unsigned* d_to =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 1 * Ndarray_STRIDE(ary_edges, 0));\n unsigned* d_emission_idxs =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 2 * Ndarray_STRIDE(ary_edges, 0));\n unsigned* d_lenmod_idxs =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 3 * Ndarray_STRIDE(ary_edges, 0));\n unsigned* d_sequence_idxs =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_edges) + 4 * Ndarray_STRIDE(ary_edges, 0));\n float* d_weights = Ndarray_DEV_DATA(ary_weights);\n float* d_length_models = Ndarray_DEV_DATA(ary_length_models);\n unsigned* d_start_states =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_start_end_states) + 0 * Ndarray_STRIDE(ary_start_end_states, 0));\n unsigned* d_end_states =\n reinterpret_cast<unsigned*>(Ndarray_DEV_DATA(ary_start_end_states) + 1 * Ndarray_STRIDE(ary_start_end_states, 0));\n float* d_index = Ndarray_DEV_DATA(ary_index);\n float* d_am_score_scales = Ndarray_DEV_DATA(ary_am_score_scales);\n float* d_epoch = Ndarray_DEV_DATA(ary_epoch);\n float* d_out = Ndarray_DEV_DATA(ary_out);\n float* d_norm_factors = Ndarray_DEV_DATA(ary_norm_factors);\n float* d_posterior_weights = Ndarray_DEV_DATA(ary_posterior_weights);\n\n std::vector<int> seq_lengths;\n if (NEW_BATCH_IDX_FORMAT) {\n seq_lengths.resize(Ndarray_DIMS(ary_batch_idxs)[0]);\n HANDLE_ERROR(cudaMemcpy(\n seq_lengths.data(), d_batch_idxs, seq_lengths.size() * sizeof(int), cudaMemcpyDeviceToHost));\n }\n\n const unsigned n_seg_frames = Ndarray_DIMS(ary_am_scores)[0];\n const unsigned n_batches = Ndarray_DIMS(ary_am_scores)[1];\n const unsigned n_emissions = Ndarray_DIMS(ary_am_scores)[2];\n const unsigned n_seqs =\n NEW_BATCH_IDX_FORMAT ? (Ndarray_DIMS(ary_batch_idxs)[0] - 1) : Ndarray_DIMS(ary_batch_idxs)[1];\n const unsigned n_tot_frames =\n NEW_BATCH_IDX_FORMAT ? seq_lengths.back() : Ndarray_DIMS(ary_batch_idxs)[0];\n const unsigned n_edges = Ndarray_DIMS(ary_edges)[1];\n const unsigned n_length_models = Ndarray_DIMS(ary_length_models)[1];\n const unsigned n_am_score_scales = Ndarray_DIMS(ary_am_score_scales)[0];\n const unsigned n_threads = 1024u;\n unsigned n_blocks = (n_edges + n_threads - 1) / n_threads;\n\n unsigned tmp;\n HANDLE_ERROR(cudaMemcpy(&tmp, d_end_states + n_seqs - 1, sizeof(float), cudaMemcpyDeviceToHost));\n\n const unsigned n_states = tmp + 1;\n\n /*std::cerr << "seg frames: " << n_seg_frames << std::endl;\n std::cerr << "batches: " << n_batches << std::endl;\n std::cerr << "emissions: " << n_emissions << std::endl;\n std::cerr << "tot frames: " << n_tot_frames << std::endl;\n std::cerr << "seqs: " << n_seqs << std::endl;\n std::cerr << "edges: " << n_edges << std::endl;\n std::cerr << "length models: " << n_length_models << std::endl;\n std::cerr << "threads: " << n_threads << std::endl;\n std::cerr << "blocks: " << n_blocks << std::endl;\n std::cerr << "num states: " << n_states << std::endl;*/\n\n // initialize edge buffer\n const unsigned edge_buffer_size = n_tot_frames * n_seg_frames * n_edges;\n float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(edge_buffer_size * sizeof(float)));\n HANDLE_LAST_ERROR();\n unsigned n_fill_blocks = (edge_buffer_size + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(d_edge_buffer, std::numeric_limits<float>::infinity(), edge_buffer_size);\n HANDLE_LAST_ERROR();\n\n // initialize the state buffer\n const unsigned n_ringbuffer_frames = n_seg_frames + 1;\n float* d_state_buffer = reinterpret_cast<float*>(device_malloc(n_states * n_ringbuffer_frames * sizeof(float)));\n HANDLE_LAST_ERROR();\n n_fill_blocks = (n_states * n_ringbuffer_frames + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(\n d_state_buffer, std::numeric_limits<float>::infinity(), n_states * n_ringbuffer_frames);\n HANDLE_LAST_ERROR();\n\n // initialize sum buffer and posterior weigths\n n_fill_blocks = (n_batches * n_seg_frames + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(d_norm_factors, 0.0f, n_batches * n_seg_frames);\n HANDLE_LAST_ERROR();\n fill_array<<<n_fill_blocks, n_threads>>>(d_posterior_weights, 0.0f, n_batches * n_seg_frames);\n HANDLE_LAST_ERROR();\n\n set_start_states<<<1, n_seqs>>>(d_state_buffer, d_start_states);\n HANDLE_LAST_ERROR();\n\n // fwd pass\n for (unsigned t = 0u; t < n_tot_frames; t++) {\n //std::cerr << "fwd t: " << t << " " << n_tot_frames << std::endl;\n float* d_state_buffer_prev = d_state_buffer + ((t - 1) %% n_ringbuffer_frames) * n_states;\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_prev, std::numeric_limits<float>::infinity(), n_states);\n HANDLE_LAST_ERROR();\n next_frame_fwd<<<n_blocks, n_threads>>>(t, n_states, n_edges, n_emissions, n_seg_frames, n_tot_frames, n_seqs,\n n_am_score_scales,\n d_sequence_idxs, d_from, d_to, d_weights, d_emission_idxs, d_lenmod_idxs,\n d_batch_idxs,\n d_am_scores, d_length_models, d_am_score_scales, d_epoch,\n d_state_buffer, d_edge_buffer + t * n_seg_frames * n_edges);\n HANDLE_LAST_ERROR();\n\n //std::stringstream ss;\n //ss << "dump/fwd_state_buffer." << t << ".dump";\n //dump_to_file_2d(d_state_buffer, n_ringbuffer_frames, n_states, ss.str());\n }\n\n //dump_to_file_3d(d_edge_buffer, n_tot_frames, n_seg_frames, n_edges, "dump/fwd_edges.dump");\n\n // bwd pass\n n_fill_blocks = (n_states * n_ringbuffer_frames + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(\n d_state_buffer, std::numeric_limits<float>::infinity(), n_states * n_ringbuffer_frames);\n HANDLE_LAST_ERROR();\n n_fill_blocks = (n_states + n_threads - 1u) / n_threads;\n for (unsigned t = n_tot_frames; t > 0; t--) {\n //std::cerr <<\n //"bwd t: " << t << " " << n_tot_frames << " buffer next: " << ((t-1) %% n_ringbuffer_frames) << std::endl;\n float* d_state_buffer_next = d_state_buffer + ((t - 1) %% n_ringbuffer_frames) * n_states;\n float* d_state_buffer_prev = d_state_buffer + ( t %% n_ringbuffer_frames) * n_states;\n fill_array<<<n_fill_blocks, n_threads>>>(d_state_buffer_next, std::numeric_limits<float>::infinity(), n_states);\n HANDLE_LAST_ERROR();\n init_bwd_state_buffer<<<1, n_seqs>>>(\n t - 1, n_batches, n_seqs, d_batch_idxs, d_index, d_state_buffer_prev, d_end_states);\n HANDLE_LAST_ERROR();\n next_frame_bwd<<<n_blocks, n_threads>>>(\n t - 1, n_states, n_edges, n_emissions, n_seg_frames, n_tot_frames, n_seqs, n_am_score_scales,\n d_sequence_idxs, d_to, d_from, d_weights, d_emission_idxs, d_lenmod_idxs, d_batch_idxs,\n d_am_scores, d_length_models, d_am_score_scales, d_epoch,\n d_state_buffer, d_edge_buffer + (t - 1) * n_seg_frames * n_edges);\n HANDLE_LAST_ERROR();\n\n //std::stringstream ss;\n //ss << "dump/bwd_state_buffer." << t << ".dump";\n //dump_to_file_2d(d_state_buffer, n_ringbuffer_frames, n_states, ss.str());\n }\n\n n_blocks = (n_tot_frames * n_seg_frames + n_threads - 1) / n_threads;\n compute_framewise_sum<<<n_blocks, n_threads, n_threads * n_seqs * sizeof(float)>>>(\n n_tot_frames, n_seqs, n_seg_frames, n_batches, n_edges,\n d_sequence_idxs, d_batch_idxs,\n d_index, d_edge_buffer, d_norm_factors);\n HANDLE_LAST_ERROR();\n\n //dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, "dump/norm_factors_1.dump");\n\n if (segmentwise_normalization) {\n n_blocks = (n_batches + n_threads - 1) / n_threads;\n merge_framewise_sum<<<n_blocks, n_threads>>>(n_seg_frames, n_batches, d_index, d_norm_factors);\n HANDLE_LAST_ERROR();\n }\n\n //dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, "dump/norm_factors_2.dump");\n\n n_blocks = (n_tot_frames * n_seqs + n_threads - 1) / n_threads;\n compute_posterior_weights<<<n_blocks, n_threads>>>(n_tot_frames, n_seg_frames, n_seqs, n_batches, d_state_buffer,\n d_start_states, d_batch_idxs, d_index, d_norm_factors,\n d_posterior_weights);\n HANDLE_LAST_ERROR();\n\n n_fill_blocks = (n_batches * n_seg_frames * n_emissions + n_threads - 1u) / n_threads;\n fill_array<<<n_fill_blocks, n_threads>>>(\n d_out, std::numeric_limits<float>::infinity(), n_batches * n_seg_frames * n_emissions);\n HANDLE_LAST_ERROR();\n\n n_blocks = (n_tot_frames * n_seg_frames * n_edges + n_threads - 1) / n_threads;\n compute_targets<<<n_blocks, n_threads>>>(n_tot_frames, n_seg_frames, n_edges, n_batches, n_seqs, n_emissions,\n d_sequence_idxs, d_emission_idxs, d_batch_idxs, d_index, d_edge_buffer,\n d_norm_factors, d_out);\n HANDLE_LAST_ERROR();\n\n //dump_to_file_1d(d_weights, n_edges, "dump/edge_weights.dump");\n //dump_to_file_1d(d_sequence_idxs, n_edges, "dump/sequence_idxs.dump");\n //dump_to_file_2d(d_state_buffer, n_ringbuffer_frames, n_states, "dump/state_buffer.dump");\n //dump_to_file_2d(d_batch_idxs, n_tot_frames, n_seqs, "dump/batch_idxs.dump");\n //dump_to_file_2d(d_index, n_seg_frames, n_batches, "dump/index.dump");\n //dump_to_file_3d(d_edge_buffer, n_tot_frames, n_seg_frames, n_edges, "dump/edges.dump");\n //dump_to_file_3d(d_am_scores, n_seg_frames, n_batches, n_emissions, "dump/am_scores.dump");\n //dump_to_file_3d(d_out, n_seg_frames, n_batches, n_emissions, "dump/targets.dump");\n\n if (dump_targets and iter %% dump_targets_interval == 0) {\n std::stringstream ss;\n ss << "dump/targets_" << iter << ".dump";\n dump_to_file_3d(d_out, n_seg_frames, n_batches, n_emissions, ss.str());\n ss.str("");\n ss.clear();\n ss << "dump/norm_factors_" << iter << ".dump";\n dump_to_file_2d(d_norm_factors, n_seg_frames, n_batches, ss.str());\n ss.str("");\n ss.clear();\n ss << "dump/posterior_weights_" << iter << ".dump";\n dump_to_file_2d(d_posterior_weights, n_seg_frames, n_batches, ss.str());\n }\n\n iter += 1;\n\n device_free(d_state_buffer);\n device_free(d_edge_buffer);\n '[source]#
- class returnn.native_op.FastViterbiOp[source]#
- inputs:
- param am_scores:
scores in +log space. 3d (time,batch,dim)
- param am_seq_len:
(batch,)
- param edges:
edges of the graph (from,to,emission_idx,sequence_idx), i.e. (4, n_edges)
- param weights:
weights of the edges (n_edges,)
- param start_end_states:
(2, batch)
- param n_states:
scalar, int32
- outputs:
- param output:
Viterbi (hard) alignment, scores in +log space. 2d (time,batch)
- param scores:
(batch,)
- in_info: Tuple[Dict[str]] = ({'gradient': 'disconnected', 'name': 'am_scores', 'ndim': 3, 'need_contiguous': True, 'shape': (None, None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'am_seq_len', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 0),)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (4, None)}, {'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': ((3, 1),)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, (0, 0))}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'n_states', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
- out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 2, 'need_contiguous': True, 'shape': ((0, 0), (0, 1))}, {'name': 'scores', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 1),)})[source]#
- c_extra_support_code: Dict[str, str] = {'01_IdxAndVal': '\n struct __attribute__((__packed__)) IdxAndVal {\n int idx;\n float val;\n };\n ', '04_select_max': '\n DEV_FUNC\n void select_max(IdxAndVal* a, IdxAndVal b) {\n // fast path\n if(b.val < a->val)\n return;\n // Maybe we could use double compare-and-swap (https://stackoverflow.com/questions/55941382/).\n // But not sure how.\n // So instead, we use double-wide compare-and-swap.\n union U {\n IdxAndVal s;\n unsigned long long int v64;\n };\n while(true) {\n U prev;\n prev.s = *a;\n if(b.val < prev.s.val)\n return;\n if(b.val == prev.s.val && b.idx >= prev.s.idx)\n return;\n U updated;\n updated.s = b;\n\n U old;\n old.v64 = elem_atomic_cas((unsigned long long int*) a, prev.v64, updated.v64);\n if(old.v64 == prev.v64)\n return;\n // Not the same, so repeat.\n }\n }\n ', '05_init_buffer': '\n DEF_KERNEL\n void init_buffer\n (\n int n_time,\n int n_states, // for the whole batch\n IdxAndVal* buffer // (time+1,n_states), states for the whole batch\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(idx < (n_time + 1) * n_states) {\n buffer[idx].val = -INF_F;\n buffer[idx].idx = -1;\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '06_init_first_frame': '\n DEF_KERNEL\n void init_first_frame\n (\n int n_batch,\n int n_states, // for the whole batch\n IdxAndVal* frame, // (n_states,), states for the whole batch\n const int32_t* d_start_states // (n_batch,)\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(idx < n_batch) {\n int state_idx = d_start_states[idx];\n frame[state_idx].val = 0;\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '08_next_frame': '\n DEF_KERNEL\n void next_frame\n (\n int n_time,\n int n_states,\n int n_edges,\n int n_classes,\n int t,\n const float* d_am_scores,\n const int32_t* d_am_seq_len,\n const IdxAndVal* prev_frame,\n IdxAndVal* frame,\n const int32_t* d_edge_from,\n const int32_t* d_edge_to,\n const int32_t* d_edge_emission_idx,\n const int32_t* d_edge_seq_idx,\n const float* d_edge_weights,\n const int32_t* d_end_states // (n_batch,)\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(idx < n_edges) {\n int from_idx = d_edge_from[idx];\n //assert_cmp(0, <=, from_idx); assert_cmp(from_idx, <, n_states);\n\n int seq_idx = d_edge_seq_idx[idx];\n if(t < d_am_seq_len[seq_idx]) {\n float prev_val = prev_frame[from_idx].val;\n int emission_idx = d_edge_emission_idx[idx];\n //assert_cmp(0, <=, emission_idx); assert_cmp(emission_idx, <, n_classes);\n int to_idx = d_edge_to[idx];\n //assert_cmp(0, <=, to_idx); assert_cmp(to_idx, <, n_states);\n IdxAndVal candidate;\n candidate.val = prev_val + d_edge_weights[idx] + d_am_scores[seq_idx * n_classes + emission_idx];\n candidate.idx = idx;\n select_max(&frame[to_idx], candidate);\n }\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '11_select_scores': '\n DEF_KERNEL\n void select_scores\n (\n int n_batch,\n int n_states,\n int buffer_stride,\n const IdxAndVal* buffer,\n const int32_t* d_am_seq_len, // (n_batch,)\n const int32_t* d_end_states, // (n_batch,)\n float* d_score // (n_batch,)\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(idx < n_batch) {\n const IdxAndVal* last_frame = buffer + d_am_seq_len[idx] * buffer_stride;\n int end_state_idx = d_end_states[idx];\n d_score[idx] = last_frame[end_state_idx].val;\n\n idx += gridDim.x * blockDim.x;\n }\n }\n ', '13_select_best_path': '\n DEF_KERNEL\n void select_best_path\n (\n int n_batch,\n int n_states,\n int n_edges,\n int t,\n int32* cur_state, // (n_batch,)\n const IdxAndVal* frame,\n const int32_t* d_am_seq_len,\n const int32_t* d_edge_from,\n const int32_t* d_edge_to,\n const int32_t* d_edge_emission_idx,\n int32_t* output\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n while(idx < n_batch) {\n if(t < d_am_seq_len[idx]) {\n int state_idx = cur_state[idx];\n //assert_cmp(0, <=, state_idx); assert_cmp(state_idx, <, n_states);\n int edge_idx = frame[state_idx].idx;\n if(edge_idx >= 0) {\n //assert_cmp(0, <=, edge_idx); assert_cmp(edge_idx, <, n_edges);\n //assert_cmp(state_idx, ==, d_edge_to[edge_idx]);\n cur_state[idx] = d_edge_from[edge_idx];\n output[idx] = d_edge_emission_idx[edge_idx];\n }\n else // no path found\n output[idx] = 0;\n }\n else {\n output[idx] = 0;\n }\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n using namespace std;\n // am_scores, am_seq_len, edges, weights, start_end_states, n_states = input_names\n // output, scores = output_names\n assert(n_inputs == 6);\n assert(n_outputs == 2);\n Ndarray* am_scores = inputs[0];\n Ndarray* am_seq_len = inputs[1];\n Ndarray* edges = inputs[2];\n Ndarray* weights = inputs[3];\n Ndarray* start_end_states = inputs[4];\n Ndarray* n_states_ref = inputs[5];\n Ndarray* output = *outputs[0];\n Ndarray* score = *outputs[1];\n\n assert_cmp(Ndarray_NDIM(am_scores), ==, 3);\n assert_cmp(Ndarray_NDIM(am_seq_len), ==, 1);\n assert_cmp(Ndarray_NDIM(edges), ==, 2);\n assert_cmp(Ndarray_NDIM(weights), ==, 1);\n assert_cmp(Ndarray_NDIM(start_end_states), ==, 2);\n assert_cmp(Ndarray_NDIM(n_states_ref), ==, 0);\n assert_cmp(Ndarray_NDIM(output), ==, 2);\n assert_cmp(Ndarray_NDIM(score), ==, 1);\n int n_time = Ndarray_DIMS(am_scores)[0];\n int n_batch = Ndarray_DIMS(am_scores)[1];\n int n_classes = Ndarray_DIMS(am_scores)[2];\n assert_cmp(Ndarray_DIMS(am_scores)[0], ==, n_time);\n assert_cmp(Ndarray_DIMS(am_scores)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(am_scores)[2], ==, n_classes);\n assert_cmp(Ndarray_DIMS(am_seq_len)[0], ==, n_batch);\n int n_edges = Ndarray_DIMS(edges)[1];\n assert_cmp(Ndarray_DIMS(edges)[0], ==, 4);\n assert_cmp(Ndarray_DIMS(edges)[1], ==, n_edges);\n assert_cmp(Ndarray_DIMS(weights)[0], ==, n_edges);\n assert_cmp(Ndarray_DIMS(start_end_states)[0], ==, 2);\n assert_cmp(Ndarray_DIMS(start_end_states)[1], ==, n_batch);\n int n_states = Ndarray_DEV_DATA_int32_scalar(n_states_ref);\n assert_cmp(Ndarray_DIMS(output)[0], ==, n_time);\n assert_cmp(Ndarray_DIMS(output)[1], ==, n_batch);\n assert_cmp(Ndarray_DIMS(score)[0], ==, n_batch);\n\n int32_t* d_edge_from = Ndarray_DEV_DATA_int32(edges) + 0 * Ndarray_STRIDE(edges, 0);\n int32_t* d_edge_to = Ndarray_DEV_DATA_int32(edges) + 1 * Ndarray_STRIDE(edges, 0);\n int32_t* d_edge_emission_idx = Ndarray_DEV_DATA_int32(edges) + 2 * Ndarray_STRIDE(edges, 0);\n int32_t* d_edge_seq_idx = Ndarray_DEV_DATA_int32(edges) + 3 * Ndarray_STRIDE(edges, 0);\n float* d_edge_weights = Ndarray_DEV_DATA(weights);\n float* d_am_scores = Ndarray_DEV_DATA(am_scores);\n int am_scores_stride = Ndarray_STRIDE(am_scores, 0);\n int32_t* d_am_seq_len = Ndarray_DEV_DATA_int32(am_seq_len);\n int32_t* d_start_states = Ndarray_DEV_DATA_int32(start_end_states) + 0 * Ndarray_STRIDE(start_end_states, 0);\n int32_t* d_end_states = Ndarray_DEV_DATA_int32(start_end_states) + 1 * Ndarray_STRIDE(start_end_states, 0);\n int32_t* d_output = Ndarray_DEV_DATA_int32(output);\n int output_stride = Ndarray_STRIDE(output, 0);\n float* d_score = Ndarray_DEV_DATA(score);\n\n IdxAndVal* d_buffer = (IdxAndVal*) device_malloc((n_time + 1) * n_states * sizeof(IdxAndVal));\n int buffer_stride = n_states;\n start_dev_kernel(init_buffer, (n_time, n_states, d_buffer));\n start_dev_kernel(init_first_frame, (n_batch, n_states, d_buffer, d_start_states));\n HANDLE_LAST_ERROR();\n\n for(int t = 0; t < n_time; ++t) {\n start_dev_kernel(next_frame, (\n n_time,\n n_states,\n n_edges,\n n_classes,\n t,\n d_am_scores + t * am_scores_stride,\n d_am_seq_len,\n d_buffer + t * buffer_stride,\n d_buffer + (t + 1) * buffer_stride,\n d_edge_from,\n d_edge_to,\n d_edge_emission_idx,\n d_edge_seq_idx,\n d_edge_weights,\n d_end_states\n ));\n }\n HANDLE_LAST_ERROR();\n\n start_dev_kernel(select_scores, (\n n_batch,\n n_states,\n buffer_stride,\n d_buffer,\n d_am_seq_len,\n d_end_states,\n d_score // out\n ));\n\n int32_t* d_cur_state = (int32_t*) device_malloc(n_batch * sizeof(int32_t));\n Ndarray_memcpy(d_cur_state, d_end_states, n_batch * sizeof(int32_t));\n\n for(int t = n_time - 1; t >= 0; --t) {\n start_dev_kernel(select_best_path, (\n n_batch,\n n_states,\n n_edges,\n t,\n d_cur_state,\n d_buffer + (t + 1) * buffer_stride,\n d_am_seq_len,\n d_edge_from,\n d_edge_to,\n d_edge_emission_idx,\n d_output + t * output_stride // out\n ));\n }\n HANDLE_LAST_ERROR();\n\n device_free(d_cur_state);\n device_free(d_buffer);\n '[source]#
- class returnn.native_op.GetCtcFsaFastBwOp[source]#
This implements
Fsa.get_ctc_fsa_fast_bw()
as a native op. This is for constructing a FSA with a CTC topology. The output format is compatible to the FastBaumWelch native op.- inputs:
- param targets:
shape (batch,time), int32
- param seq_lens:
shape (batch), int32
- param blank_idx:
scalar, int32
- param weights:
shape (num_edges,), float32 (not used, except for target shape)
- param label_loop:
scalar, int32 (casted from bool). True -> normal CTC; False -> RNA-like
- outputs:
- param edges:
(4,num_edges), int32, edges of the graph (from,to,emission_idx,sequence_idx)
- param start_end_states:
(2,batch), int32, (start,end) state idx in FSA
To construct weights (for FastBaumWelch), weights should be just tf.zeros((num_edges,)). num_edges should be n_batch * (5 * (n_time - 1) + 10)
(see construction in kernel why that number).
- in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'targets', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'seq_lens', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'blank_idx', 'ndim': 0, 'need_contiguous': True, 'shape': ()}, {'dtype': 'float32', 'gradient': 'disconnected', 'name': 'weights', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'host_memory': True, 'name': 'label_loop', 'ndim': 0, 'need_contiguous': True, 'shape': ()})[source]#
- out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'edges', 'ndim': 2, 'need_contiguous': True, 'shape': (4, (3, 0))}, {'dtype': 'int32', 'name': 'start_end_states', 'ndim': 2, 'need_contiguous': True, 'shape': (2, (1, 0))})[source]#
- c_extra_support_code: Dict[str, str] = {'01_kernel': '\n template<bool label_loop>\n DEF_KERNEL\n void construct_kernel\n (\n int n_batch, int n_time, int n_edges,\n const int32_t* targets, const int32_t* seq_lens,\n int32_t blank_idx,\n int32_t* edges, int32_t* start_end_states\n )\n {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n // n_edges should be n_batch * (5 * (n_time - 1) + 10).\n assert(n_edges % n_batch == 0);\n while(idx < n_edges) {\n int batch_idx = idx / (n_edges / n_batch);\n int rel_edge_idx = idx % (n_edges / n_batch);\n int32_t seq_len = seq_lens[batch_idx];\n // state_idx: 0 b, 1 l, 2 b, 3 l, ..., (T-1)*2 b, T*2-1 l, T*2 b, T*2+1 dummy, T*2+2 end\n // i.e. T*2+3 states per seq.\n int state_idx_offset = (n_time * 2 + 3) * batch_idx;\n int t = -1; // pos in targets\n int srel_edge_idx = -1; // state relative edge\n // (seq_len * 2) - 1 is last label state idx. seq_len * 2 is last blank state idx.\n int32_t dummy_state_idx = seq_len * 2 + 1;\n int32_t end_state_idx = seq_len * 2 + 2;\n int32_t state_idx = dummy_state_idx;\n int32_t to_state_idx = dummy_state_idx;\n if(rel_edge_idx == 0) {\n start_end_states[0 * n_batch + batch_idx] = state_idx_offset; // start\n start_end_states[1 * n_batch + batch_idx] = state_idx_offset + end_state_idx; // end\n }\n int32_t emission_idx = blank_idx;\n int32_t label_idx = -1, next_label_idx = -1;\n if(seq_len == 0) {\n t = -1;\n emission_idx = blank_idx;\n // 1 single blank loop\n if(rel_edge_idx == 0) {\n state_idx = 0;\n to_state_idx = 0;\n srel_edge_idx = 0;\n }\n else if(rel_edge_idx == 1) {\n state_idx = 0;\n to_state_idx = end_state_idx;\n srel_edge_idx = 1;\n }\n else {\n state_idx = dummy_state_idx;\n srel_edge_idx = -1;\n }\n }\n else if(seq_len == 1) {\n label_idx = targets[batch_idx * n_time + 0];\n // 3 edges for first / prev last blank\n if(rel_edge_idx < 3) {\n t = 0;\n state_idx = 0;\n srel_edge_idx = rel_edge_idx;\n if(srel_edge_idx == 0) {\n to_state_idx = state_idx;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 2) {\n to_state_idx = end_state_idx;\n emission_idx = label_idx;\n }\n }\n // 4 edges for first / last label\n else if(rel_edge_idx < 7) {\n t = 0;\n state_idx = 1;\n srel_edge_idx = rel_edge_idx - 3;\n if(srel_edge_idx == 0) {\n to_state_idx = label_loop ? state_idx : dummy_state_idx;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 2) {\n to_state_idx = label_loop ? end_state_idx : dummy_state_idx;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 3) {\n to_state_idx = end_state_idx;\n emission_idx = blank_idx;\n }\n }\n // 2 edges for last blank\n else if(rel_edge_idx < 9) {\n t = -1;\n emission_idx = blank_idx;\n state_idx = 2;\n srel_edge_idx = rel_edge_idx - 7;\n if(srel_edge_idx == 0)\n to_state_idx = state_idx;\n else\n to_state_idx = end_state_idx;\n }\n else {\n t = -1;\n state_idx = dummy_state_idx;\n srel_edge_idx = -1;\n }\n }\n else { // seq_len >= 2\n // 2 edges for each blank, 3 for each label. up to prev last.\n if(rel_edge_idx < 5 * (seq_len - 1)) {\n t = rel_edge_idx / 5;\n label_idx = targets[batch_idx * n_time + t];\n next_label_idx = targets[batch_idx * n_time + t + 1];\n state_idx = 2 * (rel_edge_idx / 5);\n srel_edge_idx = rel_edge_idx % 5;\n if(srel_edge_idx >= 2) {\n srel_edge_idx -= 2;\n state_idx += 1;\n }\n if(state_idx % 2 == 0) { // blank loop state\n if(srel_edge_idx == 0) {\n to_state_idx = state_idx;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = label_idx;\n }\n }\n else { // label loop state\n if(srel_edge_idx == 0) {\n to_state_idx = label_loop ? state_idx : dummy_state_idx;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 2) {\n // skip over blank to next label (if allowed <=> next label is different)\n if(label_idx != next_label_idx || !label_loop) {\n to_state_idx = state_idx + 2;\n emission_idx = next_label_idx;\n }\n }\n }\n }\n // 1 more edge for prev last label\n else if(rel_edge_idx == 5 * (seq_len - 1)) {\n t = seq_len - 2;\n label_idx = targets[batch_idx * n_time + t];\n next_label_idx = targets[batch_idx * n_time + t + 1];\n state_idx = (seq_len - 2) * 2 + 1;\n srel_edge_idx = 3;\n // skip over blank to next label / end state (if allowed <=> next label is different)\n if(label_idx != next_label_idx || !label_loop) {\n to_state_idx = end_state_idx;\n emission_idx = next_label_idx;\n }\n }\n // 3 edges for prev last blank\n else if(rel_edge_idx <= 5 * (seq_len - 1) + 3) {\n t = seq_len - 1;\n label_idx = targets[batch_idx * n_time + t];\n state_idx = (seq_len - 1) * 2;\n srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 1);\n if(srel_edge_idx == 0) {\n to_state_idx = state_idx;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 2) {\n to_state_idx = end_state_idx;\n emission_idx = label_idx;\n }\n }\n // 4 edges for last label\n else if(rel_edge_idx <= 5 * (seq_len - 1) + 7) {\n t = seq_len - 1;\n label_idx = targets[batch_idx * n_time + t];\n state_idx = (seq_len - 1) * 2 + 1;\n srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 4);\n if(srel_edge_idx == 0) {\n to_state_idx = label_loop ? state_idx : dummy_state_idx;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 1) {\n to_state_idx = state_idx + 1;\n emission_idx = blank_idx;\n }\n else if(srel_edge_idx == 2) {\n to_state_idx = label_loop ? end_state_idx : dummy_state_idx;\n emission_idx = label_idx;\n }\n else if(srel_edge_idx == 3) {\n to_state_idx = end_state_idx;\n emission_idx = blank_idx;\n }\n }\n // 2 edges for last blank\n else if(rel_edge_idx <= 5 * (seq_len - 1) + 9) {\n t = -1;\n emission_idx = blank_idx;\n state_idx = (seq_len - 1) * 2 + 2;\n srel_edge_idx = rel_edge_idx - (5 * (seq_len - 1) + 8);\n if(srel_edge_idx == 0)\n to_state_idx = state_idx;\n else\n to_state_idx = end_state_idx;\n }\n else {\n t = -1;\n state_idx = dummy_state_idx;\n srel_edge_idx = -1;\n }\n }\n\n edges[0 * n_edges + idx] = state_idx_offset + state_idx; // from\n edges[1 * n_edges + idx] = state_idx_offset + to_state_idx; // to\n edges[2 * n_edges + idx] = emission_idx; // emission\n edges[3 * n_edges + idx] = batch_idx; // batch\n\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert(n_inputs == 5);\n assert(n_outputs == 2);\n Ndarray* targets = inputs[0];\n Ndarray* seq_lens = inputs[1];\n Ndarray* blank_idx_ref = inputs[2];\n Ndarray* weights = inputs[3];\n bool label_loop = (bool) Ndarray_DEV_DATA_int32_scalar(inputs[4]);\n Ndarray* edges = *outputs[0];\n Ndarray* start_end_states = *outputs[1];\n assert_cmp(Ndarray_NDIM(targets), ==, 2);\n assert_cmp(Ndarray_NDIM(seq_lens), ==, 1);\n assert_cmp(Ndarray_NDIM(blank_idx_ref), ==, 0);\n assert_cmp(Ndarray_NDIM(weights), ==, 1);\n assert_cmp(Ndarray_NDIM(edges), ==, 2);\n assert_cmp(Ndarray_NDIM(start_end_states), ==, 2);\n int n_batch = Ndarray_DIMS(seq_lens)[0];\n assert_cmp(Ndarray_DIMS(targets)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(seq_lens)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(start_end_states)[1], ==, n_batch);\n int n_time = Ndarray_DIMS(targets)[1];\n int n_edges = Ndarray_DIMS(weights)[0];\n assert_cmp(Ndarray_DIMS(start_end_states)[0], ==, 2);\n assert_cmp(Ndarray_DIMS(edges)[0], ==, 4);\n assert_cmp(Ndarray_DIMS(edges)[1], ==, n_edges);\n\n assert_cmp(n_edges, ==, n_batch * (5 * (n_time - 1) + 10));\n\n Ndarray_memset(Ndarray_DEV_DATA_int32(edges), 255, 4 * n_edges * sizeof(int32_t));\n Ndarray_memset(Ndarray_DEV_DATA_int32(start_end_states), 255, 2 * n_batch * sizeof(int32_t));\n int32_t blank_idx = Ndarray_DEV_DATA_int32_scalar(blank_idx_ref);\n\n if(label_loop) {\n start_dev_kernel(construct_kernel<true>, (\n n_batch, n_time, n_edges,\n Ndarray_DEV_DATA_int32(targets), Ndarray_DEV_DATA_int32(seq_lens),\n blank_idx,\n Ndarray_DEV_DATA_int32(edges), Ndarray_DEV_DATA_int32(start_end_states)\n ));\n } else {\n start_dev_kernel(construct_kernel<false>, (\n n_batch, n_time, n_edges,\n Ndarray_DEV_DATA_int32(targets), Ndarray_DEV_DATA_int32(seq_lens),\n blank_idx,\n Ndarray_DEV_DATA_int32(edges), Ndarray_DEV_DATA_int32(start_end_states)\n ));\n }\n HANDLE_LAST_ERROR();\n '[source]#
- class returnn.native_op.EditDistanceOp[source]#
Similar to
tf.edit_distance()
. Calculates the edit distance / Levenshtein distance.The naive implementation either goes over
a
and thenb
, thus results in O(|a|*|b|) time complexity. To calculate a new entry in the table (over then length ofa
andb
), it depends on the prev symbol ina
(left) (deletion error), the prev symbol inb
(up) (insertion error), and the left-up diagonal (substitution error, or no error).To take advantage of the parallelism of the GPU, we follow a diagonal iteration scheme, such that in every iteration, all entries on the diagonal can be computed in parallel, as they do not depend on each other. After implementing this, we found that this algorithm is described here:
Using GPUs to Speed-Up Levenshtein Edit Distance Computation, Balhaf et al, 2016, https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7476090&tag=1
- inputs:
- param a:
symbols. 2d (batch,time), int32
- param a_len:
1d (batch,), int32
- param b:
symbols. 2d (batch,time), int32
- param b_len:
1d (batch,), int32
- outputs:
- param output:
1d (batch,), int32, unnormalized edit distance
- in_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'a_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b', 'ndim': 2, 'need_contiguous': True, 'shape': (None, None)}, {'dtype': 'int32', 'gradient': 'disconnected', 'name': 'b_len', 'ndim': 1, 'need_contiguous': True, 'shape': (None,)})[source]#
- out_info: Tuple[Dict[str]] = ({'dtype': 'int32', 'name': 'output', 'ndim': 1, 'need_contiguous': True, 'shape': ((0, 0),)},)[source]#
- c_extra_support_code: Dict[str, str] = {'001_next_step': '\n DEF_KERNEL\n void next_step_kernel(\n int n_batch, int n_a_max_len, int n_b_max_len,\n int diag_idx,\n const int32_t* a, const int32_t* b,\n const int32_t* a_len, const int32_t* b_len,\n const int32_t* last1_dist, const int32_t* last2_dist, int32_t* cur_dist,\n int32_t* result) {\n int idx = threadIdx.x + blockDim.x * blockIdx.x;\n // We are going diagonal!\n int num_entries;\n if(diag_idx <= n_a_max_len) {\n num_entries = diag_idx + 1;\n if(num_entries > n_b_max_len + 1)\n num_entries = n_b_max_len + 1;\n } else {\n num_entries = n_b_max_len + 1 - (diag_idx - n_a_max_len);\n if(num_entries > n_a_max_len + 1)\n num_entries = n_a_max_len + 1;\n }\n int max_num_entries = n_a_max_len + 1;\n if(max_num_entries > n_b_max_len + 1)\n max_num_entries = n_b_max_len + 1;\n while(idx < n_batch * num_entries) {\n int batch_idx = idx / num_entries;\n int entry_idx = idx % num_entries;\n int dist_idx = batch_idx * max_num_entries + entry_idx;\n\n int t_a, t_b;\n if(diag_idx <= n_a_max_len) {\n t_a = diag_idx - entry_idx;\n t_b = entry_idx;\n } else {\n t_a = n_a_max_len - entry_idx;\n t_b = diag_idx - n_a_max_len + entry_idx;\n }\n\n if(t_a == 0)\n cur_dist[dist_idx] = t_b; // distance == how much to delete from b\n else if(t_b == 0)\n cur_dist[dist_idx] = t_a; // distance == how much to delete from a\n else {\n // last1 is with diag_idx - 2. Needed for substitution cost.\n // last2 is with diag_idx - 1. Needed for insertion or deletion cost.\n // last2 refers to the first, for deletion. last2_idx + 1 is for insertion.\n int last1_idx, last2_idx;\n if(diag_idx - 1 < n_a_max_len)\n last1_idx = dist_idx - 1;\n else if(diag_idx - 1 == n_a_max_len)\n last1_idx = dist_idx;\n else\n last1_idx = dist_idx + 1;\n if(diag_idx <= n_a_max_len)\n last2_idx = dist_idx - 1;\n else\n last2_idx = dist_idx;\n\n int del_cost, ins_cost, sub_cost;\n del_cost = last2_dist[last2_idx] + 1;\n ins_cost = last2_dist[last2_idx + 1] + 1;\n sub_cost = last1_dist[last1_idx];\n if(a[batch_idx * n_a_max_len + t_a - 1] != b[batch_idx * n_b_max_len + t_b - 1])\n ++sub_cost;\n //printf("t_a %i, t_b %i, del %i, ins %i, sub %i\\n", t_a, t_b, del_cost, ins_cost, sub_cost);\n int min_cost = del_cost;\n if(min_cost > ins_cost) min_cost = ins_cost;\n if(min_cost > sub_cost) min_cost = sub_cost;\n cur_dist[dist_idx] = min_cost;\n }\n //printf("t_a %i, t_b %i, dist %i\\n", t_a, t_b, cur_dist[dist_idx]);\n\n if(t_a == a_len[batch_idx] && t_b == b_len[batch_idx])\n result[batch_idx] = cur_dist[dist_idx];\n\n idx += gridDim.x * blockDim.x;\n }\n }\n '}[source]#
- c_fw_code: str = '\n assert(n_inputs == 4);\n assert(n_outputs == 1);\n Ndarray* a = inputs[0];\n Ndarray* a_len = inputs[1];\n Ndarray* b = inputs[2];\n Ndarray* b_len = inputs[3];\n Ndarray* out = *outputs[0];\n assert_cmp(Ndarray_NDIM(a), ==, 2);\n assert_cmp(Ndarray_NDIM(a_len), ==, 1);\n assert_cmp(Ndarray_NDIM(b), ==, 2);\n assert_cmp(Ndarray_NDIM(b_len), ==, 1);\n assert_cmp(Ndarray_NDIM(out), ==, 1);\n int n_batch = Ndarray_DIMS(out)[0];\n assert_cmp(Ndarray_DIMS(a)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(a_len)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(b)[0], ==, n_batch);\n assert_cmp(Ndarray_DIMS(b_len)[0], ==, n_batch);\n int n_a_max_len = Ndarray_DIMS(a)[1];\n int n_b_max_len = Ndarray_DIMS(b)[1];\n Ndarray_memset(Ndarray_DEV_DATA_int32(out), 255, n_batch * sizeof(int32_t));\n\n // Working buffer.\n int max_num_entries = std::min(n_a_max_len +