blob: 35e0d9dba5c85c22077f6087d2a485b33965c675 [file] [log] [blame]
Gian Marco05288a22017-11-21 10:57:50 +00001/*
Gian Marco7b4d5472018-01-10 15:56:30 +00002 * Copyright (c) 2017-2018 ARM Limited.
Gian Marco05288a22017-11-21 10:57:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +000025#include "helpers_asymm.h"
Gian Marco05288a22017-11-21 10:57:50 +000026
Georgios Pinitasdaa38552018-08-28 17:43:18 +010027#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
28#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010029#define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010030#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010031#define ARM_DOT(x, y, val) val += arm_dot((x), (y));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010032#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
33#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arenac50da382018-07-26 15:50:09 +010034
Gian Marco19835e52018-01-30 13:35:54 +000035#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +000036/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco19835e52018-01-30 13:35:54 +000037 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
Gian Marco05288a22017-11-21 10:57:50 +000038 *
Gian Marco19835e52018-01-30 13:35:54 +000039 * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024
40 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
41 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco05288a22017-11-21 10:57:50 +000042 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010043 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
44 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
45 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
46 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
47 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
48 *
Gian Marco05288a22017-11-21 10:57:50 +000049 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
50 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
51 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
52 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
53 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
54 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
55 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
56 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
57 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
58 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
59 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
60 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
61 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
62 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
63 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
64 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
65 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
66 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010067 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
68 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
69 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
70 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +000071 */
Gian Marco19835e52018-01-30 13:35:54 +000072__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0),
73 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010074 IMAGE_DECLARATION(dst),
75 uint src0_stride_z,
76 uint src1_stride_z,
77 uint dst_stride_z
78#if defined(REINTERPRET_OUTPUT_AS_3D)
79 ,
80 uint cross_plane_pad
81#endif // REINTERPRET_OUTPUT_AS_3D
82 )
Gian Marco05288a22017-11-21 10:57:50 +000083{
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010084 const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
85 const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
86 const int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +000087
Gian Marco19835e52018-01-30 13:35:54 +000088 // Offset
89 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
90 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
91
92 // src_addr_a = address of matrix A
93 // src_addr_b = address of matrix B
Isabella Gottardib92805b2018-09-28 18:24:27 +010094 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
Gian Marco19835e52018-01-30 13:35:54 +000095 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Gian Marco05288a22017-11-21 10:57:50 +000096
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010097#if defined(MATRIX_B_DEPTH)
98 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
99 src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
100#else // defined(MATRIX_B_DEPTH)
101 src_addr_b += z * src1_stride_z;
102#endif // defined(MATRIX_B_DEPTH)
103
Gian Marco05288a22017-11-21 10:57:50 +0000104 // Compute end row address for matrix B
Gian Marco19835e52018-01-30 13:35:54 +0000105 __global uchar *src_end_addr_b = src_addr_b + COLS_B;
106
107 src_addr_a += offset_row_a;
108 src_addr_b += offset_row_b;
Gian Marco05288a22017-11-21 10:57:50 +0000109
110 // Reset accumulators
Gian Marco19835e52018-01-30 13:35:54 +0000111 int4 c00 = 0;
112 int4 c10 = 0;
113 int4 c20 = 0;
114 int4 c30 = 0;
Gian Marco05288a22017-11-21 10:57:50 +0000115
Gian Marco19835e52018-01-30 13:35:54 +0000116 for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000117 {
118 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000119 int4 a0 = convert_int4(vload4(0, src_addr_a));
120 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000121
Gian Marco19835e52018-01-30 13:35:54 +0000122 c00 += (int4)a0.s0 * b0;
123 c10 += (int4)a0.s1 * b0;
124 c20 += (int4)a0.s2 * b0;
125 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000126
Gian Marco19835e52018-01-30 13:35:54 +0000127 a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
128 b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP));
Gian Marco05288a22017-11-21 10:57:50 +0000129
Gian Marco19835e52018-01-30 13:35:54 +0000130 c00 += (int4)a0.s0 * b0;
131 c10 += (int4)a0.s1 * b0;
132 c20 += (int4)a0.s2 * b0;
133 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000134 }
135
Gian Marco19835e52018-01-30 13:35:54 +0000136 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
Gian Marco05288a22017-11-21 10:57:50 +0000137 {
138 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000139 int4 a0 = convert_int4(vload4(0, src_addr_a));
140 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000141
Gian Marco19835e52018-01-30 13:35:54 +0000142 c00 += (int4)a0.s0 * b0;
143 c10 += (int4)a0.s1 * b0;
144 c20 += (int4)a0.s2 * b0;
145 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000146 }
147
148 // Compute destination address
149 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
150
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100151#if defined(REINTERPRET_OUTPUT_AS_3D)
152 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
153 // in order to take into account the presence of possible cross plane paddings
154 //
155 // | |
156 // | plane0 |
157 // | |
158 // |__________________|
159 // |******************|
160 // | cross_plane_pad |
161 // |******************|
162 // | |
163 // | plane1 |
164 // | |
165 // |__________________|
166
167 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
168 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
169 zout = min(DEPTH_GEMM3D - 1, zout);
170
171 // Add offset due to the cross plane paddings
172 zout *= (cross_plane_pad * dst_stride_y);
173
174 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
175 // multiply dst_stride_z by DEPTH_GEMM3D
176 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
177
Gian Marco19835e52018-01-30 13:35:54 +0000178 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100179 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
180 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
181 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
182 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
183
184#else // defined(REINTERPRET_OUTPUT_AS_3D)
185 // Add offset for batched GEMM
186 dst.ptr += z * dst_stride_z;
187
188 // Store 4x4 block
189 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
190 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
191 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
192 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
193#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000194}
Gian Marco19835e52018-01-30 13:35:54 +0000195
196/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
197 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
198 *
199 * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
200 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
201 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
202 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100203 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
204 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
205 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
206 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
207 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
208 *
Gian Marco19835e52018-01-30 13:35:54 +0000209 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
210 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
211 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
212 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
213 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
214 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
215 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
216 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
217 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
218 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
219 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
220 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
221 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
222 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
223 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
224 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
225 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
226 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100227 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
228 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
229 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
230 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco19835e52018-01-30 13:35:54 +0000231 */
232__kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0),
233 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100234 IMAGE_DECLARATION(dst),
235 uint src0_stride_z,
236 uint src1_stride_z,
237 uint dst_stride_z
238#if defined(REINTERPRET_OUTPUT_AS_3D)
239 ,
240 uint cross_plane_pad
241#endif // REINTERPRET_OUTPUT_AS_3D
242 )
Gian Marco19835e52018-01-30 13:35:54 +0000243{
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100244 const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
245 const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
246 const int z = get_global_id(2);
Gian Marco19835e52018-01-30 13:35:54 +0000247
248 // Offset
249 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
250 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
251
252 // src_addr_a = address of matrix A
253 // src_addr_b = address of matrix B
Isabella Gottardib92805b2018-09-28 18:24:27 +0100254 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
Gian Marco19835e52018-01-30 13:35:54 +0000255 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
256
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100257#if defined(MATRIX_B_DEPTH)
258 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
259 src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
260#else // defined(MATRIX_B_DEPTH)
261 src_addr_b += z * src1_stride_z;
262#endif // defined(MATRIX_B_DEPTH)
263
Gian Marco19835e52018-01-30 13:35:54 +0000264 // Compute end row address for matrix B
265 __global uchar *src_end_addr_b = src_addr_b + COLS_B;
266
267 src_addr_a += offset_row_a;
268 src_addr_b += offset_row_b;
269
270 // Reset accumulators
271 uint c00 = 0;
272 uint c01 = 0;
273 uint c02 = 0;
274 uint c03 = 0;
275 uint c10 = 0;
276 uint c11 = 0;
277 uint c12 = 0;
278 uint c13 = 0;
279 uint c20 = 0;
280 uint c21 = 0;
281 uint c22 = 0;
282 uint c23 = 0;
283 uint c30 = 0;
284 uint c31 = 0;
285 uint c32 = 0;
286 uint c33 = 0;
287
288#if MULT_INTERLEAVE4X4_HEIGHT == 1
289 for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
290 {
291 // Load values from matrix A (interleaved) and matrix B (transposed)
292 uchar16 a0 = vload16(0, src_addr_a);
293 uchar4 b0 = vload4(0, src_addr_b);
294
295 c00 += (ushort)a0.s0 * b0.s0;
296 c01 += (ushort)a0.s0 * b0.s1;
297 c02 += (ushort)a0.s0 * b0.s2;
298 c03 += (ushort)a0.s0 * b0.s3;
299
300 c10 += (ushort)a0.s1 * b0.s0;
301 c11 += (ushort)a0.s1 * b0.s1;
302 c12 += (ushort)a0.s1 * b0.s2;
303 c13 += (ushort)a0.s1 * b0.s3;
304
305 c20 += (ushort)a0.s2 * b0.s0;
306 c21 += (ushort)a0.s2 * b0.s1;
307 c22 += (ushort)a0.s2 * b0.s2;
308 c23 += (ushort)a0.s2 * b0.s3;
309
310 c30 += (ushort)a0.s3 * b0.s0;
311 c31 += (ushort)a0.s3 * b0.s1;
312 c32 += (ushort)a0.s3 * b0.s2;
313 c33 += (ushort)a0.s3 * b0.s3;
314
315 // Load values from matrix B (transposed)
316 b0 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
317
318 c00 += (ushort)a0.s4 * b0.s0;
319 c01 += (ushort)a0.s4 * b0.s1;
320 c02 += (ushort)a0.s4 * b0.s2;
321 c03 += (ushort)a0.s4 * b0.s3;
322
323 c10 += (ushort)a0.s5 * b0.s0;
324 c11 += (ushort)a0.s5 * b0.s1;
325 c12 += (ushort)a0.s5 * b0.s2;
326 c13 += (ushort)a0.s5 * b0.s3;
327
328 c20 += (ushort)a0.s6 * b0.s0;
329 c21 += (ushort)a0.s6 * b0.s1;
330 c22 += (ushort)a0.s6 * b0.s2;
331 c23 += (ushort)a0.s6 * b0.s3;
332
333 c30 += (ushort)a0.s7 * b0.s0;
334 c31 += (ushort)a0.s7 * b0.s1;
335 c32 += (ushort)a0.s7 * b0.s2;
336 c33 += (ushort)a0.s7 * b0.s3;
337
338 // Load values from matrix B (transposed)
339 b0 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
340
341 c00 += (ushort)a0.s8 * b0.s0;
342 c01 += (ushort)a0.s8 * b0.s1;
343 c02 += (ushort)a0.s8 * b0.s2;
344 c03 += (ushort)a0.s8 * b0.s3;
345
346 c10 += (ushort)a0.s9 * b0.s0;
347 c11 += (ushort)a0.s9 * b0.s1;
348 c12 += (ushort)a0.s9 * b0.s2;
349 c13 += (ushort)a0.s9 * b0.s3;
350
351 c20 += (ushort)a0.sA * b0.s0;
352 c21 += (ushort)a0.sA * b0.s1;
353 c22 += (ushort)a0.sA * b0.s2;
354 c23 += (ushort)a0.sA * b0.s3;
355
356 c30 += (ushort)a0.sB * b0.s0;
357 c31 += (ushort)a0.sB * b0.s1;
358 c32 += (ushort)a0.sB * b0.s2;
359 c33 += (ushort)a0.sB * b0.s3;
360
361 // Load values from matrix B (transposed)
362 b0 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
363
364 c00 += (ushort)a0.sC * b0.s0;
365 c01 += (ushort)a0.sC * b0.s1;
366 c02 += (ushort)a0.sC * b0.s2;
367 c03 += (ushort)a0.sC * b0.s3;
368
369 c10 += (ushort)a0.sD * b0.s0;
370 c11 += (ushort)a0.sD * b0.s1;
371 c12 += (ushort)a0.sD * b0.s2;
372 c13 += (ushort)a0.sD * b0.s3;
373
374 c20 += (ushort)a0.sE * b0.s0;
375 c21 += (ushort)a0.sE * b0.s1;
376 c22 += (ushort)a0.sE * b0.s2;
377 c23 += (ushort)a0.sE * b0.s3;
378
379 c30 += (ushort)a0.sF * b0.s0;
380 c31 += (ushort)a0.sF * b0.s1;
381 c32 += (ushort)a0.sF * b0.s2;
382 c33 += (ushort)a0.sF * b0.s3;
383
384 // Load values from matrix A (interleaved) and matrix B (transposed)
385 a0 = vload16(0, src_addr_a + 16);
386 b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
387
388 c00 += (ushort)a0.s0 * b0.s0;
389 c01 += (ushort)a0.s0 * b0.s1;
390 c02 += (ushort)a0.s0 * b0.s2;
391 c03 += (ushort)a0.s0 * b0.s3;
392
393 c10 += (ushort)a0.s1 * b0.s0;
394 c11 += (ushort)a0.s1 * b0.s1;
395 c12 += (ushort)a0.s1 * b0.s2;
396 c13 += (ushort)a0.s1 * b0.s3;
397
398 c20 += (ushort)a0.s2 * b0.s0;
399 c21 += (ushort)a0.s2 * b0.s1;
400 c22 += (ushort)a0.s2 * b0.s2;
401 c23 += (ushort)a0.s2 * b0.s3;
402
403 c30 += (ushort)a0.s3 * b0.s0;
404 c31 += (ushort)a0.s3 * b0.s1;
405 c32 += (ushort)a0.s3 * b0.s2;
406 c33 += (ushort)a0.s3 * b0.s3;
407
408 // Load values from matrix B (transposed)
409 b0 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
410
411 c00 += (ushort)a0.s4 * b0.s0;
412 c01 += (ushort)a0.s4 * b0.s1;
413 c02 += (ushort)a0.s4 * b0.s2;
414 c03 += (ushort)a0.s4 * b0.s3;
415
416 c10 += (ushort)a0.s5 * b0.s0;
417 c11 += (ushort)a0.s5 * b0.s1;
418 c12 += (ushort)a0.s5 * b0.s2;
419 c13 += (ushort)a0.s5 * b0.s3;
420
421 c20 += (ushort)a0.s6 * b0.s0;
422 c21 += (ushort)a0.s6 * b0.s1;
423 c22 += (ushort)a0.s6 * b0.s2;
424 c23 += (ushort)a0.s6 * b0.s3;
425
426 c30 += (ushort)a0.s7 * b0.s0;
427 c31 += (ushort)a0.s7 * b0.s1;
428 c32 += (ushort)a0.s7 * b0.s2;
429 c33 += (ushort)a0.s7 * b0.s3;
430
431 // Load values from matrix B (transposed)
432 b0 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
433
434 c00 += (ushort)a0.s8 * b0.s0;
435 c01 += (ushort)a0.s8 * b0.s1;
436 c02 += (ushort)a0.s8 * b0.s2;
437 c03 += (ushort)a0.s8 * b0.s3;
438
439 c10 += (ushort)a0.s9 * b0.s0;
440 c11 += (ushort)a0.s9 * b0.s1;
441 c12 += (ushort)a0.s9 * b0.s2;
442 c13 += (ushort)a0.s9 * b0.s3;
443
444 c20 += (ushort)a0.sA * b0.s0;
445 c21 += (ushort)a0.sA * b0.s1;
446 c22 += (ushort)a0.sA * b0.s2;
447 c23 += (ushort)a0.sA * b0.s3;
448
449 c30 += (ushort)a0.sB * b0.s0;
450 c31 += (ushort)a0.sB * b0.s1;
451 c32 += (ushort)a0.sB * b0.s2;
452 c33 += (ushort)a0.sB * b0.s3;
453
454 // Load values from matrix B (transposed)
455 b0 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
456
457 c00 += (ushort)a0.sC * b0.s0;
458 c01 += (ushort)a0.sC * b0.s1;
459 c02 += (ushort)a0.sC * b0.s2;
460 c03 += (ushort)a0.sC * b0.s3;
461
462 c10 += (ushort)a0.sD * b0.s0;
463 c11 += (ushort)a0.sD * b0.s1;
464 c12 += (ushort)a0.sD * b0.s2;
465 c13 += (ushort)a0.sD * b0.s3;
466
467 c20 += (ushort)a0.sE * b0.s0;
468 c21 += (ushort)a0.sE * b0.s1;
469 c22 += (ushort)a0.sE * b0.s2;
470 c23 += (ushort)a0.sE * b0.s3;
471
472 c30 += (ushort)a0.sF * b0.s0;
473 c31 += (ushort)a0.sF * b0.s1;
474 c32 += (ushort)a0.sF * b0.s2;
475 c33 += (ushort)a0.sF * b0.s3;
476 }
477#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
478
479 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
480 {
481 // Load values from matrix A (interleaved) and matrix B (transposed)
482 uchar4 a0 = vload4(0, src_addr_a);
483 uchar4 b0 = vload4(0, src_addr_b);
484
485 c00 += (ushort)a0.s0 * b0.s0;
486 c01 += (ushort)a0.s0 * b0.s1;
487 c02 += (ushort)a0.s0 * b0.s2;
488 c03 += (ushort)a0.s0 * b0.s3;
489
490 c10 += (ushort)a0.s1 * b0.s0;
491 c11 += (ushort)a0.s1 * b0.s1;
492 c12 += (ushort)a0.s1 * b0.s2;
493 c13 += (ushort)a0.s1 * b0.s3;
494
495 c20 += (ushort)a0.s2 * b0.s0;
496 c21 += (ushort)a0.s2 * b0.s1;
497 c22 += (ushort)a0.s2 * b0.s2;
498 c23 += (ushort)a0.s2 * b0.s3;
499
500 c30 += (ushort)a0.s3 * b0.s0;
501 c31 += (ushort)a0.s3 * b0.s1;
502 c32 += (ushort)a0.s3 * b0.s2;
503 c33 += (ushort)a0.s3 * b0.s3;
504 }
505
506 // Compute destination address
507 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
508
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100509#if defined(REINTERPRET_OUTPUT_AS_3D)
510 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
511 // in order to take into account the presence of possible cross plane paddings
512 //
513 // | |
514 // | plane0 |
515 // | |
516 // |__________________|
517 // |******************|
518 // | cross_plane_pad |
519 // |******************|
520 // | |
521 // | plane1 |
522 // | |
523 // |__________________|
524
525 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
526 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
527 zout = min(DEPTH_GEMM3D - 1, zout);
528
529 // Add offset due to the cross plane paddings
530 zout *= (cross_plane_pad * dst_stride_y);
531
532 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
533 // multiply dst_stride_z by DEPTH_GEMM3D
534 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
535
Gian Marco19835e52018-01-30 13:35:54 +0000536 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100537 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
538 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
539 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
540 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
541
542#else // defined(REINTERPRET_OUTPUT_AS_3D)
543 // Add offset for batched GEMM
544 dst.ptr += z * dst_stride_z;
545
546 // Store 4x4 block
547 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
548 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
549 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
550 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
551#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco19835e52018-01-30 13:35:54 +0000552}
Giorgio Arena6200fa42018-07-06 17:06:36 +0100553
Georgios Pinitasdaa38552018-08-28 17:43:18 +0100554#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100555/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
556 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
557 *
558 * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
559 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
560 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
561 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100562 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
563 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
564 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
565 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
566 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
567 *
Giorgio Arena6200fa42018-07-06 17:06:36 +0100568 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
569 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
570 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
571 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
572 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
573 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
574 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
575 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
576 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
577 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
578 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
579 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
580 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
581 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
582 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
583 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
584 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
585 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100586 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
587 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
588 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
589 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100590 */
591__kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(src0),
592 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100593 IMAGE_DECLARATION(dst),
594 uint src0_stride_z,
595 uint src1_stride_z,
596 uint dst_stride_z
597#if defined(REINTERPRET_OUTPUT_AS_3D)
598 ,
599 uint cross_plane_pad
600#endif // REINTERPRET_OUTPUT_AS_3D
601 )
Giorgio Arena6200fa42018-07-06 17:06:36 +0100602{
Giorgio Arena6200fa42018-07-06 17:06:36 +0100603 // Offset
604 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
605 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
606
607 // src_addr_a = address of matrix A
608 // src_addr_b = address of matrix B
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100609 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + (get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT) * src0_stride_y + get_global_id(2) * src0_stride_z + src0_offset_first_element_in_bytes);
610 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + (get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP) * src1_stride_y + src1_offset_first_element_in_bytes);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100611
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100612#if defined(MATRIX_B_DEPTH)
613 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100614 src_addr_b += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100615#else // defined(MATRIX_B_DEPTH)
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100616 src_addr_b += get_global_id(2) * src1_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100617#endif // defined(MATRIX_B_DEPTH)
618
Giorgio Arena6200fa42018-07-06 17:06:36 +0100619 src_addr_a += offset_row_a;
620 src_addr_b += offset_row_b;
621
622 // Reset accumulators
623 uint c00 = 0;
624 uint c01 = 0;
625 uint c02 = 0;
626 uint c03 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100627
Giorgio Arena6200fa42018-07-06 17:06:36 +0100628 uint c10 = 0;
629 uint c11 = 0;
630 uint c12 = 0;
631 uint c13 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100632
Giorgio Arena6200fa42018-07-06 17:06:36 +0100633 uint c20 = 0;
634 uint c21 = 0;
635 uint c22 = 0;
636 uint c23 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100637
Giorgio Arena6200fa42018-07-06 17:06:36 +0100638 uint c30 = 0;
639 uint c31 = 0;
640 uint c32 = 0;
641 uint c33 = 0;
642
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100643#define COLS_MTX_B (COLS_B / (16 * MULT_TRANSPOSE1XW_WIDTH))
644
Giorgio Arena6200fa42018-07-06 17:06:36 +0100645#if MULT_INTERLEAVE4X4_HEIGHT == 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100646 int i = 0;
647 for(; i <= (int)(COLS_MTX_B - 8); i += 8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100648 {
649 // Load values from matrix A (interleaved) and matrix B (transposed)
650 uchar16 a0 = vload16(0, src_addr_a);
651 uchar4 b0 = vload4(0, src_addr_b);
652 uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
653 uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
654 uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100655 uchar4 b4 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
656 uchar4 b5 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
657 uchar4 b6 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
658 uchar4 b7 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100659
660 // Accumulate
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100661 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
662 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
663 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
664 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100665
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100666 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
667 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
668 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
669 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100670
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100671 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
672 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
673 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
674 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100675
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100676 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
677 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
678 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
679 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100680
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100681 // Accumulate
Giorgio Arena6200fa42018-07-06 17:06:36 +0100682 a0 = vload16(0, src_addr_a + 16);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100683
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100684 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c00);
685 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c01);
686 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c02);
687 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c03);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100688
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100689 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c10);
690 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c11);
691 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c12);
692 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c13);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100693
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100694 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c20);
695 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c21);
696 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c22);
697 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c23);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100698
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100699 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c30);
700 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c31);
701 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c32);
702 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c33);
703
704 src_addr_a += 32;
705 src_addr_b += 32 * TRANSPOSE1XW_WIDTH_STEP;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100706 }
707#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100708 int i_left_over = 0;
709 for(; i < (int)(COLS_MTX_B); ++i)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100710 {
711 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100712 uchar16 a0 = vload16(0, src_addr_a + (i_left_over % 4) + ((i_left_over / 4) * 16));
713 uchar4 b0 = vload4(0, src_addr_b);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100714
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100715 c00 += a0.s0 * b0.s0;
716 c01 += a0.s0 * b0.s1;
717 c02 += a0.s0 * b0.s2;
718 c03 += a0.s0 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100719
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100720 c10 += a0.s4 * b0.s0;
721 c11 += a0.s4 * b0.s1;
722 c12 += a0.s4 * b0.s2;
723 c13 += a0.s4 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100724
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100725 c20 += a0.s8 * b0.s0;
726 c21 += a0.s8 * b0.s1;
727 c22 += a0.s8 * b0.s2;
728 c23 += a0.s8 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100729
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100730 c30 += a0.sC * b0.s0;
731 c31 += a0.sC * b0.s1;
732 c32 += a0.sC * b0.s2;
733 c33 += a0.sC * b0.s3;
734
735 i_left_over++;
736 src_addr_b += 4 * TRANSPOSE1XW_WIDTH_STEP;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100737 }
738
739 // Compute destination address
740 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
741
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100742#if defined(REINTERPRET_OUTPUT_AS_3D)
743 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
744 // in order to take into account the presence of possible cross plane paddings
745 //
746 // | |
747 // | plane0 |
748 // | |
749 // |__________________|
750 // |******************|
751 // | cross_plane_pad |
752 // |******************|
753 // | |
754 // | plane1 |
755 // | |
756 // |__________________|
757
758 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
759 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
760 zout = min(DEPTH_GEMM3D - 1, zout);
761
762 // Add offset due to the cross plane paddings
763 zout *= (cross_plane_pad * dst_stride_y);
764
765 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
766 // multiply dst_stride_z by DEPTH_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100767 dst.ptr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100768
Giorgio Arena6200fa42018-07-06 17:06:36 +0100769 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100770 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
771 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
772 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
773 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
774
775#else // defined(REINTERPRET_OUTPUT_AS_3D)
776 // Add offset for batched GEMM
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100777 dst.ptr += get_global_id(2) * dst_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100778
779 // Store 4x4 block
780 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
781 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
782 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
783 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
784#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100785}
Georgios Pinitasdaa38552018-08-28 17:43:18 +0100786#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100787
Gian Marco19835e52018-01-30 13:35:54 +0000788#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000789
790#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
791#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
792#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
793#define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
794/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
795 *
796 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
797 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100798 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
799 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
800 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
801 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
802 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
803 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
804 *
Gian Marco05288a22017-11-21 10:57:50 +0000805 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
806 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
807 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
808 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
809 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
810 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
811 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
812 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
813 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
814 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
815 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
816 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
817 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
818 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
819 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
820 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
821 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
822 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100823 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
824 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
825 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
826 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
827 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000828 */
Gian Marco7b4d5472018-01-10 15:56:30 +0000829__kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
830 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100831 IMAGE_DECLARATION(dst),
832 uint src0_stride_z,
833 uint src1_stride_z,
834 uint dst_stride_z
835#if defined(REINTERPRET_INPUT_AS_3D)
836 ,
837 uint src_cross_plane_pad
838#endif // REINTERPRET_INPUT_AS_3D
839#if defined(REINTERPRET_OUTPUT_AS_3D)
840 ,
841 uint dst_cross_plane_pad
842#endif // REINTERPRET_OUTPUT_AS_3D
843 )
Gian Marco05288a22017-11-21 10:57:50 +0000844{
845 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
846
847 // Compute starting address for matrix A and Matrix B
848 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
849
850 // Update address for the matrix A
851 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
852
853 // Update address for the matrix B
854 src_addr.s1 += idx;
855
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100856#if defined(REINTERPRET_INPUT_AS_3D)
857 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
858 // in order to take into account the presence of possible cross plane paddings
859 //
860 // | |
861 // | plane0 |
862 // | |
863 // |__________________|
864 // |******************|
865 // | cross_plane_pad |
866 // |******************|
867 // | |
868 // | plane1 |
869 // | |
870 // |__________________|
871
872 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
873 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
874 zin = min(DEPTH_GEMM3D - 1, zin);
875
876 // Add offset due to the cross plane paddings
877 zin *= (src_cross_plane_pad * src0_stride_y);
878
879 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
880 // multiply src0_stride_z by DEPTH_GEMM3D
881 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
882
883#else // defined(REINTERPRET_INPUT_AS_3D)
884
885 // Add offset for batched GEMM
886 src_addr.s0 += get_global_id(2) * src0_stride_z;
887
888#endif // defined(REINTERPRET_INPUT_AS_3D)
889
890#if defined(MATRIX_B_DEPTH)
891 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
892 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
893#else // defined(MATRIX_B_DEPTH)
894 src_addr.s1 += get_global_id(2) * src1_stride_z;
895#endif // defined(MATRIX_B_DEPTH)
896
Gian Marco05288a22017-11-21 10:57:50 +0000897 int end_row_vec_a = src_addr.s0 + COLS_A;
898
899 VECTOR_UINT acc0 = 0;
900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
901 VECTOR_UINT acc1 = 0;
902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
904 VECTOR_UINT acc2 = 0;
905#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
906#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
907 VECTOR_UINT acc3 = 0;
908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
910 VECTOR_UINT acc4 = 0;
911#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000912
913 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
914 {
915 // Load values from matrix A
916 uchar2 a0 = vload2(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
918 uchar2 a1 = vload2(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
919#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
920#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
921 uchar2 a2 = vload2(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
923#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
924 uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
925#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000926#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
927 uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000929 // Load values from matrix B
930 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
931 VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y);
932
933 // Accumulate
934 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0.s0;
935 acc0 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a0.s1;
936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
937 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1.s0;
938 acc1 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a1.s1;
939#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
940#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
941 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2.s0;
942 acc2 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a2.s1;
943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
945 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0;
946 acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1;
947#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
949 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0;
950 acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1;
951#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000952 }
953
954 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
955 {
956 // Load values from matrix A
957 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
958#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
959 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
960#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
961#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
962 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
965 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
966#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000967#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
968 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
969#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000970 // Load values from matrix B
971 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
972
973 // Accumulate
974 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0;
975#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
976 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1;
977#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
978#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
979 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2;
980#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
981#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
982 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3;
983#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000984#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
985 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4;
986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000987 }
988
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100989 const int z = get_global_id(2);
990
Gian Marco05288a22017-11-21 10:57:50 +0000991 // Compute destination address
992 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
993
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100994#if defined(REINTERPRET_OUTPUT_AS_3D)
995 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
996 // in order to take into account the presence of possible cross plane paddings
997 //
998 // | |
999 // | plane0 |
1000 // | |
1001 // |__________________|
1002 // |******************|
1003 // | cross_plane_pad |
1004 // |******************|
1005 // | |
1006 // | plane1 |
1007 // | |
1008 // |__________________|
1009
1010 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1011 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
1012 zout = min(DEPTH_GEMM3D - 1, zout);
1013
1014 // Add offset due to the cross plane paddings
1015 zout *= (dst_cross_plane_pad * dst_stride_y);
1016
1017 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1018 // multiply dst_stride_z by DEPTH_GEMM3D
1019 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
1020
Gian Marco05288a22017-11-21 10:57:50 +00001021 // Store the result
1022 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001023 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco05288a22017-11-21 10:57:50 +00001024#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1025 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001026 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco05288a22017-11-21 10:57:50 +00001027#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1029 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001030 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco05288a22017-11-21 10:57:50 +00001031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1033 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001034 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco05288a22017-11-21 10:57:50 +00001035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +00001036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1037 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001038 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +00001039#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001040
1041#else // defined(REINTERPRET_OUTPUT_AS_3D)
1042 // Add offset for batched GEMM
1043 dst.ptr += z * dst_stride_z;
1044
1045 // Store the result
1046 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1047 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
1048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1049 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1050 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
1051#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1052#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1053 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1054 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
1055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1056#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1057 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1058 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
1059#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1060#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1061 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1062 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
1063#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1064#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +00001065}
1066
1067/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1068 *
1069 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
1070 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001071 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1072 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1073 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1074 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1075 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1076 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1077 *
Gian Marco7b4d5472018-01-10 15:56:30 +00001078 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
1079 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1080 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1081 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1082 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1083 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1084 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
1085 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1086 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1087 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1088 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1089 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1090 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
1091 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1092 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1093 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1094 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1095 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001096 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1097 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1098 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1099 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1100 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +00001101 */
1102__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
1103 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001104 IMAGE_DECLARATION(dst),
1105 uint src0_stride_z,
1106 uint src1_stride_z,
1107 uint dst_stride_z
1108#if defined(REINTERPRET_INPUT_AS_3D)
1109 ,
1110 uint src_cross_plane_pad
1111#endif // REINTERPRET_INPUT_AS_3D
1112#if defined(REINTERPRET_OUTPUT_AS_3D)
1113 ,
1114 uint dst_cross_plane_pad
1115#endif // REINTERPRET_OUTPUT_AS_3D
1116 )
Gian Marco7b4d5472018-01-10 15:56:30 +00001117{
1118 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1119
1120 // Compute starting address for matrix A and Matrix B
1121 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1122
1123 // Update address for the matrix A
1124 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1125
1126 // Update address for the matrix B
1127 src_addr.s1 += idx;
1128
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001129#if defined(REINTERPRET_INPUT_AS_3D)
1130 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1131 // in order to take into account the presence of possible cross plane paddings
1132 //
1133 // | |
1134 // | plane0 |
1135 // | |
1136 // |__________________|
1137 // |******************|
1138 // | cross_plane_pad |
1139 // |******************|
1140 // | |
1141 // | plane1 |
1142 // | |
1143 // |__________________|
1144
1145 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1146 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1147 zin = min(DEPTH_GEMM3D - 1, zin);
1148
1149 // Add offset due to the cross plane paddings
1150 zin *= (src_cross_plane_pad * src0_stride_y);
1151
1152 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1153 // multiply src0_stride_z by DEPTH_GEMM3D
1154 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1155
1156#else // defined(REINTERPRET_INPUT_AS_3D)
1157
1158 // Add offset for batched GEMM
1159 src_addr.s0 += get_global_id(2) * src0_stride_z;
1160
1161#endif // defined(REINTERPRET_INPUT_AS_3D)
1162
1163#if defined(MATRIX_B_DEPTH)
1164 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1165 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1166#else // defined(MATRIX_B_DEPTH)
1167 src_addr.s1 += get_global_id(2) * src1_stride_z;
1168#endif // defined(MATRIX_B_DEPTH)
1169
Gian Marco7b4d5472018-01-10 15:56:30 +00001170 int end_row_vec_a = src_addr.s0 + COLS_A;
1171
1172 uint acc00 = 0;
1173 uint acc01 = 0;
1174 uint acc02 = 0;
1175 uint acc03 = 0;
1176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1177 uint acc10 = 0;
1178 uint acc11 = 0;
1179 uint acc12 = 0;
1180 uint acc13 = 0;
1181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1183 uint acc20 = 0;
1184 uint acc21 = 0;
1185 uint acc22 = 0;
1186 uint acc23 = 0;
1187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1189 uint acc30 = 0;
1190 uint acc31 = 0;
1191 uint acc32 = 0;
1192 uint acc33 = 0;
1193#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1194#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1195 uint acc40 = 0;
1196 uint acc41 = 0;
1197 uint acc42 = 0;
1198 uint acc43 = 0;
1199#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1200
1201 for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
1202 {
1203 // Load values from matrix A
1204 uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
1205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1206 uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
1207#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1208#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1209 uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
1210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1211#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1212 uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
1213#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1214#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1215 uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
1216#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1217 // Load values from matrix B
1218 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1219 uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1220 uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1221 uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1222
1223 {
1224 // Accumulate
1225 ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
1226 ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
1227 ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
1228 ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
1229
1230 ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
1231 ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
1232 ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
1233 ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
1234
1235 ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
1236 ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
1237 ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
1238 ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
1239
1240 ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
1241 ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
1242 ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
1243 ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
1244
1245 acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1246 acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1247 acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1248 acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1249 }
1250#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1251 {
1252 // Accumulate
1253 ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
1254 ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
1255 ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
1256 ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
1257
1258 ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
1259 ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
1260 ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
1261 ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
1262
1263 ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
1264 ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
1265 ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
1266 ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
1267
1268 ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
1269 ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
1270 ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
1271 ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
1272
1273 acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1274 acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1275 acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1276 acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1277 }
1278#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1279#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1280 {
1281 // Accumulate
1282 ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
1283 ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
1284 ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
1285 ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
1286
1287 ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
1288 ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
1289 ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
1290 ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
1291
1292 ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
1293 ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
1294 ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
1295 ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
1296
1297 ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
1298 ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
1299 ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
1300 ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
1301
1302 acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1303 acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1304 acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1305 acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1306 }
1307#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1309 {
1310 // Accumulate
1311 ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
1312 ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
1313 ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
1314 ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
1315
1316 ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
1317 ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
1318 ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
1319 ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
1320
1321 ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
1322 ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
1323 ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
1324 ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
1325
1326 ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
1327 ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
1328 ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
1329 ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
1330
1331 acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1332 acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1333 acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1334 acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1335 }
1336#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1337#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1338 {
1339 // Accumulate
1340 ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
1341 ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
1342 ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
1343 ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
1344
1345 ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
1346 ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
1347 ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
1348 ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
1349
1350 ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
1351 ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
1352 ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
1353 ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
1354
1355 ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
1356 ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
1357 ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
1358 ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
1359
1360 acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1361 acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1362 acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1363 acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1364 }
1365#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1366 }
1367
1368 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1369 {
1370 // Load values from matrix A
1371 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
1372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1373 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
1374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1376 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
1377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1378#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1379 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
1380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1382 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
1383#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1384 // Load values from matrix B
1385 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
1386
1387 // Accumulate
1388 {
1389 // Accumulate
1390 ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
1391 ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
1392 ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
1393 ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
1394
1395 acc00 += ((uint)tmp0);
1396 acc01 += ((uint)tmp1);
1397 acc02 += ((uint)tmp2);
1398 acc03 += ((uint)tmp3);
1399 }
1400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1401 {
1402 // Accumulate
1403 ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
1404 ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
1405 ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
1406 ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
1407
1408 acc10 += ((uint)tmp0);
1409 acc11 += ((uint)tmp1);
1410 acc12 += ((uint)tmp2);
1411 acc13 += ((uint)tmp3);
1412 }
1413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1414#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1415 {
1416 // Accumulate
1417 ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
1418 ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
1419 ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
1420 ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
1421
1422 acc20 += ((uint)tmp0);
1423 acc21 += ((uint)tmp1);
1424 acc22 += ((uint)tmp2);
1425 acc23 += ((uint)tmp3);
1426 }
1427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1429 {
1430 // Accumulate
1431 ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
1432 ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
1433 ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
1434 ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
1435
1436 acc30 += ((uint)tmp0);
1437 acc31 += ((uint)tmp1);
1438 acc32 += ((uint)tmp2);
1439 acc33 += ((uint)tmp3);
1440 }
1441#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1443 {
1444 // Accumulate
1445 ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
1446 ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
1447 ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
1448 ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
1449
1450 acc40 += ((uint)tmp0);
1451 acc41 += ((uint)tmp1);
1452 acc42 += ((uint)tmp2);
1453 acc43 += ((uint)tmp3);
1454 }
1455#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1456 }
1457
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001458 const int z = get_global_id(2);
1459
Gian Marco7b4d5472018-01-10 15:56:30 +00001460 // Compute destination address
1461 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1462
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001463#if defined(REINTERPRET_OUTPUT_AS_3D)
1464 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1465 // in order to take into account the presence of possible cross plane paddings
1466 //
1467 // | |
1468 // | plane0 |
1469 // | |
1470 // |__________________|
1471 // |******************|
1472 // | cross_plane_pad |
1473 // |******************|
1474 // | |
1475 // | plane1 |
1476 // | |
1477 // |__________________|
1478
1479 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1480 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
1481 zout = min(DEPTH_GEMM3D - 1, zout);
1482
1483 // Add offset due to the cross plane paddings
1484 zout *= (dst_cross_plane_pad * dst_stride_y);
1485
1486 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1487 // multiply dst_stride_z by DEPTH_GEMM3D
1488 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
1489
Gian Marco7b4d5472018-01-10 15:56:30 +00001490 // Store the result
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001491 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco7b4d5472018-01-10 15:56:30 +00001492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001493 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco7b4d5472018-01-10 15:56:30 +00001494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001496 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco7b4d5472018-01-10 15:56:30 +00001497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1498#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001499 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco7b4d5472018-01-10 15:56:30 +00001500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1501#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001502 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +00001503#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001504
1505#else // defined(REINTERPRET_OUTPUT_AS_3D)
1506 // Add offset for batched GEMM
1507 dst.ptr += z * dst_stride_z;
1508
1509 // Store the result
1510 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
1511#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1512 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
1513#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1514#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1515 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
1516#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1517#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1518 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
1519#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1520#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1521 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
1522#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1523#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +00001524}
Giorgio Arena6200fa42018-07-06 17:06:36 +01001525
Georgios Pinitasdaa38552018-08-28 17:43:18 +01001526#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001527/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1528 *
1529 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
1530 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001531 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1532 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1533 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1534 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1535 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1536 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1537 *
Giorgio Arena6200fa42018-07-06 17:06:36 +01001538 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
1539 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1540 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1541 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1542 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1543 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1544 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
1545 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1546 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1547 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1548 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1549 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1550 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
1551 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1552 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1553 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1554 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1555 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001556 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1557 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1558 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1559 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1560 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001561 */
1562__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
1563 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001564 IMAGE_DECLARATION(dst),
1565 uint src0_stride_z,
1566 uint src1_stride_z,
1567 uint dst_stride_z
1568#if defined(REINTERPRET_INPUT_AS_3D)
1569 ,
1570 uint src_cross_plane_pad
1571#endif // REINTERPRET_INPUT_AS_3D
1572#if defined(REINTERPRET_OUTPUT_AS_3D)
1573 ,
1574 uint dst_cross_plane_pad
1575#endif // REINTERPRET_OUTPUT_AS_3D)
1576 )
Giorgio Arena6200fa42018-07-06 17:06:36 +01001577{
1578 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1579
1580 // Compute starting address for matrix A and Matrix B
1581 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1582
1583 // Update address for the matrix A
1584 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1585
1586 // Update address for the matrix B
1587 src_addr.s1 += idx;
1588
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001589#if defined(REINTERPRET_INPUT_AS_3D)
1590 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1591 // in order to take into account the presence of possible cross plane paddings
1592 //
1593 // | |
1594 // | plane0 |
1595 // | |
1596 // |__________________|
1597 // |******************|
1598 // | cross_plane_pad |
1599 // |******************|
1600 // | |
1601 // | plane1 |
1602 // | |
1603 // |__________________|
1604
1605 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1606 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1607 zin = min(DEPTH_GEMM3D - 1, zin);
1608
1609 // Add offset due to the cross plane paddings
1610 zin *= (src_cross_plane_pad * src0_stride_y);
1611
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001612 zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
1613
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001614 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1615 // multiply src0_stride_z by DEPTH_GEMM3D
1616 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1617
1618#else // defined(REINTERPRET_INPUT_AS_3D)
1619
1620 // Add offset for batched GEMM
1621 src_addr.s0 += get_global_id(2) * src0_stride_z;
1622
1623#endif // defined(REINTERPRET_INPUT_AS_3D)
1624
1625#if defined(MATRIX_B_DEPTH)
1626 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1627 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1628#else // defined(MATRIX_B_DEPTH)
1629 src_addr.s1 += get_global_id(2) * src1_stride_z;
1630#endif // defined(MATRIX_B_DEPTH)
1631
Giorgio Arena6200fa42018-07-06 17:06:36 +01001632 uint acc00 = 0;
1633 uint acc01 = 0;
1634 uint acc02 = 0;
1635 uint acc03 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001636 uint acc04 = 0;
1637 uint acc05 = 0;
1638 uint acc06 = 0;
1639 uint acc07 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001640#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1641 uint acc10 = 0;
1642 uint acc11 = 0;
1643 uint acc12 = 0;
1644 uint acc13 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001645 uint acc14 = 0;
1646 uint acc15 = 0;
1647 uint acc16 = 0;
1648 uint acc17 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001649#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1650#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1651 uint acc20 = 0;
1652 uint acc21 = 0;
1653 uint acc22 = 0;
1654 uint acc23 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001655 uint acc24 = 0;
1656 uint acc25 = 0;
1657 uint acc26 = 0;
1658 uint acc27 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001659#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1660#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1661 uint acc30 = 0;
1662 uint acc31 = 0;
1663 uint acc32 = 0;
1664 uint acc33 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001665 uint acc34 = 0;
1666 uint acc35 = 0;
1667 uint acc36 = 0;
1668 uint acc37 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001669#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001670
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001671 // A and B src indices get incremented at the same time.
1672 int i = 0;
1673 for(; i <= ((int)COLS_A - 8); i += 8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001674 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001675#if defined(REINTERPRET_INPUT_AS_3D)
1676 // Load values from matrix A and matrix B
1677 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001678#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001679 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001680#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001682 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001683#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1684#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001685 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001687#else // defined(REINTERPRET_INPUT_AS_3D)
1688 // Load values from matrix A and matrix B
1689 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1690#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1691 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1693#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1694 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1695#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1697 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1698#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1699#endif // defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001700
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001701 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1702 uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1703 uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1704 uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1705 src_addr.s1 += 4 * src1_stride_y;
1706
1707 ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1708 ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1709 ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1710 ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1711 ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1712 ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1713 ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1714 ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1715
Giorgio Arena6200fa42018-07-06 17:06:36 +01001716#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001717 ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1718 ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1719 ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1720 ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1721 ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1722 ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1723 ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1724 ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001725#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1726#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001727 ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1728 ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1729 ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1730 ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1731 ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1732 ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1733 ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1734 ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001735#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1736#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001737 ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1738 ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1739 ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1740 ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1741 ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1742 ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1743 ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1744 ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001745#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001746
1747 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1748 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1749 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1750 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1751 src_addr.s1 += 4 * src1_stride_y;
1752
1753 ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1754 ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1755 ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1756 ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1757 ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1758 ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1759 ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1760 ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1761
1762#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1763 ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1764 ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1765 ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1766 ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1767 ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1768 ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1769 ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1770 ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
1771#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1773 ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1774 ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1775 ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1776 ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1777 ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1778 ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1779 ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1780 ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
1781#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1782#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1783 ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1784 ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1785 ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1786 ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1787 ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1788 ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1789 ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1790 ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
1791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1792
1793 src_addr.s0 += 8;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001794 }
1795
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001796 for(; i < (int)COLS_A; ++i)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001797 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001798#if defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001799 // Load values from matrix A
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001800 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001801#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001802 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001803#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1804#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001805 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001806#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1807#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001808 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001809#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001810#else // defined(REINTERPRET_INPUT_AS_3D)
1811 // Load values from matrix A
1812 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1813#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1814 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1815#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1816#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1817 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1818#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1819#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1820 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1822#endif // defined(REINTERPRET_INPUT_AS_3D)
1823
Giorgio Arena6200fa42018-07-06 17:06:36 +01001824 // Load values from matrix B
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001825 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
1826 src_addr.s1 += src1_stride_y;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001827
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001828 acc00 += (uint)a0 * b0.s0;
1829 acc01 += (uint)a0 * b0.s1;
1830 acc02 += (uint)a0 * b0.s2;
1831 acc03 += (uint)a0 * b0.s3;
1832 acc04 += (uint)a0 * b0.s4;
1833 acc05 += (uint)a0 * b0.s5;
1834 acc06 += (uint)a0 * b0.s6;
1835 acc07 += (uint)a0 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001836
Giorgio Arena6200fa42018-07-06 17:06:36 +01001837#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001838 acc10 += (uint)a1 * b0.s0;
1839 acc11 += (uint)a1 * b0.s1;
1840 acc12 += (uint)a1 * b0.s2;
1841 acc13 += (uint)a1 * b0.s3;
1842 acc14 += (uint)a1 * b0.s4;
1843 acc15 += (uint)a1 * b0.s5;
1844 acc16 += (uint)a1 * b0.s6;
1845 acc17 += (uint)a1 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001846#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1847#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001848 acc20 += (uint)a2 * b0.s0;
1849 acc21 += (uint)a2 * b0.s1;
1850 acc22 += (uint)a2 * b0.s2;
1851 acc23 += (uint)a2 * b0.s3;
1852 acc24 += (uint)a2 * b0.s4;
1853 acc25 += (uint)a2 * b0.s5;
1854 acc26 += (uint)a2 * b0.s6;
1855 acc27 += (uint)a2 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001856#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1857#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001858 acc30 += (uint)a3 * b0.s0;
1859 acc31 += (uint)a3 * b0.s1;
1860 acc32 += (uint)a3 * b0.s2;
1861 acc33 += (uint)a3 * b0.s3;
1862 acc34 += (uint)a3 * b0.s4;
1863 acc35 += (uint)a3 * b0.s5;
1864 acc36 += (uint)a3 * b0.s6;
1865 acc37 += (uint)a3 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001866#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001867
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001868 src_addr.s0 += 1;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001869 }
1870
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001871 int z = get_global_id(2);
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001872
Giorgio Arena6200fa42018-07-06 17:06:36 +01001873 // Compute destination address
1874 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1875
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001876 // Compute dst address
1877 __global uchar *dst_addr = dst.ptr;
1878
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001879#if defined(REINTERPRET_OUTPUT_AS_3D)
1880 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1881 // in order to take into account the presence of possible cross plane paddings
1882 //
1883 // | |
1884 // | plane0 |
1885 // | |
1886 // |__________________|
1887 // |******************|
1888 // | cross_plane_pad |
1889 // |******************|
1890 // | |
1891 // | plane1 |
1892 // | |
1893 // |__________________|
1894
1895 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001896 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001897 zout = min(DEPTH_GEMM3D - 1, zout);
1898
1899 // Add offset due to the cross plane paddings
1900 zout *= (dst_cross_plane_pad * dst_stride_y);
1901
1902 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1903 // multiply dst_stride_z by DEPTH_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001904 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001905
Giorgio Arena6200fa42018-07-06 17:06:36 +01001906 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001907 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
1908 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001910 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
1911 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001912#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1913#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001914 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
1915 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001916#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001918 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
1919 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001920#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001921
1922#else // defined(REINTERPRET_OUTPUT_AS_3D)
1923 // Add offset for batched GEMM
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001924 dst_addr += z * dst_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001925
1926 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001927 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
1928 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001929#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001930 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
1931 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001932#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1933#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001934 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
1935 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001936#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1937#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001938 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
1939 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001941#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1942}
1943
1944__kernel void gemmlowp_mm_bifrost_transposed_dot8(IMAGE_DECLARATION(src0),
1945 IMAGE_DECLARATION(src1),
1946 IMAGE_DECLARATION(dst),
1947 uint src0_stride_z,
1948 uint src1_stride_z,
1949 uint dst_stride_z
1950#if defined(REINTERPRET_INPUT_AS_3D)
1951 ,
1952 uint src_cross_plane_pad
1953#endif // REINTERPRET_INPUT_AS_3D
1954#if defined(REINTERPRET_OUTPUT_AS_3D)
1955 ,
1956 uint dst_cross_plane_pad
1957#endif // REINTERPRET_OUTPUT_AS_3D)
1958 )
1959{
1960 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1961
1962 // Compute starting address for matrix A and Matrix B
1963 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1964
1965 // Update address for the matrix A
1966 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1967
1968 // Update address for the matrix B
1969 src_addr.s1 += idx;
1970
1971#if defined(REINTERPRET_INPUT_AS_3D)
1972 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1973 // in order to take into account the presence of possible cross plane paddings
1974 //
1975 // | |
1976 // | plane0 |
1977 // | |
1978 // |__________________|
1979 // |******************|
1980 // | cross_plane_pad |
1981 // |******************|
1982 // | |
1983 // | plane1 |
1984 // | |
1985 // |__________________|
1986
1987 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1988 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1989 zin = min(DEPTH_GEMM3D - 1, zin);
1990
1991 // Add offset due to the cross plane paddings
1992 zin *= (src_cross_plane_pad * src0_stride_y);
1993
1994 zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
1995
1996 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1997 // multiply src0_stride_z by DEPTH_GEMM3D
1998 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1999
2000#else // defined(REINTERPRET_INPUT_AS_3D)
2001
2002 // Add offset for batched GEMM
2003 src_addr.s0 += get_global_id(2) * src0_stride_z;
2004
2005#endif // defined(REINTERPRET_INPUT_AS_3D)
2006
2007#if defined(MATRIX_B_DEPTH)
2008 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2009 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2010#else // defined(MATRIX_B_DEPTH)
2011 src_addr.s1 += get_global_id(2) * src1_stride_z;
2012#endif // defined(MATRIX_B_DEPTH)
2013
2014 uint acc00 = 0;
2015 uint acc01 = 0;
2016 uint acc02 = 0;
2017 uint acc03 = 0;
2018 uint acc04 = 0;
2019 uint acc05 = 0;
2020 uint acc06 = 0;
2021 uint acc07 = 0;
2022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2023 uint acc10 = 0;
2024 uint acc11 = 0;
2025 uint acc12 = 0;
2026 uint acc13 = 0;
2027 uint acc14 = 0;
2028 uint acc15 = 0;
2029 uint acc16 = 0;
2030 uint acc17 = 0;
2031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2033 uint acc20 = 0;
2034 uint acc21 = 0;
2035 uint acc22 = 0;
2036 uint acc23 = 0;
2037 uint acc24 = 0;
2038 uint acc25 = 0;
2039 uint acc26 = 0;
2040 uint acc27 = 0;
2041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2043 uint acc30 = 0;
2044 uint acc31 = 0;
2045 uint acc32 = 0;
2046 uint acc33 = 0;
2047 uint acc34 = 0;
2048 uint acc35 = 0;
2049 uint acc36 = 0;
2050 uint acc37 = 0;
2051#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2052
2053 // A and B src indices get incremented at the same time.
2054 int i = 0;
2055 for(; i <= ((int)COLS_A - 8); i += 8)
2056 {
2057#if defined(REINTERPRET_INPUT_AS_3D)
2058 // Load values from matrix A and matrix B
2059 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
2060#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2061 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
2062#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2064 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
2065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2067 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
2068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2069#else // defined(REINTERPRET_INPUT_AS_3D)
2070 // Load values from matrix A and matrix B
2071 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2072#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2073 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2074#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2075#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2076 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2077#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2078#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2079 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2080#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2081#endif // defined(REINTERPRET_INPUT_AS_3D)
2082
2083 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
2084 uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
2085 uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
2086 uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
2087 src_addr.s1 += 4 * src1_stride_y;
2088
2089 ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
2090 ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
2091 ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
2092 ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
2093 ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
2094 ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
2095 ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
2096 ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
2097
2098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2099 ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
2100 ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
2101 ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
2102 ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
2103 ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
2104 ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
2105 ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
2106 ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
2107#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2108#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2109 ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
2110 ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
2111 ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
2112 ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
2113 ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
2114 ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
2115 ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
2116 ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
2117#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2118#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2119 ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
2120 ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
2121 ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
2122 ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
2123 ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
2124 ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
2125 ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
2126 ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
2127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2128
2129 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
2130 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
2131 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
2132 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
2133 src_addr.s1 += 4 * src1_stride_y;
2134
2135 ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
2136 ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
2137 ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
2138 ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
2139 ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
2140 ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
2141 ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
2142 ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
2143
2144#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2145 ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
2146 ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
2147 ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
2148 ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
2149 ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
2150 ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
2151 ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
2152 ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
2153#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2154#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2155 ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
2156 ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
2157 ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
2158 ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
2159 ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
2160 ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
2161 ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
2162 ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
2163#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2165 ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
2166 ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
2167 ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
2168 ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
2169 ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
2170 ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
2171 ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
2172 ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
2173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2174
2175 src_addr.s0 += 8;
2176 }
2177
2178 for(; i < (int)COLS_A; ++i)
2179 {
2180#if defined(REINTERPRET_INPUT_AS_3D)
2181 // Load values from matrix A
2182 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
2183#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2184 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
2185#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2186#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2187 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
2188#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2189#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2190 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
2191#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2192#else // defined(REINTERPRET_INPUT_AS_3D)
2193 // Load values from matrix A
2194 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2196 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2197#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2198#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2199 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2201#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2202 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2203#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2204#endif // defined(REINTERPRET_INPUT_AS_3D)
2205
2206 // Load values from matrix B
2207 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
2208 src_addr.s1 += src1_stride_y;
2209
2210 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s0), acc00);
2211 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s1), acc01);
2212 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s2), acc02);
2213 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s3), acc03);
2214 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s4), acc04);
2215 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s5), acc05);
2216 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s6), acc06);
2217 ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s7), acc07);
2218
2219#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2220 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s0), acc10);
2221 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s1), acc11);
2222 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s2), acc12);
2223 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s3), acc13);
2224 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s4), acc14);
2225 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s5), acc15);
2226 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s6), acc16);
2227 ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s7), acc17);
2228#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2229#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2230 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s0), acc20);
2231 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s1), acc21);
2232 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s2), acc22);
2233 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s3), acc23);
2234 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s4), acc24);
2235 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s5), acc25);
2236 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s6), acc26);
2237 ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s7), acc27);
2238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2240 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s0), acc30);
2241 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s1), acc31);
2242 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s2), acc32);
2243 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s3), acc33);
2244 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s4), acc34);
2245 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s5), acc35);
2246 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s6), acc36);
2247 ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s7), acc37);
2248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2249
2250 src_addr.s0 += 1;
2251 }
2252
2253 int z = get_global_id(2);
2254
2255 // Compute destination address
2256 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2257
2258 // Compute dst address
2259 __global uchar *dst_addr = dst.ptr;
2260
2261#if defined(REINTERPRET_OUTPUT_AS_3D)
2262 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2263 // in order to take into account the presence of possible cross plane paddings
2264 //
2265 // | |
2266 // | plane0 |
2267 // | |
2268 // |__________________|
2269 // |******************|
2270 // | cross_plane_pad |
2271 // |******************|
2272 // | |
2273 // | plane1 |
2274 // | |
2275 // |__________________|
2276
2277 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2278 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2279 zout = min(DEPTH_GEMM3D - 1, zout);
2280
2281 // Add offset due to the cross plane paddings
2282 zout *= (dst_cross_plane_pad * dst_stride_y);
2283
2284 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2285 // multiply dst_stride_z by DEPTH_GEMM3D
2286 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2287
2288 // Store the result
2289 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
2290 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
2291#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2292 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
2293 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
2294#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2295#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2296 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
2297 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
2298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2299#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2300 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
2301 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
2302#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2303
2304#else // defined(REINTERPRET_OUTPUT_AS_3D)
2305 // Add offset for batched GEMM
2306 dst_addr += z * dst_stride_z;
2307
2308 // Store the result
2309 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
2310 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
2311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2312 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
2313 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
2314#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2315#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2316 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
2317 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
2318#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2319#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2320 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
2321 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
2322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01002323#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01002324}
Georgios Pinitasdaa38552018-08-28 17:43:18 +01002325#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01002326
Gian Marco05288a22017-11-21 10:57:50 +00002327#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
2328
2329#if defined(COLS_A)
2330/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
2331 *
2332 * @note This stage is needed to handle the offset of matrix product
2333 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2334 *
2335 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
2336 *
2337 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2338 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2339 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2340 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2341 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2342 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2343 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2344 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2345 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2346 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2347 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2348 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2349 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2350 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2351 */
2352__kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
2353 IMAGE_DECLARATION(dst))
2354{
2355 // Compute source and destination addresses
2356 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2357 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2358
2359 uint4 sum_row_u32 = (uint4)0;
2360 uint sum_row = 0;
2361
2362 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
2363
2364 int i = 0;
2365
2366 // This for loop performs 16 accumulations
2367 for(; i <= ((int)COLS_A - 16); i += 16)
2368 {
2369 const uchar16 a0_u8 = vload16(0, matrix_a + i);
2370
2371 sum_row_u32 += convert_uint4(a0_u8.s0123) + convert_uint4(a0_u8.s4567) + convert_uint4(a0_u8.s89AB) + convert_uint4(a0_u8.sCDEF);
2372 }
2373
2374 // This for loop performs the leftover accumulations
2375 for(; i < COLS_A; ++i)
2376 {
2377 sum_row += matrix_a[i];
2378 }
2379
2380 sum_row += sum_row_u32.s0 + sum_row_u32.s1 + sum_row_u32.s2 + sum_row_u32.s3;
2381
2382 *((__global int *)dst.ptr) = (int)sum_row;
2383}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002384
2385#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2386/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction
2387 *
2388 * @note This stage is needed to handle the offset of matrix product
2389 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2390 *
2391 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
2392 *
2393 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2394 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2395 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2396 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2397 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2398 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2399 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2400 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2401 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2402 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2403 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2404 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2405 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2406 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2407 */
2408__kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
2409 IMAGE_DECLARATION(dst))
2410{
2411 // Compute source and destination addresses
2412 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2413 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2414
2415 uint sum_row = 0;
2416
2417 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
2418
2419 int i = 0;
2420
2421 // This for loop performs 16 accumulations
2422 for(; i <= ((int)COLS_A - 32); i += 32)
2423 {
2424 uchar16 a0_u8 = vload16(0, matrix_a + i);
2425
2426 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
2427 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
2428 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
2429 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
2430
2431 a0_u8 = vload16(1, matrix_a + i);
2432
2433 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
2434 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
2435 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
2436 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
2437 }
2438
2439 // This for loop performs the leftover accumulations
2440 for(; i < COLS_A; ++i)
2441 {
2442 sum_row += matrix_a[i];
2443 }
2444
2445 *((__global int *)dst.ptr) = (int)sum_row;
2446}
2447#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco05288a22017-11-21 10:57:50 +00002448#endif // defined(COLS_A)
2449
2450#if defined(COLS_B) && defined(ROWS_B)
2451/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
2452 *
2453 * @note This stage is needed to handle the offset of matrix product
2454 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2455 *
2456 * @attention The number of matrix B columns and rows needs to be passed at compile time using -DCOLS_B and -DROWS_B
2457 *
2458 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2459 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2460 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2461 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2462 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2463 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2464 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2465 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2466 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2467 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2468 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2469 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2470 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2471 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2472 */
2473__kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src),
2474 IMAGE_DECLARATION(dst))
2475{
2476 // Compute source and destination addresses
2477 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2478 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2479
2480 uint16 sum_col_u32 = (uint16)0;
2481
2482 __global const uchar *matrix_b = (__global const uchar *)(src.ptr + get_global_id(1) * src_stride_z);
2483
2484 int i = 0;
2485 // This for loop performs 4 accumulations
2486 for(; i <= ((int)ROWS_B - 4); i += 4)
2487 {
2488 const uchar16 b0_u8 = vload16(0, matrix_b + 0 * src_stride_y);
2489 const uchar16 b1_u8 = vload16(0, matrix_b + 1 * src_stride_y);
2490 const uchar16 b2_u8 = vload16(0, matrix_b + 2 * src_stride_y);
2491 const uchar16 b3_u8 = vload16(0, matrix_b + 3 * src_stride_y);
2492
2493 sum_col_u32 += convert_uint16(b0_u8) + convert_uint16(b1_u8) + convert_uint16(b2_u8) + convert_uint16(b3_u8);
2494
2495 matrix_b += 4 * src_stride_y;
2496 }
2497
2498 // This for loop perfoms the leftover accumulations
2499 for(; i < (int)ROWS_B; ++i)
2500 {
2501 const uchar16 b0_u8 = vload16(0, matrix_b);
2502
2503 sum_col_u32 += convert_uint16(b0_u8);
2504
2505 matrix_b += src_stride_y;
2506 }
2507
2508 vstore16(convert_int16(sum_col_u32), 0, (__global int *)dst.ptr);
2509}
2510#endif // defined(COLS_B) && defined(ROWS_B)
2511
2512#if defined(K_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002513
2514/* Helper function used to calculate the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel.
2515 *
2516 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
2517 * and calculates the offset contribution of matrix A and matrix B.
2518 *
2519 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2520 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2521 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
2522 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
2523 *
2524 * @param[in] x get_global_id(0) * 4
2525 * @param[in] y get_global_id(1)
2526 * @param[in] z get_global_id(2)
2527 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2528 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2529 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2530 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2531 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2532 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2533 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2534 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2535 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2536 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2537 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2538 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2539 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2540 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2541 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2542 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
2543 */
2544inline int4 offset_contribution(
2545 int x,
2546 int y,
2547 int z
2548#if defined(A_OFFSET)
2549 ,
2550 IMAGE_DECLARATION(sum_col)
2551#endif // defined(A_OFFSET)
2552#if defined(B_OFFSET)
2553 ,
2554 IMAGE_DECLARATION(sum_row)
2555#endif // defined(B_OFFSET)
2556#if defined(ADD_BIAS)
2557 ,
2558 VECTOR_DECLARATION(biases)
2559#endif // defined(ADD_BIAS)
2560)
2561{
2562 int4 a_offset_s32 = (int4)0;
2563 int4 b_offset_s32 = (int4)0;
2564
2565 int batch_id = z;
2566#if defined(DEPTH_INPUT3D)
2567 batch_id /= (int)DEPTH_INPUT3D;
2568#endif // defined(DEPTH_INPUT3D)
2569
2570#if defined(A_OFFSET)
2571 // Compute the offset contribution due to A_OFFSET
2572 __global uchar *sum_col_addr = sum_col_ptr + sum_col_offset_first_element_in_bytes + x * sizeof(int);
2573
2574 // Compute the offset contribution due to A_OFFSET
2575#if defined(SUM_COL_HAS_BATCHES)
2576 a_offset_s32 = vload4(0, (__global int *)(sum_col_addr + batch_id * sum_col_stride_y));
2577#else // defined(SUM_COL_HAS_BATCHES)
2578 a_offset_s32 = vload4(0, (__global int *)sum_col_addr);
2579#endif // defined(SUM_COL_HAS_BATCHES)
2580
2581 a_offset_s32 *= (int4)A_OFFSET;
2582#endif // defined(A_OFFSET)
2583
2584#if defined(B_OFFSET)
2585 // Compute the offset contribution due to A_OFFSET
2586 __global uchar *sum_row_addr = sum_row_ptr + sum_row_offset_first_element_in_bytes + y * sizeof(int);
2587
2588 // Compute the offset contribution due to B_OFFSET
2589#if defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2590 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)) + (z % (int)DEPTH_INPUT3D) * (int)HEIGHT_INPUT3D);
2591#else // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2592 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)));
2593#endif // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2594 b_offset_s32 *= (int4)B_OFFSET;
2595#endif // defined(B_OFFSET)
2596
2597#if defined(ADD_BIAS)
2598 // Add bias
2599 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
2600
2601 int4 biases_values = vload4(0, (__global int *)bias_addr);
2602 b_offset_s32 += (int4)biases_values;
2603#endif // defined(ADD_BIAS)
2604
2605 return (int4)K_OFFSET + a_offset_s32 + b_offset_s32;
2606}
2607
Gian Marco05288a22017-11-21 10:57:50 +00002608/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel. The computation is performed in-place
2609 *
2610 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
2611 * and adds to it the offset contribution of matrix A and matrix B in-place.
2612 *
2613 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2614 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2615 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
Chunosov5124be52017-11-22 20:42:13 +07002616 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
Gian Marco05288a22017-11-21 10:57:50 +00002617 *
2618 * The final result is:
2619 *
2620 * mm_result[i][k] = mm_result[i][k] +
2621 * (sum_col[k] * A_OFFSET) +
2622 * (sum_row[i] * B_OFFSET) +
2623 * (K_OFFSET)
2624 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01002625 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
2626 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
2627 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
2628 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
2629 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
2630 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
2631 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
2632 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002633 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2634 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2635 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2636 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2637 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2638 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2639 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2640 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2641 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2642 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2643 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2644 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2645 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2646 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2647 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2648 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00002649 */
2650__kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result)
2651#if defined(A_OFFSET)
2652 ,
2653 IMAGE_DECLARATION(sum_col)
2654#endif // defined(A_OFFSET)
2655#if defined(B_OFFSET)
2656 ,
2657 IMAGE_DECLARATION(sum_row)
2658#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002659#if defined(ADD_BIAS)
2660 ,
2661 VECTOR_DECLARATION(biases)
2662#endif // defined(ADD_BIAS))
Gian Marco05288a22017-11-21 10:57:50 +00002663 )
2664{
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002665 const int x = get_global_id(0) * 4;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01002666 const int y = get_global_id(1);
2667 const int z = get_global_id(2);
2668
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002669 // Compute offset contribution
2670 int4 offset_term_s32 = offset_contribution(
2671 x, y, z
Gian Marco05288a22017-11-21 10:57:50 +00002672#if defined(A_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002673 ,
2674 sum_col_ptr,
2675 sum_col_stride_x,
2676 sum_col_step_x,
2677 sum_col_stride_y,
2678 sum_col_step_y,
2679 sum_col_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00002680#endif // defined(A_OFFSET)
Gian Marco05288a22017-11-21 10:57:50 +00002681#if defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002682 ,
2683 sum_row_ptr,
2684 sum_row_stride_x,
2685 sum_row_step_x,
2686 sum_row_stride_y,
2687 sum_row_step_y,
2688 sum_row_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00002689#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002690#if defined(ADD_BIAS)
2691 ,
2692 biases_ptr,
2693 biases_stride_x,
2694 biases_step_x,
2695 biases_offset_first_element_in_bytes
2696#endif // defined(ADD_BIAS)
2697 );
Gian Marco05288a22017-11-21 10:57:50 +00002698
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002699 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00002700
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002701 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00002702
2703 // Add the offset terms to GEMM's result
2704 in_s32 += offset_term_s32;
2705
2706 // Store the result with the offset contribution
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002707 vstore4(in_s32, 0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00002708}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002709
2710#if defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
2711/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
2712 *
2713 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
2714 *
2715 *
2716 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2717 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2718 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
2719 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
2720 *
2721 * The result before the output stage is:
2722 *
2723 * mm_result[i][k] = mm_result[i][k] +
2724 * (sum_col[k] * A_OFFSET) +
2725 * (sum_row[i] * B_OFFSET) +
2726 * (K_OFFSET)
2727 *
2728 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
2729 *
2730 * -# Add offset terms to final result
2731 * -# Multiply each entry of result by result_mult_int
2732 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
2733 * -# Shift the int32 accumulator by result_shift
2734 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
2735 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
2736 *
2737 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
2738 *
2739 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
2740 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
2741 * These values can be used to implement "rectified linear unit" activation functions
2742 *
2743 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
2744 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
2745 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
2746 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
2747 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
2748 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
2749 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
2750 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
2751 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2752 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2753 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2754 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2755 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2756 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2757 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2758 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2759 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2760 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2761 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2762 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2763 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2764 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2765 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2766 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
2767 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
2768 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2769 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2770 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2771 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2772 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
2773 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2774 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2775 */
2776__kernel void gemmlowp_offset_contribution_quantize_down(TENSOR3D_DECLARATION(mm_result)
2777#if defined(A_OFFSET)
2778 ,
2779 IMAGE_DECLARATION(sum_col)
2780#endif // defined(A_OFFSET)
2781#if defined(B_OFFSET)
2782 ,
2783 IMAGE_DECLARATION(sum_row)
2784#endif // defined(B_OFFSET)
2785 ,
2786#if defined(ADD_BIAS)
2787 VECTOR_DECLARATION(biases),
2788#endif // defined(ADD_BIAS)
2789 TENSOR3D_DECLARATION(dst))
2790{
2791 const int x = get_global_id(0) * 4;
2792 const int y = get_global_id(1);
2793 const int z = get_global_id(2);
2794
2795 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
2796
2797 // Compute offset contribution
2798 int4 offset_term_s32 = offset_contribution(
2799 x, y, z
2800#if defined(A_OFFSET)
2801 ,
2802 sum_col_ptr,
2803 sum_col_stride_x,
2804 sum_col_step_x,
2805 sum_col_stride_y,
2806 sum_col_step_y,
2807 sum_col_offset_first_element_in_bytes
2808#endif // defined(A_OFFSET)
2809#if defined(B_OFFSET)
2810 ,
2811 sum_row_ptr,
2812 sum_row_stride_x,
2813 sum_row_step_x,
2814 sum_row_stride_y,
2815 sum_row_step_y,
2816 sum_row_offset_first_element_in_bytes
2817#endif // defined(B_OFFSET)
2818#if defined(ADD_BIAS)
2819 ,
2820 biases_ptr,
2821 biases_stride_x,
2822 biases_step_x,
2823 biases_offset_first_element_in_bytes
2824#endif // defined(ADD_BIAS)
2825 );
2826
2827 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
2828
2829 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
2830
2831 // Add the offset terms to GEMM's result
2832 in_s32 += offset_term_s32;
2833
2834 // -------------- OUTPUT STAGE
2835
2836 // Add the offset terms to GEMM's result
2837 in_s32 += (int4)RESULT_OFFSET;
2838
2839 // Multiply by result_mult_int and shift
2840 in_s32 *= RESULT_MULTIPLIER;
2841
2842 in_s32 >>= RESULT_SHIFT;
2843
2844 uchar4 res = convert_uchar4_sat(in_s32);
2845
2846#if defined(MIN_BOUND)
2847 res = max(res, (uchar4)MIN_BOUND);
2848#endif // defined(MIN_BOUND)
2849#if defined(MAX_BOUND)
2850 res = min(res, (uchar4)MAX_BOUND);
2851#endif // defined(MAX_BOUND)
2852
2853 // Store the result
2854 vstore4(res, 0, dst_addr);
2855}
2856
2857/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
2858 *
2859 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
2860 *
2861 *
2862 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2863 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2864 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
2865 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
2866 *
2867 * The result before the output stage is:
2868 *
2869 * mm_result[i][k] = mm_result[i][k] +
2870 * (sum_col[k] * A_OFFSET) +
2871 * (sum_row[i] * B_OFFSET) +
2872 * (K_OFFSET)
2873 *
2874 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
2875 *
2876 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
2877 * -# Add bias to final result if bias tensor is not a nullptr
2878 * -# Round to nearest division by a power-of-two using result_shift
2879 * -# Add offset to each result
2880 * -# Clamp the value between the specified min and max bounds
2881 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
2882 *
2883 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
2884 *
2885 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
2886 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
2887 * These values can be used to implement "rectified linear unit" activation functions
2888 *
2889 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
2890 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
2891 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
2892 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
2893 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
2894 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
2895 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
2896 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
2897 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2898 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2899 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2900 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2901 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2902 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2903 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2904 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2905 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2906 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2907 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2908 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2909 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2910 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2911 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2912 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
2913 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
2914 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2915 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2916 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2917 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2918 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
2919 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2920 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2921 */
2922__kernel void gemmlowp_offset_contribution_quantize_down_fixedpoint(TENSOR3D_DECLARATION(mm_result)
2923#if defined(A_OFFSET)
2924 ,
2925 IMAGE_DECLARATION(sum_col)
2926#endif // defined(A_OFFSET)
2927#if defined(B_OFFSET)
2928 ,
2929 IMAGE_DECLARATION(sum_row)
2930#endif // defined(B_OFFSET)
2931 ,
2932#if defined(ADD_BIAS)
2933 VECTOR_DECLARATION(biases),
2934#endif // defined(ADD_BIAS)
2935 TENSOR3D_DECLARATION(dst))
2936{
2937 const int x = get_global_id(0) * 4;
2938 const int y = get_global_id(1);
2939 const int z = get_global_id(2);
2940
2941 // Compute offset contribution
2942 int4 offset_term_s32 = offset_contribution(
2943 x, y, z
2944#if defined(A_OFFSET)
2945 ,
2946 sum_col_ptr,
2947 sum_col_stride_x,
2948 sum_col_step_x,
2949 sum_col_stride_y,
2950 sum_col_step_y,
2951 sum_col_offset_first_element_in_bytes
2952#endif // defined(A_OFFSET)
2953#if defined(B_OFFSET)
2954 ,
2955 sum_row_ptr,
2956 sum_row_stride_x,
2957 sum_row_step_x,
2958 sum_row_stride_y,
2959 sum_row_step_y,
2960 sum_row_offset_first_element_in_bytes
2961#endif // defined(B_OFFSET)
2962#if defined(ADD_BIAS)
2963 ,
2964 biases_ptr,
2965 biases_stride_x,
2966 biases_step_x,
2967 biases_offset_first_element_in_bytes
2968#endif // defined(ADD_BIAS)
2969 );
2970
2971 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
2972
2973 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
2974
2975 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
2976
2977 // Add the offset terms to GEMM's result
2978 in_s32 += offset_term_s32;
2979
2980 // -------------- OUTPUT STAGE
2981
2982 // Multiply by result_mult_int and shift
2983 in_s32 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(in_s32, RESULT_MULTIPLIER, RESULT_SHIFT, 4);
2984
2985 // Add the offset terms to GEMM's result
2986 in_s32 += (int4)RESULT_OFFSET;
2987
2988 uchar4 res = convert_uchar4_sat(in_s32);
2989
2990#if defined(MIN_BOUND)
2991 res = max(res, (uchar4)MIN_BOUND);
2992#endif // defined(MIN_BOUND)
2993#if defined(MAX_BOUND)
2994 res = min(res, (uchar4)MAX_BOUND);
2995#endif // defined(MAX_BOUND)
2996
2997 // Store the result
2998 vstore4(res, 0, dst_addr);
2999}
3000#endif // defined(K_OFFSET) && defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
Gian Marco05288a22017-11-21 10:57:50 +00003001#endif // defined(K_OFFSET)
3002
3003#if defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3004/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3005 *
3006 * This kernel takes a final int32 accumulator value and processes it to obtain the final QASYMM8 value.
3007 * The following computations will be performed by the kernel:
3008 *
3009 * -# Add offset terms to final result
3010 * -# Multiply each entry of result by result_mult_int
3011 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
3012 * -# Shift the int32 accumulator by result_shift
3013 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
3014 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3015 *
3016 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3017 *
3018 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3019 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3020 * These values can be used to implement "rectified linear unit" activation functions
3021 *
3022 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3023 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3024 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3025 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3026 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3027 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3028 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3029 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003030 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3031 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3032 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3033 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00003034 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3035 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3036 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3037 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3038 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3039 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3040 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3041 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3042 */
3043__kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src),
3044#if defined(ADD_BIAS)
3045 VECTOR_DECLARATION(biases),
3046#endif // defined(ADD_BIAS)
3047 TENSOR3D_DECLARATION(dst))
3048{
3049 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003050 int x = get_global_id(0) * 4;
3051 int y = get_global_id(1);
3052 int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +00003053
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003054 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00003055
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003056 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3057
3058 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003059
Gian Marco05288a22017-11-21 10:57:50 +00003060#if defined(ADD_BIAS)
3061 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003062 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3063
3064 int4 biases_values = vload4(0, (__global int *)bias_addr);
3065 input_values += (int4)biases_values;
Gian Marco05288a22017-11-21 10:57:50 +00003066#endif // defined(ADD_BIAS)
3067
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003068 // Add the offset terms to GEMM's result
3069 input_values += (int4)RESULT_OFFSET;
3070
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +00003071 // Multiply by result_mult_int and shift
Gian Marco58c57942017-11-28 09:10:03 +00003072 input_values *= RESULT_MULT_INT;
Gian Marco05288a22017-11-21 10:57:50 +00003073
Gian Marco58c57942017-11-28 09:10:03 +00003074 input_values >>= RESULT_SHIFT;
Gian Marco05288a22017-11-21 10:57:50 +00003075
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003076 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco05288a22017-11-21 10:57:50 +00003077
3078#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003079 res = max(res, (uchar4)MIN_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003080#endif // defined(MIN_BOUND)
3081#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003082 res = min(res, (uchar4)MAX_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003083#endif // defined(MAX_BOUND)
3084
3085 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003086 vstore4(res, 0, dst_addr);
Gian Marco05288a22017-11-21 10:57:50 +00003087}
Gian Marco58c57942017-11-28 09:10:03 +00003088#endif // defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3089
3090#if defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
3091/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3092 *
3093 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3094 * The following computations will be performed by the kernel:
3095 *
3096 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3097 * -# Add bias to final result if bias tensor is not a nullptr
3098 * -# Round to nearest division by a power-of-two using result_shift
3099 * -# Add offset to each result
3100 * -# Clamp the value between the specified min and max bounds
3101 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3102 *
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003103 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET_AFTER_SHIFT, -DRESULT_FIXEDPOINT_MULTIPLIER and -DRESULT_SHIFT
Gian Marco58c57942017-11-28 09:10:03 +00003104 *
3105 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3106 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3107 * These values can be used to implement "rectified linear unit" activation functions
3108 *
3109 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3110 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3111 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3112 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3113 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3114 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3115 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3116 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003117 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3118 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3119 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3120 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco58c57942017-11-28 09:10:03 +00003121 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3122 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3123 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3124 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3125 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3126 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3127 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3128 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3129 */
3130__kernel void gemmlowp_output_stage_quantize_down_fixedpoint(TENSOR3D_DECLARATION(src),
3131#if defined(ADD_BIAS)
3132 VECTOR_DECLARATION(biases),
3133#endif // defined(ADD_BIAS)
3134 TENSOR3D_DECLARATION(dst))
3135{
3136 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003137 int x = get_global_id(0) * 4;
3138 int y = get_global_id(1);
3139 int z = get_global_id(2);
Georgios Pinitas932491f2018-09-21 16:33:15 +01003140
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003141 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco58c57942017-11-28 09:10:03 +00003142
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003143 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3144
3145 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003146
3147#if defined(ADD_BIAS)
3148 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003149 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3150
3151 int4 biases_values = vload4(0, (__global int *)bias_addr);
3152 input_values += (int4)biases_values;
Gian Marco58c57942017-11-28 09:10:03 +00003153#endif // defined(ADD_BIAS)
3154
3155 // Multiply by result_mult_int and shift
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003156 input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4);
Gian Marco58c57942017-11-28 09:10:03 +00003157
3158 // Add the offset terms to GEMM's result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003159 input_values += (int4)RESULT_OFFSET_AFTER_SHIFT;
Gian Marco58c57942017-11-28 09:10:03 +00003160
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003161 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco58c57942017-11-28 09:10:03 +00003162
3163#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003164 res = max(res, (uchar4)MIN_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003165#endif // defined(MIN_BOUND)
3166#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003167 res = min(res, (uchar4)MAX_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003168#endif // defined(MAX_BOUND)
3169
3170 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003171 vstore4(res, 0, dst_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003172}
Chunosov5124be52017-11-22 20:42:13 +07003173#endif // defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003174
3175#if defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)
3176/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3177 *
3178 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3179 * The following computations will be performed by the kernel:
3180 *
3181 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3182 * -# Add bias to final result if bias tensor is not a nullptr
3183 * -# Requantize
3184 * -# Add offset to each result
3185 * -# Clamp the value between the specified min and max bounds
3186 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3187 *
3188 * @attention The offset and scalar scale factor must be passed at compile time using -DRESULT_OFFSET, -DREAL_MULTIPLIER
3189 *
3190 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3191 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3192 * These values can be used to implement "rectified linear unit" activation functions
3193 *
3194 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3195 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3196 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3197 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3198 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3199 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3200 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3201 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3202 * @param[in] biases_ptr Pointer to the biases tensor. Supported data type: same as @p src_ptr
3203 * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
3204 * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
3205 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
3206 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3207 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3208 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3209 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3210 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3211 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3212 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3213 * @param[in] dst_stride_w Stride of the source tensor in W dimension (in bytes)
3214 * @param[in] dst_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
3215 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3216 */
3217__kernel void gemmlowp_output_stage_quantize_down_float(TENSOR3D_DECLARATION(src),
3218#if defined(ADD_BIAS)
3219 VECTOR_DECLARATION(biases),
3220#endif // defined(ADD_BIAS)
3221#if defined(DST_HEIGHT)
3222 TENSOR4D_DECLARATION(dst))
3223#else // defined(DST_HEIGHT)
3224 TENSOR3D_DECLARATION(dst))
3225#endif // defined(DST_HEIGHT)
3226{
3227 // Compute source and destination addresses
3228 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3229#if defined(DST_HEIGHT)
3230 Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 1);
3231 dst.ptr += get_global_id(0) * dst_step_x + (get_global_id(1) % DST_HEIGHT) * dst_step_y + (get_global_id(1) / DST_HEIGHT) * dst_step_z + get_global_id(2) * dst_step_w;
3232#else // defined(DST_HEIGHT)
3233 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
3234#endif // defined(DST_HEIGHT)
3235
3236#if defined(ADD_BIAS)
3237 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
3238#endif // defined(ADD_BIAS)
3239
3240 int16 input_values = vload16(0, (__global int *)src.ptr);
3241
3242#if defined(ADD_BIAS)
3243 // Add bias
3244 const int16 biases_values = vload16(0, (__global int *)biases.ptr);
3245 input_values += (int16)biases_values;
3246#endif // defined(ADD_BIAS)
3247
3248 // Convert to float
3249 float16 input_values_f = convert_float16(input_values);
3250 input_values_f = round(input_values_f * (float)REAL_MULTIPLIER + (float)OUTPUT_OFFSET);
3251
3252 uchar16 res = convert_uchar16_sat(input_values_f);
3253
3254#if defined(MIN_BOUND)
3255 res = max(res, (uchar16)MIN_BOUND);
3256#endif // defined(MIN_BOUND)
3257#if defined(MAX_BOUND)
3258 res = min(res, (uchar16)MAX_BOUND);
3259#endif // defined(MAX_BOUND)
3260
3261 // Store the result
3262 vstore16(res, 0, dst.ptr);
3263}
3264#endif // defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)