blob: 033b4b49422dcaa7f3a7027d032d3092e3c4bee9 [file] [log] [blame]
Gian Marco05288a22017-11-21 10:57:50 +00001/*
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002 * Copyright (c) 2017-2019 ARM Limited.
Gian Marco05288a22017-11-21 10:57:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +000025#include "helpers_asymm.h"
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +000026#include "repeat.h"
Gian Marco05288a22017-11-21 10:57:50 +000027
Georgios Pinitasdaa38552018-08-28 17:43:18 +010028#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
29#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010030#define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010031#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010032#define ARM_DOT(x, y, val) val += arm_dot((x), (y));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010033#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
34#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arenac50da382018-07-26 15:50:09 +010035
Gian Marco19835e52018-01-30 13:35:54 +000036#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +000037/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco19835e52018-01-30 13:35:54 +000038 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
Gian Marco05288a22017-11-21 10:57:50 +000039 *
Gian Marco19835e52018-01-30 13:35:54 +000040 * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024
41 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
42 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco05288a22017-11-21 10:57:50 +000043 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010044 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
45 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
46 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
47 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
48 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
49 *
Gian Marco05288a22017-11-21 10:57:50 +000050 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
51 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
52 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
53 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
54 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
55 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
56 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
57 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
58 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
59 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
60 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
61 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
62 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
63 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
64 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
65 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
66 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
67 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010068 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
69 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
70 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
71 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +000072 */
Gian Marco19835e52018-01-30 13:35:54 +000073__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0),
74 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010075 IMAGE_DECLARATION(dst),
76 uint src0_stride_z,
77 uint src1_stride_z,
78 uint dst_stride_z
79#if defined(REINTERPRET_OUTPUT_AS_3D)
80 ,
81 uint cross_plane_pad
82#endif // REINTERPRET_OUTPUT_AS_3D
83 )
Gian Marco05288a22017-11-21 10:57:50 +000084{
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010085 const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
86 const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
87 const int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +000088
Gian Marco19835e52018-01-30 13:35:54 +000089 // Offset
90 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
91 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
92
93 // src_addr_a = address of matrix A
94 // src_addr_b = address of matrix B
Isabella Gottardib92805b2018-09-28 18:24:27 +010095 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
Gian Marco19835e52018-01-30 13:35:54 +000096 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Gian Marco05288a22017-11-21 10:57:50 +000097
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010098#if defined(MATRIX_B_DEPTH)
99 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
100 src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
101#else // defined(MATRIX_B_DEPTH)
102 src_addr_b += z * src1_stride_z;
103#endif // defined(MATRIX_B_DEPTH)
104
Gian Marco05288a22017-11-21 10:57:50 +0000105 // Compute end row address for matrix B
Gian Marco19835e52018-01-30 13:35:54 +0000106 __global uchar *src_end_addr_b = src_addr_b + COLS_B;
107
108 src_addr_a += offset_row_a;
109 src_addr_b += offset_row_b;
Gian Marco05288a22017-11-21 10:57:50 +0000110
111 // Reset accumulators
Gian Marco19835e52018-01-30 13:35:54 +0000112 int4 c00 = 0;
113 int4 c10 = 0;
114 int4 c20 = 0;
115 int4 c30 = 0;
Gian Marco05288a22017-11-21 10:57:50 +0000116
Gian Marco19835e52018-01-30 13:35:54 +0000117 for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000118 {
119 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000120 int4 a0 = convert_int4(vload4(0, src_addr_a));
121 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000122
Gian Marco19835e52018-01-30 13:35:54 +0000123 c00 += (int4)a0.s0 * b0;
124 c10 += (int4)a0.s1 * b0;
125 c20 += (int4)a0.s2 * b0;
126 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000127
Gian Marco19835e52018-01-30 13:35:54 +0000128 a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
129 b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP));
Gian Marco05288a22017-11-21 10:57:50 +0000130
Gian Marco19835e52018-01-30 13:35:54 +0000131 c00 += (int4)a0.s0 * b0;
132 c10 += (int4)a0.s1 * b0;
133 c20 += (int4)a0.s2 * b0;
134 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000135 }
136
Gian Marco19835e52018-01-30 13:35:54 +0000137 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
Gian Marco05288a22017-11-21 10:57:50 +0000138 {
139 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000140 int4 a0 = convert_int4(vload4(0, src_addr_a));
141 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000142
Gian Marco19835e52018-01-30 13:35:54 +0000143 c00 += (int4)a0.s0 * b0;
144 c10 += (int4)a0.s1 * b0;
145 c20 += (int4)a0.s2 * b0;
146 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000147 }
148
149 // Compute destination address
150 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
151
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100152#if defined(REINTERPRET_OUTPUT_AS_3D)
153 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
154 // in order to take into account the presence of possible cross plane paddings
155 //
156 // | |
157 // | plane0 |
158 // | |
159 // |__________________|
160 // |******************|
161 // | cross_plane_pad |
162 // |******************|
163 // | |
164 // | plane1 |
165 // | |
166 // |__________________|
167
168 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
169 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
170 zout = min(DEPTH_GEMM3D - 1, zout);
171
172 // Add offset due to the cross plane paddings
173 zout *= (cross_plane_pad * dst_stride_y);
174
175 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
176 // multiply dst_stride_z by DEPTH_GEMM3D
177 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
178
Gian Marco19835e52018-01-30 13:35:54 +0000179 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100180 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
181 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
182 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
183 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
184
185#else // defined(REINTERPRET_OUTPUT_AS_3D)
186 // Add offset for batched GEMM
187 dst.ptr += z * dst_stride_z;
188
189 // Store 4x4 block
190 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
191 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
192 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
193 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
194#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000195}
Gian Marco19835e52018-01-30 13:35:54 +0000196
197/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
198 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
199 *
200 * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
201 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
202 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
203 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100204 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
205 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
206 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
207 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
208 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
209 *
Gian Marco19835e52018-01-30 13:35:54 +0000210 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
211 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
212 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
213 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
214 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
215 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
216 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
217 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
218 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
219 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
220 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
221 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
222 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
223 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
224 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
225 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
226 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
227 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100228 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
229 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
230 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
231 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco19835e52018-01-30 13:35:54 +0000232 */
233__kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0),
234 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100235 IMAGE_DECLARATION(dst),
236 uint src0_stride_z,
237 uint src1_stride_z,
238 uint dst_stride_z
239#if defined(REINTERPRET_OUTPUT_AS_3D)
240 ,
241 uint cross_plane_pad
242#endif // REINTERPRET_OUTPUT_AS_3D
243 )
Gian Marco19835e52018-01-30 13:35:54 +0000244{
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100245 const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
246 const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
247 const int z = get_global_id(2);
Gian Marco19835e52018-01-30 13:35:54 +0000248
249 // Offset
250 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
251 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
252
253 // src_addr_a = address of matrix A
254 // src_addr_b = address of matrix B
Isabella Gottardib92805b2018-09-28 18:24:27 +0100255 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
Gian Marco19835e52018-01-30 13:35:54 +0000256 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
257
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100258#if defined(MATRIX_B_DEPTH)
259 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
260 src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
261#else // defined(MATRIX_B_DEPTH)
262 src_addr_b += z * src1_stride_z;
263#endif // defined(MATRIX_B_DEPTH)
264
Gian Marco19835e52018-01-30 13:35:54 +0000265 // Compute end row address for matrix B
266 __global uchar *src_end_addr_b = src_addr_b + COLS_B;
267
268 src_addr_a += offset_row_a;
269 src_addr_b += offset_row_b;
270
271 // Reset accumulators
272 uint c00 = 0;
273 uint c01 = 0;
274 uint c02 = 0;
275 uint c03 = 0;
276 uint c10 = 0;
277 uint c11 = 0;
278 uint c12 = 0;
279 uint c13 = 0;
280 uint c20 = 0;
281 uint c21 = 0;
282 uint c22 = 0;
283 uint c23 = 0;
284 uint c30 = 0;
285 uint c31 = 0;
286 uint c32 = 0;
287 uint c33 = 0;
288
289#if MULT_INTERLEAVE4X4_HEIGHT == 1
290 for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
291 {
292 // Load values from matrix A (interleaved) and matrix B (transposed)
293 uchar16 a0 = vload16(0, src_addr_a);
294 uchar4 b0 = vload4(0, src_addr_b);
295
296 c00 += (ushort)a0.s0 * b0.s0;
297 c01 += (ushort)a0.s0 * b0.s1;
298 c02 += (ushort)a0.s0 * b0.s2;
299 c03 += (ushort)a0.s0 * b0.s3;
300
301 c10 += (ushort)a0.s1 * b0.s0;
302 c11 += (ushort)a0.s1 * b0.s1;
303 c12 += (ushort)a0.s1 * b0.s2;
304 c13 += (ushort)a0.s1 * b0.s3;
305
306 c20 += (ushort)a0.s2 * b0.s0;
307 c21 += (ushort)a0.s2 * b0.s1;
308 c22 += (ushort)a0.s2 * b0.s2;
309 c23 += (ushort)a0.s2 * b0.s3;
310
311 c30 += (ushort)a0.s3 * b0.s0;
312 c31 += (ushort)a0.s3 * b0.s1;
313 c32 += (ushort)a0.s3 * b0.s2;
314 c33 += (ushort)a0.s3 * b0.s3;
315
316 // Load values from matrix B (transposed)
317 b0 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
318
319 c00 += (ushort)a0.s4 * b0.s0;
320 c01 += (ushort)a0.s4 * b0.s1;
321 c02 += (ushort)a0.s4 * b0.s2;
322 c03 += (ushort)a0.s4 * b0.s3;
323
324 c10 += (ushort)a0.s5 * b0.s0;
325 c11 += (ushort)a0.s5 * b0.s1;
326 c12 += (ushort)a0.s5 * b0.s2;
327 c13 += (ushort)a0.s5 * b0.s3;
328
329 c20 += (ushort)a0.s6 * b0.s0;
330 c21 += (ushort)a0.s6 * b0.s1;
331 c22 += (ushort)a0.s6 * b0.s2;
332 c23 += (ushort)a0.s6 * b0.s3;
333
334 c30 += (ushort)a0.s7 * b0.s0;
335 c31 += (ushort)a0.s7 * b0.s1;
336 c32 += (ushort)a0.s7 * b0.s2;
337 c33 += (ushort)a0.s7 * b0.s3;
338
339 // Load values from matrix B (transposed)
340 b0 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
341
342 c00 += (ushort)a0.s8 * b0.s0;
343 c01 += (ushort)a0.s8 * b0.s1;
344 c02 += (ushort)a0.s8 * b0.s2;
345 c03 += (ushort)a0.s8 * b0.s3;
346
347 c10 += (ushort)a0.s9 * b0.s0;
348 c11 += (ushort)a0.s9 * b0.s1;
349 c12 += (ushort)a0.s9 * b0.s2;
350 c13 += (ushort)a0.s9 * b0.s3;
351
352 c20 += (ushort)a0.sA * b0.s0;
353 c21 += (ushort)a0.sA * b0.s1;
354 c22 += (ushort)a0.sA * b0.s2;
355 c23 += (ushort)a0.sA * b0.s3;
356
357 c30 += (ushort)a0.sB * b0.s0;
358 c31 += (ushort)a0.sB * b0.s1;
359 c32 += (ushort)a0.sB * b0.s2;
360 c33 += (ushort)a0.sB * b0.s3;
361
362 // Load values from matrix B (transposed)
363 b0 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
364
365 c00 += (ushort)a0.sC * b0.s0;
366 c01 += (ushort)a0.sC * b0.s1;
367 c02 += (ushort)a0.sC * b0.s2;
368 c03 += (ushort)a0.sC * b0.s3;
369
370 c10 += (ushort)a0.sD * b0.s0;
371 c11 += (ushort)a0.sD * b0.s1;
372 c12 += (ushort)a0.sD * b0.s2;
373 c13 += (ushort)a0.sD * b0.s3;
374
375 c20 += (ushort)a0.sE * b0.s0;
376 c21 += (ushort)a0.sE * b0.s1;
377 c22 += (ushort)a0.sE * b0.s2;
378 c23 += (ushort)a0.sE * b0.s3;
379
380 c30 += (ushort)a0.sF * b0.s0;
381 c31 += (ushort)a0.sF * b0.s1;
382 c32 += (ushort)a0.sF * b0.s2;
383 c33 += (ushort)a0.sF * b0.s3;
384
385 // Load values from matrix A (interleaved) and matrix B (transposed)
386 a0 = vload16(0, src_addr_a + 16);
387 b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
388
389 c00 += (ushort)a0.s0 * b0.s0;
390 c01 += (ushort)a0.s0 * b0.s1;
391 c02 += (ushort)a0.s0 * b0.s2;
392 c03 += (ushort)a0.s0 * b0.s3;
393
394 c10 += (ushort)a0.s1 * b0.s0;
395 c11 += (ushort)a0.s1 * b0.s1;
396 c12 += (ushort)a0.s1 * b0.s2;
397 c13 += (ushort)a0.s1 * b0.s3;
398
399 c20 += (ushort)a0.s2 * b0.s0;
400 c21 += (ushort)a0.s2 * b0.s1;
401 c22 += (ushort)a0.s2 * b0.s2;
402 c23 += (ushort)a0.s2 * b0.s3;
403
404 c30 += (ushort)a0.s3 * b0.s0;
405 c31 += (ushort)a0.s3 * b0.s1;
406 c32 += (ushort)a0.s3 * b0.s2;
407 c33 += (ushort)a0.s3 * b0.s3;
408
409 // Load values from matrix B (transposed)
410 b0 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
411
412 c00 += (ushort)a0.s4 * b0.s0;
413 c01 += (ushort)a0.s4 * b0.s1;
414 c02 += (ushort)a0.s4 * b0.s2;
415 c03 += (ushort)a0.s4 * b0.s3;
416
417 c10 += (ushort)a0.s5 * b0.s0;
418 c11 += (ushort)a0.s5 * b0.s1;
419 c12 += (ushort)a0.s5 * b0.s2;
420 c13 += (ushort)a0.s5 * b0.s3;
421
422 c20 += (ushort)a0.s6 * b0.s0;
423 c21 += (ushort)a0.s6 * b0.s1;
424 c22 += (ushort)a0.s6 * b0.s2;
425 c23 += (ushort)a0.s6 * b0.s3;
426
427 c30 += (ushort)a0.s7 * b0.s0;
428 c31 += (ushort)a0.s7 * b0.s1;
429 c32 += (ushort)a0.s7 * b0.s2;
430 c33 += (ushort)a0.s7 * b0.s3;
431
432 // Load values from matrix B (transposed)
433 b0 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
434
435 c00 += (ushort)a0.s8 * b0.s0;
436 c01 += (ushort)a0.s8 * b0.s1;
437 c02 += (ushort)a0.s8 * b0.s2;
438 c03 += (ushort)a0.s8 * b0.s3;
439
440 c10 += (ushort)a0.s9 * b0.s0;
441 c11 += (ushort)a0.s9 * b0.s1;
442 c12 += (ushort)a0.s9 * b0.s2;
443 c13 += (ushort)a0.s9 * b0.s3;
444
445 c20 += (ushort)a0.sA * b0.s0;
446 c21 += (ushort)a0.sA * b0.s1;
447 c22 += (ushort)a0.sA * b0.s2;
448 c23 += (ushort)a0.sA * b0.s3;
449
450 c30 += (ushort)a0.sB * b0.s0;
451 c31 += (ushort)a0.sB * b0.s1;
452 c32 += (ushort)a0.sB * b0.s2;
453 c33 += (ushort)a0.sB * b0.s3;
454
455 // Load values from matrix B (transposed)
456 b0 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
457
458 c00 += (ushort)a0.sC * b0.s0;
459 c01 += (ushort)a0.sC * b0.s1;
460 c02 += (ushort)a0.sC * b0.s2;
461 c03 += (ushort)a0.sC * b0.s3;
462
463 c10 += (ushort)a0.sD * b0.s0;
464 c11 += (ushort)a0.sD * b0.s1;
465 c12 += (ushort)a0.sD * b0.s2;
466 c13 += (ushort)a0.sD * b0.s3;
467
468 c20 += (ushort)a0.sE * b0.s0;
469 c21 += (ushort)a0.sE * b0.s1;
470 c22 += (ushort)a0.sE * b0.s2;
471 c23 += (ushort)a0.sE * b0.s3;
472
473 c30 += (ushort)a0.sF * b0.s0;
474 c31 += (ushort)a0.sF * b0.s1;
475 c32 += (ushort)a0.sF * b0.s2;
476 c33 += (ushort)a0.sF * b0.s3;
477 }
478#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
479
480 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
481 {
482 // Load values from matrix A (interleaved) and matrix B (transposed)
483 uchar4 a0 = vload4(0, src_addr_a);
484 uchar4 b0 = vload4(0, src_addr_b);
485
486 c00 += (ushort)a0.s0 * b0.s0;
487 c01 += (ushort)a0.s0 * b0.s1;
488 c02 += (ushort)a0.s0 * b0.s2;
489 c03 += (ushort)a0.s0 * b0.s3;
490
491 c10 += (ushort)a0.s1 * b0.s0;
492 c11 += (ushort)a0.s1 * b0.s1;
493 c12 += (ushort)a0.s1 * b0.s2;
494 c13 += (ushort)a0.s1 * b0.s3;
495
496 c20 += (ushort)a0.s2 * b0.s0;
497 c21 += (ushort)a0.s2 * b0.s1;
498 c22 += (ushort)a0.s2 * b0.s2;
499 c23 += (ushort)a0.s2 * b0.s3;
500
501 c30 += (ushort)a0.s3 * b0.s0;
502 c31 += (ushort)a0.s3 * b0.s1;
503 c32 += (ushort)a0.s3 * b0.s2;
504 c33 += (ushort)a0.s3 * b0.s3;
505 }
506
507 // Compute destination address
508 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
509
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100510#if defined(REINTERPRET_OUTPUT_AS_3D)
511 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
512 // in order to take into account the presence of possible cross plane paddings
513 //
514 // | |
515 // | plane0 |
516 // | |
517 // |__________________|
518 // |******************|
519 // | cross_plane_pad |
520 // |******************|
521 // | |
522 // | plane1 |
523 // | |
524 // |__________________|
525
526 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
527 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
528 zout = min(DEPTH_GEMM3D - 1, zout);
529
530 // Add offset due to the cross plane paddings
531 zout *= (cross_plane_pad * dst_stride_y);
532
533 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
534 // multiply dst_stride_z by DEPTH_GEMM3D
535 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
536
Gian Marco19835e52018-01-30 13:35:54 +0000537 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100538 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
539 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
540 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
541 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
542
543#else // defined(REINTERPRET_OUTPUT_AS_3D)
544 // Add offset for batched GEMM
545 dst.ptr += z * dst_stride_z;
546
547 // Store 4x4 block
548 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
549 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
550 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
551 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
552#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco19835e52018-01-30 13:35:54 +0000553}
Giorgio Arena6200fa42018-07-06 17:06:36 +0100554
Georgios Pinitasdaa38552018-08-28 17:43:18 +0100555#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100556/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
557 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
558 *
559 * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
560 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
561 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
562 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100563 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
564 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
565 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
566 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
567 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
568 *
Giorgio Arena6200fa42018-07-06 17:06:36 +0100569 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
570 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
571 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
573 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
575 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
576 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
577 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
578 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
579 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
580 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
581 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
582 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
583 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
584 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
585 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
586 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100587 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
588 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
589 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
590 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100591 */
592__kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(src0),
593 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100594 IMAGE_DECLARATION(dst),
595 uint src0_stride_z,
596 uint src1_stride_z,
597 uint dst_stride_z
598#if defined(REINTERPRET_OUTPUT_AS_3D)
599 ,
600 uint cross_plane_pad
601#endif // REINTERPRET_OUTPUT_AS_3D
602 )
Giorgio Arena6200fa42018-07-06 17:06:36 +0100603{
Giorgio Arena6200fa42018-07-06 17:06:36 +0100604 // Offset
605 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
606 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
607
608 // src_addr_a = address of matrix A
609 // src_addr_b = address of matrix B
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100610 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + (get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT) * src0_stride_y + get_global_id(2) * src0_stride_z + src0_offset_first_element_in_bytes);
611 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + (get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP) * src1_stride_y + src1_offset_first_element_in_bytes);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100612
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100613#if defined(MATRIX_B_DEPTH)
614 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100615 src_addr_b += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100616#else // defined(MATRIX_B_DEPTH)
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100617 src_addr_b += get_global_id(2) * src1_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100618#endif // defined(MATRIX_B_DEPTH)
619
Giorgio Arena6200fa42018-07-06 17:06:36 +0100620 src_addr_a += offset_row_a;
621 src_addr_b += offset_row_b;
622
623 // Reset accumulators
624 uint c00 = 0;
625 uint c01 = 0;
626 uint c02 = 0;
627 uint c03 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100628
Giorgio Arena6200fa42018-07-06 17:06:36 +0100629 uint c10 = 0;
630 uint c11 = 0;
631 uint c12 = 0;
632 uint c13 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100633
Giorgio Arena6200fa42018-07-06 17:06:36 +0100634 uint c20 = 0;
635 uint c21 = 0;
636 uint c22 = 0;
637 uint c23 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100638
Giorgio Arena6200fa42018-07-06 17:06:36 +0100639 uint c30 = 0;
640 uint c31 = 0;
641 uint c32 = 0;
642 uint c33 = 0;
643
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100644#define COLS_MTX_B (COLS_B / (16 * MULT_TRANSPOSE1XW_WIDTH))
645
Giorgio Arena6200fa42018-07-06 17:06:36 +0100646#if MULT_INTERLEAVE4X4_HEIGHT == 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100647 int i = 0;
648 for(; i <= (int)(COLS_MTX_B - 8); i += 8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100649 {
650 // Load values from matrix A (interleaved) and matrix B (transposed)
651 uchar16 a0 = vload16(0, src_addr_a);
652 uchar4 b0 = vload4(0, src_addr_b);
653 uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
654 uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
655 uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100656 uchar4 b4 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
657 uchar4 b5 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
658 uchar4 b6 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
659 uchar4 b7 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100660
661 // Accumulate
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100662 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
663 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
664 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
665 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100666
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100667 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
668 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
669 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
670 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100671
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100672 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
673 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
674 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
675 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100676
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100677 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
678 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
679 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
680 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100681
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100682 // Accumulate
Giorgio Arena6200fa42018-07-06 17:06:36 +0100683 a0 = vload16(0, src_addr_a + 16);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100684
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100685 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c00);
686 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c01);
687 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c02);
688 ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c03);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100689
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100690 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c10);
691 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c11);
692 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c12);
693 ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c13);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100694
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100695 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c20);
696 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c21);
697 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c22);
698 ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c23);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100699
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100700 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c30);
701 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c31);
702 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c32);
703 ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c33);
704
705 src_addr_a += 32;
706 src_addr_b += 32 * TRANSPOSE1XW_WIDTH_STEP;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100707 }
708#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100709 int i_left_over = 0;
710 for(; i < (int)(COLS_MTX_B); ++i)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100711 {
712 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100713 uchar16 a0 = vload16(0, src_addr_a + (i_left_over % 4) + ((i_left_over / 4) * 16));
Gian Marco Iodice0c54a622018-10-30 12:20:03 +0000714 uchar4 b0 = vload4(0, src_addr_b);
Giorgio Arena6200fa42018-07-06 17:06:36 +0100715
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100716 c00 += a0.s0 * b0.s0;
717 c01 += a0.s0 * b0.s1;
718 c02 += a0.s0 * b0.s2;
719 c03 += a0.s0 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100720
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100721 c10 += a0.s4 * b0.s0;
722 c11 += a0.s4 * b0.s1;
723 c12 += a0.s4 * b0.s2;
724 c13 += a0.s4 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100725
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100726 c20 += a0.s8 * b0.s0;
727 c21 += a0.s8 * b0.s1;
728 c22 += a0.s8 * b0.s2;
729 c23 += a0.s8 * b0.s3;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100730
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100731 c30 += a0.sC * b0.s0;
732 c31 += a0.sC * b0.s1;
733 c32 += a0.sC * b0.s2;
734 c33 += a0.sC * b0.s3;
735
736 i_left_over++;
737 src_addr_b += 4 * TRANSPOSE1XW_WIDTH_STEP;
Giorgio Arena6200fa42018-07-06 17:06:36 +0100738 }
739
740 // Compute destination address
741 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
742
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100743#if defined(REINTERPRET_OUTPUT_AS_3D)
744 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
745 // in order to take into account the presence of possible cross plane paddings
746 //
747 // | |
748 // | plane0 |
749 // | |
750 // |__________________|
751 // |******************|
752 // | cross_plane_pad |
753 // |******************|
754 // | |
755 // | plane1 |
756 // | |
757 // |__________________|
758
759 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
760 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
761 zout = min(DEPTH_GEMM3D - 1, zout);
762
763 // Add offset due to the cross plane paddings
764 zout *= (cross_plane_pad * dst_stride_y);
765
766 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
767 // multiply dst_stride_z by DEPTH_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100768 dst.ptr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100769
Giorgio Arena6200fa42018-07-06 17:06:36 +0100770 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100771 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
772 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
773 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
774 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
775
776#else // defined(REINTERPRET_OUTPUT_AS_3D)
777 // Add offset for batched GEMM
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100778 dst.ptr += get_global_id(2) * dst_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100779
780 // Store 4x4 block
781 vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
782 vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
783 vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
784 vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
785#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100786}
Georgios Pinitasdaa38552018-08-28 17:43:18 +0100787#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100788
Gian Marco19835e52018-01-30 13:35:54 +0000789#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000790
791#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
792#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
793#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
794#define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
795/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
796 *
797 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
798 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100799 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
800 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
801 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
802 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
803 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
804 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
805 *
Gian Marco05288a22017-11-21 10:57:50 +0000806 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
807 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
808 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
809 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
810 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
811 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
812 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
813 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
814 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
815 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
816 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
817 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
818 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
819 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
820 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
821 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
822 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
823 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100824 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
825 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
826 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
827 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
828 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000829 */
Gian Marco7b4d5472018-01-10 15:56:30 +0000830__kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
831 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100832 IMAGE_DECLARATION(dst),
833 uint src0_stride_z,
834 uint src1_stride_z,
835 uint dst_stride_z
836#if defined(REINTERPRET_INPUT_AS_3D)
837 ,
838 uint src_cross_plane_pad
839#endif // REINTERPRET_INPUT_AS_3D
840#if defined(REINTERPRET_OUTPUT_AS_3D)
841 ,
842 uint dst_cross_plane_pad
843#endif // REINTERPRET_OUTPUT_AS_3D
844 )
Gian Marco05288a22017-11-21 10:57:50 +0000845{
846 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
847
848 // Compute starting address for matrix A and Matrix B
849 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
850
851 // Update address for the matrix A
852 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
853
854 // Update address for the matrix B
855 src_addr.s1 += idx;
856
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100857#if defined(REINTERPRET_INPUT_AS_3D)
858 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
859 // in order to take into account the presence of possible cross plane paddings
860 //
861 // | |
862 // | plane0 |
863 // | |
864 // |__________________|
865 // |******************|
866 // | cross_plane_pad |
867 // |******************|
868 // | |
869 // | plane1 |
870 // | |
871 // |__________________|
872
873 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
874 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
875 zin = min(DEPTH_GEMM3D - 1, zin);
876
877 // Add offset due to the cross plane paddings
878 zin *= (src_cross_plane_pad * src0_stride_y);
879
880 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
881 // multiply src0_stride_z by DEPTH_GEMM3D
882 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
883
884#else // defined(REINTERPRET_INPUT_AS_3D)
885
886 // Add offset for batched GEMM
887 src_addr.s0 += get_global_id(2) * src0_stride_z;
888
889#endif // defined(REINTERPRET_INPUT_AS_3D)
890
891#if defined(MATRIX_B_DEPTH)
892 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
893 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
894#else // defined(MATRIX_B_DEPTH)
895 src_addr.s1 += get_global_id(2) * src1_stride_z;
896#endif // defined(MATRIX_B_DEPTH)
897
Gian Marco05288a22017-11-21 10:57:50 +0000898 int end_row_vec_a = src_addr.s0 + COLS_A;
899
900 VECTOR_UINT acc0 = 0;
901#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
902 VECTOR_UINT acc1 = 0;
903#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
904#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
905 VECTOR_UINT acc2 = 0;
906#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
907#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
908 VECTOR_UINT acc3 = 0;
909#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000910#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
911 VECTOR_UINT acc4 = 0;
912#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000913
914 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
915 {
916 // Load values from matrix A
917 uchar2 a0 = vload2(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
918#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
919 uchar2 a1 = vload2(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
920#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
921#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
922 uchar2 a2 = vload2(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
923#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
924#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
925 uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
926#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000927#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
928 uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
929#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000930 // Load values from matrix B
931 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
932 VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y);
933
934 // Accumulate
935 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0.s0;
936 acc0 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a0.s1;
937#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
938 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1.s0;
939 acc1 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a1.s1;
940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
941#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
942 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2.s0;
943 acc2 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a2.s1;
944#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
945#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
946 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0;
947 acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1;
948#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000949#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
950 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0;
951 acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1;
952#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000953 }
954
955 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
956 {
957 // Load values from matrix A
958 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
959#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
960 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
961#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
962#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
963 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
964#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
965#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
966 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
967#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
969 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000971 // Load values from matrix B
972 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
973
974 // Accumulate
975 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0;
976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
977 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1;
978#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
980 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2;
981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
983 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3;
984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000985#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
986 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4;
987#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000988 }
989
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100990 const int z = get_global_id(2);
991
Gian Marco05288a22017-11-21 10:57:50 +0000992 // Compute destination address
993 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
994
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100995#if defined(REINTERPRET_OUTPUT_AS_3D)
996 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
997 // in order to take into account the presence of possible cross plane paddings
998 //
999 // | |
1000 // | plane0 |
1001 // | |
1002 // |__________________|
1003 // |******************|
1004 // | cross_plane_pad |
1005 // |******************|
1006 // | |
1007 // | plane1 |
1008 // | |
1009 // |__________________|
1010
1011 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1012 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
1013 zout = min(DEPTH_GEMM3D - 1, zout);
1014
1015 // Add offset due to the cross plane paddings
1016 zout *= (dst_cross_plane_pad * dst_stride_y);
1017
1018 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1019 // multiply dst_stride_z by DEPTH_GEMM3D
1020 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
1021
Gian Marco05288a22017-11-21 10:57:50 +00001022 // Store the result
1023 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001024 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco05288a22017-11-21 10:57:50 +00001025#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1026 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001027 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco05288a22017-11-21 10:57:50 +00001028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1029#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1030 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001031 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco05288a22017-11-21 10:57:50 +00001032#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1033#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1034 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001035 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco05288a22017-11-21 10:57:50 +00001036#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +00001037#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1038 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001039 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +00001040#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001041
1042#else // defined(REINTERPRET_OUTPUT_AS_3D)
1043 // Add offset for batched GEMM
1044 dst.ptr += z * dst_stride_z;
1045
1046 // Store the result
1047 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1048 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
1049#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1050 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1051 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
1052#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1054 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1055 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
1056#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1057#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1058 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1059 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
1060#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1061#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1062 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1063 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
1064#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1065#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +00001066}
1067
1068/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1069 *
1070 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
1071 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001072 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1073 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1074 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1075 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1076 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1077 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1078 *
Gian Marco7b4d5472018-01-10 15:56:30 +00001079 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
1080 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1081 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1082 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1083 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1084 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1085 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
1086 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1087 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1088 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1089 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1090 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1091 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
1092 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1093 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1094 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1095 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1096 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001097 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1098 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1099 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1100 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1101 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +00001102 */
1103__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
1104 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001105 IMAGE_DECLARATION(dst),
1106 uint src0_stride_z,
1107 uint src1_stride_z,
1108 uint dst_stride_z
1109#if defined(REINTERPRET_INPUT_AS_3D)
1110 ,
1111 uint src_cross_plane_pad
1112#endif // REINTERPRET_INPUT_AS_3D
1113#if defined(REINTERPRET_OUTPUT_AS_3D)
1114 ,
1115 uint dst_cross_plane_pad
1116#endif // REINTERPRET_OUTPUT_AS_3D
1117 )
Gian Marco7b4d5472018-01-10 15:56:30 +00001118{
1119 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1120
1121 // Compute starting address for matrix A and Matrix B
1122 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1123
1124 // Update address for the matrix A
1125 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1126
1127 // Update address for the matrix B
1128 src_addr.s1 += idx;
1129
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001130#if defined(REINTERPRET_INPUT_AS_3D)
1131 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1132 // in order to take into account the presence of possible cross plane paddings
1133 //
1134 // | |
1135 // | plane0 |
1136 // | |
1137 // |__________________|
1138 // |******************|
1139 // | cross_plane_pad |
1140 // |******************|
1141 // | |
1142 // | plane1 |
1143 // | |
1144 // |__________________|
1145
1146 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1147 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1148 zin = min(DEPTH_GEMM3D - 1, zin);
1149
1150 // Add offset due to the cross plane paddings
1151 zin *= (src_cross_plane_pad * src0_stride_y);
1152
1153 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1154 // multiply src0_stride_z by DEPTH_GEMM3D
1155 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1156
1157#else // defined(REINTERPRET_INPUT_AS_3D)
1158
1159 // Add offset for batched GEMM
1160 src_addr.s0 += get_global_id(2) * src0_stride_z;
1161
1162#endif // defined(REINTERPRET_INPUT_AS_3D)
1163
1164#if defined(MATRIX_B_DEPTH)
1165 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1166 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1167#else // defined(MATRIX_B_DEPTH)
1168 src_addr.s1 += get_global_id(2) * src1_stride_z;
1169#endif // defined(MATRIX_B_DEPTH)
1170
Gian Marco7b4d5472018-01-10 15:56:30 +00001171 int end_row_vec_a = src_addr.s0 + COLS_A;
1172
1173 uint acc00 = 0;
1174 uint acc01 = 0;
1175 uint acc02 = 0;
1176 uint acc03 = 0;
1177#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1178 uint acc10 = 0;
1179 uint acc11 = 0;
1180 uint acc12 = 0;
1181 uint acc13 = 0;
1182#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1183#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1184 uint acc20 = 0;
1185 uint acc21 = 0;
1186 uint acc22 = 0;
1187 uint acc23 = 0;
1188#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1189#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1190 uint acc30 = 0;
1191 uint acc31 = 0;
1192 uint acc32 = 0;
1193 uint acc33 = 0;
1194#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1196 uint acc40 = 0;
1197 uint acc41 = 0;
1198 uint acc42 = 0;
1199 uint acc43 = 0;
1200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1201
1202 for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
1203 {
1204 // Load values from matrix A
1205 uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
1206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1207 uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
1208#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1210 uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
1211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1213 uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
1214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1216 uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
1217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1218 // Load values from matrix B
1219 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1220 uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1221 uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1222 uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1223
1224 {
1225 // Accumulate
1226 ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
1227 ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
1228 ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
1229 ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
1230
1231 ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
1232 ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
1233 ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
1234 ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
1235
1236 ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
1237 ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
1238 ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
1239 ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
1240
1241 ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
1242 ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
1243 ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
1244 ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
1245
1246 acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1247 acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1248 acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1249 acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1250 }
1251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1252 {
1253 // Accumulate
1254 ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
1255 ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
1256 ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
1257 ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
1258
1259 ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
1260 ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
1261 ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
1262 ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
1263
1264 ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
1265 ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
1266 ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
1267 ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
1268
1269 ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
1270 ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
1271 ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
1272 ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
1273
1274 acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1275 acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1276 acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1277 acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1278 }
1279#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1280#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1281 {
1282 // Accumulate
1283 ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
1284 ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
1285 ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
1286 ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
1287
1288 ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
1289 ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
1290 ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
1291 ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
1292
1293 ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
1294 ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
1295 ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
1296 ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
1297
1298 ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
1299 ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
1300 ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
1301 ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
1302
1303 acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1304 acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1305 acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1306 acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1307 }
1308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1309#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1310 {
1311 // Accumulate
1312 ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
1313 ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
1314 ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
1315 ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
1316
1317 ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
1318 ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
1319 ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
1320 ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
1321
1322 ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
1323 ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
1324 ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
1325 ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
1326
1327 ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
1328 ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
1329 ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
1330 ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
1331
1332 acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1333 acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1334 acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1335 acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1336 }
1337#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1338#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1339 {
1340 // Accumulate
1341 ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
1342 ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
1343 ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
1344 ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
1345
1346 ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
1347 ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
1348 ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
1349 ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
1350
1351 ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
1352 ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
1353 ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
1354 ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
1355
1356 ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
1357 ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
1358 ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
1359 ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
1360
1361 acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
1362 acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
1363 acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
1364 acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
1365 }
1366#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1367 }
1368
1369 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1370 {
1371 // Load values from matrix A
1372 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
1373#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1374 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
1375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1376#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1377 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
1378#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1379#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1380 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
1381#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1382#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1383 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
1384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1385 // Load values from matrix B
1386 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
1387
1388 // Accumulate
1389 {
1390 // Accumulate
1391 ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
1392 ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
1393 ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
1394 ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
1395
1396 acc00 += ((uint)tmp0);
1397 acc01 += ((uint)tmp1);
1398 acc02 += ((uint)tmp2);
1399 acc03 += ((uint)tmp3);
1400 }
1401#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1402 {
1403 // Accumulate
1404 ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
1405 ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
1406 ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
1407 ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
1408
1409 acc10 += ((uint)tmp0);
1410 acc11 += ((uint)tmp1);
1411 acc12 += ((uint)tmp2);
1412 acc13 += ((uint)tmp3);
1413 }
1414#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1416 {
1417 // Accumulate
1418 ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
1419 ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
1420 ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
1421 ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
1422
1423 acc20 += ((uint)tmp0);
1424 acc21 += ((uint)tmp1);
1425 acc22 += ((uint)tmp2);
1426 acc23 += ((uint)tmp3);
1427 }
1428#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1429#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1430 {
1431 // Accumulate
1432 ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
1433 ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
1434 ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
1435 ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
1436
1437 acc30 += ((uint)tmp0);
1438 acc31 += ((uint)tmp1);
1439 acc32 += ((uint)tmp2);
1440 acc33 += ((uint)tmp3);
1441 }
1442#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1443#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1444 {
1445 // Accumulate
1446 ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
1447 ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
1448 ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
1449 ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
1450
1451 acc40 += ((uint)tmp0);
1452 acc41 += ((uint)tmp1);
1453 acc42 += ((uint)tmp2);
1454 acc43 += ((uint)tmp3);
1455 }
1456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1457 }
1458
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001459 const int z = get_global_id(2);
1460
Gian Marco7b4d5472018-01-10 15:56:30 +00001461 // Compute destination address
1462 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1463
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001464#if defined(REINTERPRET_OUTPUT_AS_3D)
1465 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1466 // in order to take into account the presence of possible cross plane paddings
1467 //
1468 // | |
1469 // | plane0 |
1470 // | |
1471 // |__________________|
1472 // |******************|
1473 // | cross_plane_pad |
1474 // |******************|
1475 // | |
1476 // | plane1 |
1477 // | |
1478 // |__________________|
1479
1480 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1481 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
1482 zout = min(DEPTH_GEMM3D - 1, zout);
1483
1484 // Add offset due to the cross plane paddings
1485 zout *= (dst_cross_plane_pad * dst_stride_y);
1486
1487 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1488 // multiply dst_stride_z by DEPTH_GEMM3D
1489 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
1490
Gian Marco7b4d5472018-01-10 15:56:30 +00001491 // Store the result
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001492 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco7b4d5472018-01-10 15:56:30 +00001493#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001494 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco7b4d5472018-01-10 15:56:30 +00001495#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1496#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001497 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco7b4d5472018-01-10 15:56:30 +00001498#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1499#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001500 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco7b4d5472018-01-10 15:56:30 +00001501#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1502#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001503 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +00001504#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001505
1506#else // defined(REINTERPRET_OUTPUT_AS_3D)
1507 // Add offset for batched GEMM
1508 dst.ptr += z * dst_stride_z;
1509
1510 // Store the result
1511 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
1512#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1513 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
1514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1515#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1516 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
1517#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1518#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1519 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
1520#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1521#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1522 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
1523#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
1524#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +00001525}
Giorgio Arena6200fa42018-07-06 17:06:36 +01001526
Georgios Pinitasdaa38552018-08-28 17:43:18 +01001527#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001528/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1529 *
1530 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
1531 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001532 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1533 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1534 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1535 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1536 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1537 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1538 *
Giorgio Arena6200fa42018-07-06 17:06:36 +01001539 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
1540 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1541 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1542 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1543 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1544 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1545 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
1546 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1547 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1548 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1549 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1550 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1551 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
1552 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1553 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1554 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1555 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1556 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001557 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1558 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1559 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1560 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1561 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001562 */
1563__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
1564 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001565 IMAGE_DECLARATION(dst),
1566 uint src0_stride_z,
1567 uint src1_stride_z,
1568 uint dst_stride_z
1569#if defined(REINTERPRET_INPUT_AS_3D)
1570 ,
1571 uint src_cross_plane_pad
1572#endif // REINTERPRET_INPUT_AS_3D
1573#if defined(REINTERPRET_OUTPUT_AS_3D)
1574 ,
1575 uint dst_cross_plane_pad
1576#endif // REINTERPRET_OUTPUT_AS_3D)
1577 )
Giorgio Arena6200fa42018-07-06 17:06:36 +01001578{
1579 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1580
1581 // Compute starting address for matrix A and Matrix B
1582 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1583
1584 // Update address for the matrix A
1585 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1586
1587 // Update address for the matrix B
1588 src_addr.s1 += idx;
1589
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001590#if defined(REINTERPRET_INPUT_AS_3D)
1591 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1592 // in order to take into account the presence of possible cross plane paddings
1593 //
1594 // | |
1595 // | plane0 |
1596 // | |
1597 // |__________________|
1598 // |******************|
1599 // | cross_plane_pad |
1600 // |******************|
1601 // | |
1602 // | plane1 |
1603 // | |
1604 // |__________________|
1605
1606 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1607 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1608 zin = min(DEPTH_GEMM3D - 1, zin);
1609
1610 // Add offset due to the cross plane paddings
1611 zin *= (src_cross_plane_pad * src0_stride_y);
1612
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001613 zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
1614
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001615 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1616 // multiply src0_stride_z by DEPTH_GEMM3D
1617 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1618
1619#else // defined(REINTERPRET_INPUT_AS_3D)
1620
1621 // Add offset for batched GEMM
1622 src_addr.s0 += get_global_id(2) * src0_stride_z;
1623
1624#endif // defined(REINTERPRET_INPUT_AS_3D)
1625
1626#if defined(MATRIX_B_DEPTH)
1627 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1628 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1629#else // defined(MATRIX_B_DEPTH)
1630 src_addr.s1 += get_global_id(2) * src1_stride_z;
1631#endif // defined(MATRIX_B_DEPTH)
1632
Giorgio Arena6200fa42018-07-06 17:06:36 +01001633 uint acc00 = 0;
1634 uint acc01 = 0;
1635 uint acc02 = 0;
1636 uint acc03 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001637 uint acc04 = 0;
1638 uint acc05 = 0;
1639 uint acc06 = 0;
1640 uint acc07 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1642 uint acc10 = 0;
1643 uint acc11 = 0;
1644 uint acc12 = 0;
1645 uint acc13 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001646 uint acc14 = 0;
1647 uint acc15 = 0;
1648 uint acc16 = 0;
1649 uint acc17 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001650#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1651#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1652 uint acc20 = 0;
1653 uint acc21 = 0;
1654 uint acc22 = 0;
1655 uint acc23 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001656 uint acc24 = 0;
1657 uint acc25 = 0;
1658 uint acc26 = 0;
1659 uint acc27 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1662 uint acc30 = 0;
1663 uint acc31 = 0;
1664 uint acc32 = 0;
1665 uint acc33 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001666 uint acc34 = 0;
1667 uint acc35 = 0;
1668 uint acc36 = 0;
1669 uint acc37 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001670#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001671
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001672 // A and B src indices get incremented at the same time.
1673 int i = 0;
1674 for(; i <= ((int)COLS_A - 8); i += 8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001675 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001676#if defined(REINTERPRET_INPUT_AS_3D)
1677 // Load values from matrix A and matrix B
1678 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001679#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001680 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001683 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001686 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001687#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001688#else // defined(REINTERPRET_INPUT_AS_3D)
1689 // Load values from matrix A and matrix B
1690 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1691#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1692 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1693#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1694#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1695 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1696#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1697#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1698 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1699#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1700#endif // defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001701
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001702 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1703 uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1704 uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1705 uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1706 src_addr.s1 += 4 * src1_stride_y;
1707
1708 ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1709 ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1710 ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1711 ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1712 ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1713 ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1714 ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1715 ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1716
Giorgio Arena6200fa42018-07-06 17:06:36 +01001717#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001718 ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1719 ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1720 ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1721 ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1722 ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1723 ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1724 ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1725 ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001726#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1727#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001728 ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1729 ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1730 ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1731 ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1732 ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1733 ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1734 ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1735 ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001736#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1737#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001738 ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1739 ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1740 ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1741 ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1742 ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1743 ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1744 ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1745 ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001746#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001747
1748 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1749 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1750 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1751 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1752 src_addr.s1 += 4 * src1_stride_y;
1753
1754 ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1755 ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1756 ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1757 ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1758 ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1759 ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1760 ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1761 ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1762
1763#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1764 ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1765 ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1766 ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1767 ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1768 ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1769 ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1770 ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1771 ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
1772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1774 ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1775 ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1776 ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1777 ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1778 ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1779 ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1780 ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1781 ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
1782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1783#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1784 ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1785 ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1786 ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1787 ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1788 ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1789 ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1790 ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1791 ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
1792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1793
1794 src_addr.s0 += 8;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001795 }
1796
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001797 for(; i < (int)COLS_A; ++i)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001798 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001799#if defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001800 // Load values from matrix A
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001801 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001803 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001806 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001807#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1808#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001809 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001811#else // defined(REINTERPRET_INPUT_AS_3D)
1812 // Load values from matrix A
1813 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1814#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1815 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1816#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1817#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1818 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1819#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1820#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1821 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1822#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1823#endif // defined(REINTERPRET_INPUT_AS_3D)
1824
Giorgio Arena6200fa42018-07-06 17:06:36 +01001825 // Load values from matrix B
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001826 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
1827 src_addr.s1 += src1_stride_y;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001828
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001829 acc00 += (uint)a0 * b0.s0;
1830 acc01 += (uint)a0 * b0.s1;
1831 acc02 += (uint)a0 * b0.s2;
1832 acc03 += (uint)a0 * b0.s3;
1833 acc04 += (uint)a0 * b0.s4;
1834 acc05 += (uint)a0 * b0.s5;
1835 acc06 += (uint)a0 * b0.s6;
1836 acc07 += (uint)a0 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001837
Giorgio Arena6200fa42018-07-06 17:06:36 +01001838#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001839 acc10 += (uint)a1 * b0.s0;
1840 acc11 += (uint)a1 * b0.s1;
1841 acc12 += (uint)a1 * b0.s2;
1842 acc13 += (uint)a1 * b0.s3;
1843 acc14 += (uint)a1 * b0.s4;
1844 acc15 += (uint)a1 * b0.s5;
1845 acc16 += (uint)a1 * b0.s6;
1846 acc17 += (uint)a1 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001847#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1848#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001849 acc20 += (uint)a2 * b0.s0;
1850 acc21 += (uint)a2 * b0.s1;
1851 acc22 += (uint)a2 * b0.s2;
1852 acc23 += (uint)a2 * b0.s3;
1853 acc24 += (uint)a2 * b0.s4;
1854 acc25 += (uint)a2 * b0.s5;
1855 acc26 += (uint)a2 * b0.s6;
1856 acc27 += (uint)a2 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001857#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001859 acc30 += (uint)a3 * b0.s0;
1860 acc31 += (uint)a3 * b0.s1;
1861 acc32 += (uint)a3 * b0.s2;
1862 acc33 += (uint)a3 * b0.s3;
1863 acc34 += (uint)a3 * b0.s4;
1864 acc35 += (uint)a3 * b0.s5;
1865 acc36 += (uint)a3 * b0.s6;
1866 acc37 += (uint)a3 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001867#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001868
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001869 src_addr.s0 += 1;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001870 }
1871
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001872 int z = get_global_id(2);
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001873
Giorgio Arena6200fa42018-07-06 17:06:36 +01001874 // Compute destination address
1875 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1876
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001877 // Compute dst address
1878 __global uchar *dst_addr = dst.ptr;
1879
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001880#if defined(REINTERPRET_OUTPUT_AS_3D)
1881 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1882 // in order to take into account the presence of possible cross plane paddings
1883 //
1884 // | |
1885 // | plane0 |
1886 // | |
1887 // |__________________|
1888 // |******************|
1889 // | cross_plane_pad |
1890 // |******************|
1891 // | |
1892 // | plane1 |
1893 // | |
1894 // |__________________|
1895
1896 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001897 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001898 zout = min(DEPTH_GEMM3D - 1, zout);
1899
1900 // Add offset due to the cross plane paddings
1901 zout *= (dst_cross_plane_pad * dst_stride_y);
1902
1903 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1904 // multiply dst_stride_z by DEPTH_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001905 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001906
Giorgio Arena6200fa42018-07-06 17:06:36 +01001907 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001908 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
1909 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001910#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001911 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
1912 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001913#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1914#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001915 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
1916 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001917#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1918#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001919 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
1920 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001921#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001922
1923#else // defined(REINTERPRET_OUTPUT_AS_3D)
1924 // Add offset for batched GEMM
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001925 dst_addr += z * dst_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001926
1927 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001928 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
1929 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001930#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001931 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
1932 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001933#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1934#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001935 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
1936 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001937#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001939 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
1940 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001942#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1943}
Georgios Pinitasdaa38552018-08-28 17:43:18 +01001944#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco05288a22017-11-21 10:57:50 +00001945#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
1946
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001947#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N)
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001948
1949#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1950
1951#if K0 == 2
1952#define ARM_DOT_K0(a, b, c) \
1953 ({ \
1954 ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \
1955 })
1956#elif K0 == 3 // K0 == 3
1957#define ARM_DOT_K0(a, b, c) \
1958 ({ \
1959 ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \
1960 })
1961#elif K0 == 4 // K0 == 4
1962#define ARM_DOT_K0(a, b, c) \
1963 ({ \
1964 ARM_DOT(a, b, c); \
1965 })
1966#elif K0 == 8 // K0 == 8
1967#define ARM_DOT_K0(a, b, c) \
1968 ({ \
1969 ARM_DOT(a.s0123, b.s0123, c); \
1970 ARM_DOT(a.s4567, b.s4567, c); \
1971 })
1972#elif K0 == 16 // K0 == 16
1973#define ARM_DOT_K0(a, b, c) \
1974 ({ \
1975 ARM_DOT(a.s0123, b.s0123, c); \
1976 ARM_DOT(a.s4567, b.s4567, c); \
1977 ARM_DOT(a.s89AB, b.s89AB, c); \
1978 ARM_DOT(a.sCDEF, b.sCDEF, c); \
1979 })
1980#else // K0 not supported
1981#error "K0 value not supported"
1982#endif // K0
1983
1984#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1985
1986#if K0 == 2
1987#define ARM_DOT_K0(a, b, c) \
1988 ({ \
1989 c += (uint)a.s0 * b.s0; \
1990 c += (uint)a.s1 * b.s1; \
1991 })
1992#elif K0 == 3 // K0 == 3
1993#define ARM_DOT_K0(a, b, c) \
1994 ({ \
1995 c += (uint)a.s0 * b.s0; \
1996 c += (uint)a.s1 * b.s1; \
1997 c += (uint)a.s2 * b.s2; \
1998 })
1999#elif K0 == 4 // K0 == 4
2000#define ARM_DOT_K0(a, b, c) \
2001 ({ \
2002 c += (uint)a.s0 * b.s0; \
2003 c += (uint)a.s1 * b.s1; \
2004 c += (uint)a.s2 * b.s2; \
2005 c += (uint)a.s3 * b.s3; \
2006 })
2007#elif K0 == 8 // K0 == 8
2008#define ARM_DOT_K0(a, b, c) \
2009 ({ \
2010 c += (uint)a.s0 * b.s0; \
2011 c += (uint)a.s1 * b.s1; \
2012 c += (uint)a.s2 * b.s2; \
2013 c += (uint)a.s3 * b.s3; \
2014 c += (uint)a.s4 * b.s4; \
2015 c += (uint)a.s5 * b.s5; \
2016 c += (uint)a.s6 * b.s6; \
2017 c += (uint)a.s7 * b.s7; \
2018 })
2019#elif K0 == 16 // K0 == 16
2020#define ARM_DOT_K0(a, b, c) \
2021 ({ \
2022 c += (uint)a.s0 * b.s0; \
2023 c += (uint)a.s1 * b.s1; \
2024 c += (uint)a.s2 * b.s2; \
2025 c += (uint)a.s3 * b.s3; \
2026 c += (uint)a.s4 * b.s4; \
2027 c += (uint)a.s5 * b.s5; \
2028 c += (uint)a.s6 * b.s6; \
2029 c += (uint)a.s7 * b.s7; \
2030 c += (uint)a.s8 * b.s8; \
2031 c += (uint)a.s9 * b.s9; \
2032 c += (uint)a.sA * b.sA; \
2033 c += (uint)a.sB * b.sB; \
2034 c += (uint)a.sC * b.sC; \
2035 c += (uint)a.sD * b.sD; \
2036 c += (uint)a.sE * b.sE; \
2037 c += (uint)a.sF * b.sF; \
2038 })
2039#else // K0 not supported
2040#error "K0 value not supported"
2041#endif // K0
2042
2043#endif //defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2044
2045#if N0 == 2
2046#define ARM_DOT_K0XN0(a, b, c) \
2047 ({ \
2048 ARM_DOT_K0((a), (b##0), (c.s0)); \
2049 ARM_DOT_K0((a), (b##1), (c.s1)); \
2050 })
2051#elif N0 == 3 // N0 == 3
2052#define ARM_DOT_K0XN0(a, b, c) \
2053 ({ \
2054 ARM_DOT_K0((a), (b##0), (c.s0)); \
2055 ARM_DOT_K0((a), (b##1), (c.s1)); \
2056 ARM_DOT_K0((a), (b##2), (c.s2)); \
2057 })
2058#elif N0 == 4 // N0 == 4
2059#define ARM_DOT_K0XN0(a, b, c) \
2060 ({ \
2061 ARM_DOT_K0((a), (b##0), (c.s0)); \
2062 ARM_DOT_K0((a), (b##1), (c.s1)); \
2063 ARM_DOT_K0((a), (b##2), (c.s2)); \
2064 ARM_DOT_K0((a), (b##3), (c.s3)); \
2065 })
2066#elif N0 == 8 // N0 == 8
2067#define ARM_DOT_K0XN0(a, b, c) \
2068 ({ \
2069 ARM_DOT_K0((a), (b##0), (c.s0)); \
2070 ARM_DOT_K0((a), (b##1), (c.s1)); \
2071 ARM_DOT_K0((a), (b##2), (c.s2)); \
2072 ARM_DOT_K0((a), (b##3), (c.s3)); \
2073 ARM_DOT_K0((a), (b##4), (c.s4)); \
2074 ARM_DOT_K0((a), (b##5), (c.s5)); \
2075 ARM_DOT_K0((a), (b##6), (c.s6)); \
2076 ARM_DOT_K0((a), (b##7), (c.s7)); \
2077 })
2078#elif N0 == 16 // N0 == 16
2079#define ARM_DOT_K0XN0(a, b, c) \
2080 ({ \
2081 ARM_DOT_K0((a), (b##0), (c.s0)); \
2082 ARM_DOT_K0((a), (b##1), (c.s1)); \
2083 ARM_DOT_K0((a), (b##2), (c.s2)); \
2084 ARM_DOT_K0((a), (b##3), (c.s3)); \
2085 ARM_DOT_K0((a), (b##4), (c.s4)); \
2086 ARM_DOT_K0((a), (b##5), (c.s5)); \
2087 ARM_DOT_K0((a), (b##6), (c.s6)); \
2088 ARM_DOT_K0((a), (b##7), (c.s7)); \
2089 ARM_DOT_K0((a), (b##8), (c.s8)); \
2090 ARM_DOT_K0((a), (b##9), (c.s9)); \
2091 ARM_DOT_K0((a), (b##A), (c.sA)); \
2092 ARM_DOT_K0((a), (b##B), (c.sB)); \
2093 ARM_DOT_K0((a), (b##C), (c.sC)); \
2094 ARM_DOT_K0((a), (b##D), (c.sD)); \
2095 ARM_DOT_K0((a), (b##E), (c.sE)); \
2096 ARM_DOT_K0((a), (b##F), (c.sF)); \
2097 })
2098#else // N0 not supported
2099#error "N0 value not supported"
2100#endif // N0 conditions
2101
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002102/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM data type .
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002103 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2104 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2105 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002106 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2107 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002108 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
2109 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
2110 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2111 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2112 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2113 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2114 * - M0 = 2, 3, 4, 5, 6, 7, 8
2115 * - N0 = 2, 3, 4, 8, 16
2116 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002117 * - V0 >= 1
2118 * - H0 >= 1
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002119 *
2120 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2121 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2122 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2123 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2124 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2125 *
2126 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8
2127 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2128 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2129 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2130 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2131 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2132 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2133 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2134 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2135 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2136 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2137 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2138 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2139 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2140 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2141 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2142 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2143 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2144 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2145 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2146 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2147 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2148 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2149 */
2150__kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
2151 IMAGE_DECLARATION(rhs),
2152 IMAGE_DECLARATION(dst),
2153 uint k,
2154 uint lhs_stride_z,
2155 uint rhs_stride_z,
2156 uint dst_stride_z
2157#if defined(REINTERPRET_OUTPUT_AS_3D)
2158 ,
2159 uint dst_cross_plane_pad
2160#endif // REINTERPRET_OUTPUT_AS_3D
2161 )
2162{
2163 // Block size
2164#define LHS_BLOCK_SIZE ((K0) * (M0))
2165
2166#if defined(LHS_INTERLEAVE)
2167#define LHS_OFFSET_X (K0)
2168#define LHS_STEP_X ((K0) * (V0))
2169#define LHS_STEP_LOOP (1)
2170#else // defined(INTERLEAVE)
2171#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2172#define LHS_STEP_X (K0)
2173#define LHS_STEP_LOOP (V0)
2174#endif // defined(INTERLEAVE)
2175
2176 // Block size
2177#define RHS_BLOCK_SIZE ((K0) * (N0))
2178
2179 // RHS offset and step X
2180#if defined(RHS_INTERLEAVE)
2181#define RHS_OFFSET_X (K0)
2182#define RHS_STEP_X ((K0) * (H0))
2183#define RHS_STEP_LOOP (1)
2184#else // defined(RHS_INTERLEAVE)
2185#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2186#define RHS_STEP_X (K0)
2187#define RHS_STEP_LOOP (H0)
2188#endif // defined(RHS_INTERLEAVE)
2189
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002190#if defined(DUMMY_WORK_ITEMS)
2191 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
2192 {
2193 return;
2194 }
2195#endif // defined(DUMMY_WORK_ITEMS)
2196
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002197 // Compute LHS matrix address
2198 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X + (get_global_id(1) / V0) * (uint)lhs_stride_y + (get_global_id(
2199 2)
2200 * lhs_stride_z);
2201
2202 // Compute RHS matrix address
2203 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X + (get_global_id(0) / (uint)H0) * rhs_stride_y;
2204
2205#if defined(MATRIX_B_DEPTH)
2206 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2207 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
2208#else // defined(MATRIX_B_DEPTH)
2209 rhs_addr += get_global_id(2) * rhs_stride_z;
2210#endif // defined(MATRIX_B_DEPTH)
2211
2212 // Initialize the accumulators
2213 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
2214
2215 for(int i = 0; i < k; i += K0)
2216 {
2217 // Supported cases (M0, K0):
2218 // 2,4 - 2,8 - 2,16
2219 // 3,4 - 3,8 - 3,16
2220 // 4,4 - 4,8 - 4,16
2221 // 5,4 - 5,8 - 5,16
2222 // 6,4 - 6,8 - 6,16
2223 // Load values from LHS matrix
2224 VEC_DATA_TYPE(uchar, K0)
2225 a0 = VLOAD(K0)(0, lhs_addr + 0 * LHS_STEP_X);
2226#if M0 > 1
2227 VEC_DATA_TYPE(uchar, K0)
2228 a1 = VLOAD(K0)(0, lhs_addr + 1 * LHS_STEP_X);
2229#endif // M0 > 1
2230#if M0 > 2
2231 VEC_DATA_TYPE(uchar, K0)
2232 a2 = VLOAD(K0)(0, lhs_addr + 2 * LHS_STEP_X);
2233#endif // M0 > 2
2234#if M0 > 3
2235 VEC_DATA_TYPE(uchar, K0)
2236 a3 = VLOAD(K0)(0, lhs_addr + 3 * LHS_STEP_X);
2237#endif // M0 > 3
2238#if M0 > 4
2239 VEC_DATA_TYPE(uchar, K0)
2240 a4 = VLOAD(K0)(0, lhs_addr + 4 * LHS_STEP_X);
2241#endif // M0 > 4
2242#if M0 > 5
2243 VEC_DATA_TYPE(uchar, K0)
2244 a5 = VLOAD(K0)(0, lhs_addr + 5 * LHS_STEP_X);
2245#endif // M0 > 5
2246#if M0 > 6
2247 VEC_DATA_TYPE(uchar, K0)
2248 a6 = VLOAD(K0)(0, lhs_addr + 6 * LHS_STEP_X);
2249#endif // M0 > 6
2250#if M0 > 7
2251 VEC_DATA_TYPE(uchar, K0)
2252 a7 = VLOAD(K0)(0, lhs_addr + 7 * LHS_STEP_X);
2253#endif // M0 > 7
2254
2255 // Load values from RHS matrix
2256 VEC_DATA_TYPE(uchar, K0)
2257 b0 = VLOAD(K0)(0, rhs_addr + 0 * RHS_STEP_X);
2258 VEC_DATA_TYPE(uchar, K0)
2259 b1 = VLOAD(K0)(0, rhs_addr + 1 * RHS_STEP_X);
2260#if N0 > 2
2261 VEC_DATA_TYPE(uchar, K0)
2262 b2 = VLOAD(K0)(0, rhs_addr + 2 * RHS_STEP_X);
2263#endif // N0 > 2
2264#if N0 > 3
2265 VEC_DATA_TYPE(uchar, K0)
2266 b3 = VLOAD(K0)(0, rhs_addr + 3 * RHS_STEP_X);
2267#endif // N0 > 3
2268#if N0 > 4
2269 VEC_DATA_TYPE(uchar, K0)
2270 b4 = VLOAD(K0)(0, rhs_addr + 4 * RHS_STEP_X);
2271 VEC_DATA_TYPE(uchar, K0)
2272 b5 = VLOAD(K0)(0, rhs_addr + 5 * RHS_STEP_X);
2273 VEC_DATA_TYPE(uchar, K0)
2274 b6 = VLOAD(K0)(0, rhs_addr + 6 * RHS_STEP_X);
2275 VEC_DATA_TYPE(uchar, K0)
2276 b7 = VLOAD(K0)(0, rhs_addr + 7 * RHS_STEP_X);
2277#endif // N0 > 4
2278#if N0 > 8
2279 VEC_DATA_TYPE(uchar, K0)
2280 b8 = VLOAD(K0)(0, rhs_addr + 8 * RHS_STEP_X);
2281 VEC_DATA_TYPE(uchar, K0)
2282 b9 = VLOAD(K0)(0, rhs_addr + 9 * RHS_STEP_X);
2283 VEC_DATA_TYPE(uchar, K0)
2284 bA = VLOAD(K0)(0, rhs_addr + 10 * RHS_STEP_X);
2285 VEC_DATA_TYPE(uchar, K0)
2286 bB = VLOAD(K0)(0, rhs_addr + 11 * RHS_STEP_X);
2287 VEC_DATA_TYPE(uchar, K0)
2288 bC = VLOAD(K0)(0, rhs_addr + 12 * RHS_STEP_X);
2289 VEC_DATA_TYPE(uchar, K0)
2290 bD = VLOAD(K0)(0, rhs_addr + 13 * RHS_STEP_X);
2291 VEC_DATA_TYPE(uchar, K0)
2292 bE = VLOAD(K0)(0, rhs_addr + 14 * RHS_STEP_X);
2293 VEC_DATA_TYPE(uchar, K0)
2294 bF = VLOAD(K0)(0, rhs_addr + 15 * RHS_STEP_X);
2295#endif // N0 > 8
2296
2297 // Accumulate
2298 ARM_DOT_K0XN0(a0, b, c0);
2299#if M0 > 1
2300 ARM_DOT_K0XN0(a1, b, c1);
2301#endif // M0 > 1
2302#if M0 > 2
2303 ARM_DOT_K0XN0(a2, b, c2);
2304#endif // M0 > 2
2305#if M0 > 3
2306 ARM_DOT_K0XN0(a3, b, c3);
2307#endif // M0 > 3
2308#if M0 > 4
2309 ARM_DOT_K0XN0(a4, b, c4);
2310#endif // M0 > 4
2311#if M0 > 5
2312 ARM_DOT_K0XN0(a5, b, c5);
2313#endif // M0 > 5
2314#if M0 > 6
2315 ARM_DOT_K0XN0(a6, b, c6);
2316#endif // M0 > 6
2317#if M0 > 7
2318 ARM_DOT_K0XN0(a7, b, c7);
2319#endif // M0 > 7
2320
2321 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP);
2322 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP);
2323 }
2324
2325 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(int)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2326
2327 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2328
2329#if defined(REINTERPRET_OUTPUT_AS_3D)
2330 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2331 // in order to take into account the presence of possible cross plane paddings
2332 //
2333 // | |
2334 // | plane0 |
2335 // | |
2336 // |__________________|
2337 // |******************|
2338 // | cross_plane_pad |
2339 // |******************|
2340 // | |
2341 // | plane1 |
2342 // | |
2343 // |__________________|
2344
2345 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2346 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2347 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
2348 zout0 *= (dst_cross_plane_pad * dst_stride_y);
2349#if M0 > 1
2350 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2351 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
2352 zout1 *= (dst_cross_plane_pad * dst_stride_y);
2353#endif // M0 > 1
2354#if M0 > 2
2355 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2356 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
2357 zout2 *= (dst_cross_plane_pad * dst_stride_y);
2358#endif // M0 > 2
2359#if M0 > 3
2360 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2361 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
2362 zout3 *= (dst_cross_plane_pad * dst_stride_y);
2363#endif // M0 > 3
2364#if M0 > 4
2365 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2366 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
2367 zout4 *= (dst_cross_plane_pad * dst_stride_y);
2368#endif // M0 > 4
2369#if M0 > 5
2370 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2371 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
2372 zout5 *= (dst_cross_plane_pad * dst_stride_y);
2373#endif // M0 > 5
2374#if M0 > 6
2375 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2376 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
2377 zout6 *= (dst_cross_plane_pad * dst_stride_y);
2378#endif // M0 > 6
2379#if M0 > 7
2380 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2381 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2382 zout7 *= (dst_cross_plane_pad * dst_stride_y);
2383#endif // M0 > 7
2384
2385 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2386 // multiply dst_stride_z by DEPTH_GEMM3D
2387 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2388
2389#else // defined(REINTERPRET_OUTPUT_AS_3D)
2390
2391 // Add offset for batched GEMM
2392 dst_addr += get_global_id(2) * dst_stride_z;
2393
2394#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2395
2396 // Store output block
2397 VSTORE(N0)
2398 (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0));
2399#if M0 > 1
2400 VSTORE(N0)
2401 (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1));
2402#endif // M0 > 1
2403#if M0 > 2
2404 VSTORE(N0)
2405 (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2));
2406#endif // M0 > 2
2407#if M0 > 3
2408 VSTORE(N0)
2409 (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3));
2410#endif // M0 > 3
2411#if M0 > 4
2412 VSTORE(N0)
2413 (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4));
2414#endif // M0 > 4
2415#if M0 > 5
2416 VSTORE(N0)
2417 (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5));
2418#endif // M0 > 5
2419#if M0 > 6
2420 VSTORE(N0)
2421 (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6));
2422#endif // M0 > 6
2423#if M0 > 7
2424 VSTORE(N0)
2425 (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7));
2426#endif // M0 > 7
2427
2428#undef LHS_BLOCK_SIZE
2429#undef LHS_OFFSET_X
2430#undef LHS_STEP_X
2431#undef RHS_BLOCK_SIZE
2432#undef RHS_OFFSET_X
2433#undef RHS_STEP_X
2434}
2435
2436#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002437/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM8 data type using the dot8 instruction.
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002438 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2439 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2440 *
2441 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
2442 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
2443 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2444 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2445 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2446 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2447 * - M0 = 2, 3, 4, 5, 6, 7, 8
2448 * - N0 = 2, 3, 4, 8, 16
2449 * - K0 = 2, 3, 4, 8, 16
2450 *
2451 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2452 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2453 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2454 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2455 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2456 *
2457 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8
2458 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2459 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2460 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2461 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2462 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2463 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2464 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2465 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2466 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2467 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2468 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2469 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2470 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2471 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2472 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2473 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2474 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2475 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2476 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2477 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2478 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2479 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2480 */
2481__kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t_dot8(IMAGE_DECLARATION(lhs),
2482 IMAGE_DECLARATION(rhs),
2483 IMAGE_DECLARATION(dst),
2484 uint k,
2485 uint lhs_stride_z,
2486 uint rhs_stride_z,
2487 uint dst_stride_z
2488#if defined(REINTERPRET_OUTPUT_AS_3D)
2489 ,
2490 uint dst_cross_plane_pad
2491#endif // REINTERPRET_OUTPUT_AS_3D
2492 )
2493{
2494 // Note: ARM_DOT_K0XN0 is generated with the dot8 instruction
2495 gemmlowp_mm_reshaped_lhs_nt_rhs_t(lhs_ptr,
2496 lhs_stride_x,
2497 lhs_step_x,
2498 lhs_stride_y,
2499 lhs_step_y,
2500 lhs_offset_first_element_in_bytes,
2501 rhs_ptr,
2502 rhs_stride_x,
2503 rhs_step_x,
2504 rhs_stride_y,
2505 rhs_step_y,
2506 rhs_offset_first_element_in_bytes,
2507 dst_ptr,
2508 dst_stride_x,
2509 dst_step_x,
2510 dst_stride_y,
2511 dst_step_y,
2512 dst_offset_first_element_in_bytes,
2513 k,
2514 lhs_stride_z,
2515 rhs_stride_z,
2516 dst_stride_z
2517#if defined(REINTERPRET_OUTPUT_AS_3D)
2518 ,
2519 dst_cross_plane_pad
2520#endif // REINTERPRET_OUTPUT_AS_3D
2521 );
2522}
2523#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2524#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K)
2525
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002526#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(K)
2527
2528#define CONCAT(a, b) a##b
2529
2530#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2531
2532#define ARM_DOT1(a, b, c) \
2533 ({ \
2534 ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \
2535 })
2536#define ARM_DOT2(a, b, c) \
2537 ({ \
2538 ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \
2539 })
2540#define ARM_DOT3(a, b, c) \
2541 ({ \
2542 ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \
2543 })
2544#define ARM_DOT4(a, b, c) \
2545 ({ \
2546 ARM_DOT(a, b, c); \
2547 })
2548#define ARM_DOT8(a, b, c) \
2549 ({ \
2550 ARM_DOT4((a.lo), (b.lo), c); \
2551 ARM_DOT4((a.hi), (b.hi), c); \
2552 })
2553#define ARM_DOT16(a, b, c) \
2554 ({ \
2555 ARM_DOT8((a.lo), (b.lo), c); \
2556 ARM_DOT8((a.hi), (b.hi), c); \
2557 })
2558
2559#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2560
2561#define ARM_DOT1(a, b, c) \
2562 ({ \
2563 c += (uint)a.s0 * b.s0; \
2564 })
2565#define ARM_DOT2(a, b, c) \
2566 ({ \
2567 ARM_DOT1(a, b, c); \
2568 c += (uint)a.s1 * b.s1; \
2569 })
2570#define ARM_DOT3(a, b, c) \
2571 ({ \
2572 ARM_DOT2(a, b, c); \
2573 c += (uint)a.s2 * b.s2; \
2574 })
2575#define ARM_DOT4(a, b, c) \
2576 ({ \
2577 ARM_DOT3(a, b, c); \
2578 c += (uint)a.s3 * b.s3; \
2579 })
2580#define ARM_DOT8(a, b, c) \
2581 ({ \
2582 ARM_DOT4((a.lo), (b.lo), c); \
2583 ARM_DOT4((a.hi), (b.hi), c); \
2584 })
2585#define ARM_DOT16(a, b, c) \
2586 ({ \
2587 ARM_DOT8((a.lo), (b.lo), c); \
2588 ARM_DOT8((a.hi), (b.hi), c); \
2589 })
2590#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2591
2592#if N0 == 2
2593#define ARM_DOT_K0XN0(k0, a, b, c) \
2594 ({ \
2595 CONCAT(ARM_DOT, k0) \
2596 ((a), (b##0), (c.s0)); \
2597 CONCAT(ARM_DOT, k0) \
2598 ((a), (b##1), (c.s1)); \
2599 })
2600#elif N0 == 3 // N0 == 3
2601#define ARM_DOT_K0XN0(k0, a, b, c) \
2602 ({ \
2603 CONCAT(ARM_DOT, k0) \
2604 ((a), (b##0), (c.s0)); \
2605 CONCAT(ARM_DOT, k0) \
2606 ((a), (b##1), (c.s1)); \
2607 CONCAT(ARM_DOT, k0) \
2608 ((a), (b##2), (c.s2)); \
2609 })
2610#elif N0 == 4 // N0 == 4
2611#define ARM_DOT_K0XN0(k0, a, b, c) \
2612 ({ \
2613 CONCAT(ARM_DOT, k0) \
2614 ((a), (b##0), (c.s0)); \
2615 CONCAT(ARM_DOT, k0) \
2616 ((a), (b##1), (c.s1)); \
2617 CONCAT(ARM_DOT, k0) \
2618 ((a), (b##2), (c.s2)); \
2619 CONCAT(ARM_DOT, k0) \
2620 ((a), (b##3), (c.s3)); \
2621 })
2622#elif N0 == 8 // N0 == 8
2623#define ARM_DOT_K0XN0(k0, a, b, c) \
2624 ({ \
2625 CONCAT(ARM_DOT, k0) \
2626 ((a), (b##0), (c.s0)); \
2627 CONCAT(ARM_DOT, k0) \
2628 ((a), (b##1), (c.s1)); \
2629 CONCAT(ARM_DOT, k0) \
2630 ((a), (b##2), (c.s2)); \
2631 CONCAT(ARM_DOT, k0) \
2632 ((a), (b##3), (c.s3)); \
2633 CONCAT(ARM_DOT, k0) \
2634 ((a), (b##4), (c.s4)); \
2635 CONCAT(ARM_DOT, k0) \
2636 ((a), (b##5), (c.s5)); \
2637 CONCAT(ARM_DOT, k0) \
2638 ((a), (b##6), (c.s6)); \
2639 CONCAT(ARM_DOT, k0) \
2640 ((a), (b##7), (c.s7)); \
2641 })
2642#elif N0 == 16 // N0 == 16
2643#define ARM_DOT_K0XN0(k0, a, b, c) \
2644 ({ \
2645 CONCAT(ARM_DOT, k0) \
2646 ((a), (b##0), (c.s0)); \
2647 CONCAT(ARM_DOT, k0) \
2648 ((a), (b##1), (c.s1)); \
2649 CONCAT(ARM_DOT, k0) \
2650 ((a), (b##2), (c.s2)); \
2651 CONCAT(ARM_DOT, k0) \
2652 ((a), (b##3), (c.s3)); \
2653 CONCAT(ARM_DOT, k0) \
2654 ((a), (b##4), (c.s4)); \
2655 CONCAT(ARM_DOT, k0) \
2656 ((a), (b##5), (c.s5)); \
2657 CONCAT(ARM_DOT, k0) \
2658 ((a), (b##6), (c.s6)); \
2659 CONCAT(ARM_DOT, k0) \
2660 ((a), (b##7), (c.s7)); \
2661 CONCAT(ARM_DOT, k0) \
2662 ((a), (b##8), (c.s8)); \
2663 CONCAT(ARM_DOT, k0) \
2664 ((a), (b##9), (c.s9)); \
2665 CONCAT(ARM_DOT, k0) \
2666 ((a), (b##A), (c.sA)); \
2667 CONCAT(ARM_DOT, k0) \
2668 ((a), (b##B), (c.sB)); \
2669 CONCAT(ARM_DOT, k0) \
2670 ((a), (b##C), (c.sC)); \
2671 CONCAT(ARM_DOT, k0) \
2672 ((a), (b##D), (c.sD)); \
2673 CONCAT(ARM_DOT, k0) \
2674 ((a), (b##E), (c.sE)); \
2675 CONCAT(ARM_DOT, k0) \
2676 ((a), (b##F), (c.sF)); \
2677 })
2678#else // N0 not supported
2679#error "N0 value not supported"
2680#endif // N0 conditions
2681
2682/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2683 * The LHS matrix is NOT reshaped
2684 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
2685 *
2686 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
2687 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
2688 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
2689 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2690 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2691 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2692 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2693 * - N0 = 2, 3, 4, 8, 16
2694 * - K0 = 2, 3, 4, 8, 16
2695 * - H0 >= 1
2696 *
2697 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2698 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2699 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2700 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2701 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2702 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2703 *
2704 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2705 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2706 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2707 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2708 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2709 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2710 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2711 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2712 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2713 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2714 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2715 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2716 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2717 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2718 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2719 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2720 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2721 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2722 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2723 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2724 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2725 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2726 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2727 */
2728__kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
2729 IMAGE_DECLARATION(rhs),
2730 IMAGE_DECLARATION(dst),
2731 uint lhs_stride_z,
2732 uint rhs_stride_z,
2733 uint dst_stride_z
2734#if defined(REINTERPRET_INPUT_AS_3D)
2735 ,
2736 uint lhs_cross_plane_pad
2737#endif // REINTERPRET_INPUT_AS_3D
2738#if defined(REINTERPRET_OUTPUT_AS_3D)
2739 ,
2740 uint dst_cross_plane_pad
2741#endif // REINTERPRET_OUTPUT_AS_3D
2742 )
2743{
2744 // Block size
2745#define RHS_BLOCK_SIZE ((K0) * (N0))
2746
2747 // RHS offset and step X
2748#if defined(RHS_INTERLEAVE)
2749#define RHS_OFFSET_X (K0)
2750#define RHS_STEP_X ((K0) * (H0))
2751#define RHS_STEP_LOOP (1)
2752#else // defined(RHS_INTERLEAVE)
2753#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2754#define RHS_STEP_X (K0)
2755#define RHS_STEP_LOOP (H0)
2756#endif // defined(RHS_INTERLEAVE)
2757
2758 uint x = get_global_id(0);
2759 uint y = get_global_id(1);
2760 uint z = get_global_id(2);
2761
Gian Marco Iodice86cfffe2019-04-02 11:02:20 +01002762#if defined(DUMMY_WORK_ITEMS)
2763 if((x * N0 >= N) || (y * M0 >= M))
2764 {
2765 return;
2766 }
2767#endif // defined(DUMMY_WORK_ITEMS)
2768
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002769 // Compute LHS matrix address
2770 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2771
2772 // Compute RHS matrix address
2773 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X + (x / (uint)H0) * rhs_stride_y;
2774
2775#if defined(MATRIX_B_DEPTH)
2776 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2777 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2778#else // defined(MATRIX_B_DEPTH)
2779 rhs_offset += z * rhs_stride_z;
2780#endif // defined(MATRIX_B_DEPTH)
2781
2782 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2783
2784#if defined(REINTERPRET_INPUT_AS_3D)
2785 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2786 // in order to take into account the presence of possible cross plane paddings
2787 //
2788 // | |
2789 // | plane0 |
2790 // | |
2791 // |__________________|
2792 // |******************|
2793 // | cross_plane_pad |
2794 // |******************|
2795 // | |
2796 // | plane1 |
2797 // | |
2798 // |__________________|
2799
2800 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2801 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2802 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
2803 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
2804#if M0 > 1
2805 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2806 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
2807 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
2808#endif // M0 > 1
2809#if M0 > 2
2810 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2811 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
2812 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
2813#endif // M0 > 2
2814#if M0 > 3
2815 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2816 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
2817 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
2818#endif // M0 > 3
2819#if M0 > 4
2820 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2821 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
2822 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
2823#endif // M0 > 4
2824#if M0 > 5
2825 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2826 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
2827 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
2828#endif // M0 > 5
2829#if M0 > 6
2830 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2831 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
2832 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
2833#endif // M0 > 6
2834#if M0 > 7
2835 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2836 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2837 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
2838#endif // M0 > 7
2839
2840 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2841 // multiply lhs_stride_z by DEPTH_GEMM3D
2842 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2843
2844#else // defined(REINTERPRET_INPUT_AS_3D)
2845
2846 // Add offset for batched GEMM
2847 lhs_offset += z * lhs_stride_z;
2848
2849#endif // defined(REINTERPRET_INPUT_AS_3D)
2850
2851 // Initialize the accumulators
2852 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2853
2854 for(int i = 0; i < K; i += K0)
2855 {
2856 // Supported cases (M0, K0):
2857 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2858 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2859 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2860 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2861 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2862 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2863 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2864 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2865 // Load values from LHS matrix
2866 VEC_DATA_TYPE(uchar, K0)
2867 a0 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0);
2868#if M0 > 1
2869 VEC_DATA_TYPE(uchar, K0)
2870 a1 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1);
2871#endif // M0 > 1
2872#if M0 > 2
2873 VEC_DATA_TYPE(uchar, K0)
2874 a2 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2);
2875#endif // M0 > 2
2876#if M0 > 3
2877 VEC_DATA_TYPE(uchar, K0)
2878 a3 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3);
2879#endif // M0 > 3
2880#if M0 > 4
2881 VEC_DATA_TYPE(uchar, K0)
2882 a4 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4);
2883#endif // M0 > 4
2884#if M0 > 5
2885 VEC_DATA_TYPE(uchar, K0)
2886 a5 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5);
2887#endif // M0 > 5
2888#if M0 > 6
2889 VEC_DATA_TYPE(uchar, K0)
2890 a6 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6);
2891#endif // M0 > 6
2892#if M0 > 7
2893 VEC_DATA_TYPE(uchar, K0)
2894 a7 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7);
2895#endif // M0 > 7
2896
2897 // Load values from RHS matrix
2898 VEC_DATA_TYPE(uchar, K0)
2899 b0 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 0 * RHS_STEP_X);
2900 VEC_DATA_TYPE(uchar, K0)
2901 b1 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 1 * RHS_STEP_X);
2902#if N0 > 2
2903 VEC_DATA_TYPE(uchar, K0)
2904 b2 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 2 * RHS_STEP_X);
2905#endif // N0 > 2
2906#if N0 > 3
2907 VEC_DATA_TYPE(uchar, K0)
2908 b3 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 3 * RHS_STEP_X);
2909#endif // N0 > 3
2910#if N0 > 4
2911 VEC_DATA_TYPE(uchar, K0)
2912 b4 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 4 * RHS_STEP_X);
2913 VEC_DATA_TYPE(uchar, K0)
2914 b5 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 5 * RHS_STEP_X);
2915 VEC_DATA_TYPE(uchar, K0)
2916 b6 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 6 * RHS_STEP_X);
2917 VEC_DATA_TYPE(uchar, K0)
2918 b7 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 7 * RHS_STEP_X);
2919#endif // N0 > 4
2920#if N0 > 8
2921 VEC_DATA_TYPE(uchar, K0)
2922 b8 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 8 * RHS_STEP_X);
2923 VEC_DATA_TYPE(uchar, K0)
2924 b9 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 9 * RHS_STEP_X);
2925 VEC_DATA_TYPE(uchar, K0)
2926 bA = VLOAD(K0)(0, rhs_ptr + rhs_offset + 10 * RHS_STEP_X);
2927 VEC_DATA_TYPE(uchar, K0)
2928 bB = VLOAD(K0)(0, rhs_ptr + rhs_offset + 11 * RHS_STEP_X);
2929 VEC_DATA_TYPE(uchar, K0)
2930 bC = VLOAD(K0)(0, rhs_ptr + rhs_offset + 12 * RHS_STEP_X);
2931 VEC_DATA_TYPE(uchar, K0)
2932 bD = VLOAD(K0)(0, rhs_ptr + rhs_offset + 13 * RHS_STEP_X);
2933 VEC_DATA_TYPE(uchar, K0)
2934 bE = VLOAD(K0)(0, rhs_ptr + rhs_offset + 14 * RHS_STEP_X);
2935 VEC_DATA_TYPE(uchar, K0)
2936 bF = VLOAD(K0)(0, rhs_ptr + rhs_offset + 15 * RHS_STEP_X);
2937#endif // N0 > 8
2938
2939 // Accumulate
2940 ARM_DOT_K0XN0(K0, a0, b, c0);
2941#if M0 > 1
2942 ARM_DOT_K0XN0(K0, a1, b, c1);
2943#endif // M0 > 1
2944#if M0 > 2
2945 ARM_DOT_K0XN0(K0, a2, b, c2);
2946#endif // M0 > 2
2947#if M0 > 3
2948 ARM_DOT_K0XN0(K0, a3, b, c3);
2949#endif // M0 > 3
2950#if M0 > 4
2951 ARM_DOT_K0XN0(K0, a4, b, c4);
2952#endif // M0 > 4
2953#if M0 > 5
2954 ARM_DOT_K0XN0(K0, a5, b, c5);
2955#endif // M0 > 5
2956#if M0 > 6
2957 ARM_DOT_K0XN0(K0, a6, b, c6);
2958#endif // M0 > 6
2959#if M0 > 7
2960 ARM_DOT_K0XN0(K0, a7, b, c7);
2961#endif // M0 > 7
2962
2963 lhs_offset += K0;
2964 rhs_offset += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2965 }
2966
2967 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0) * sizeof(int) + (y * (uint)M0 * dst_stride_y);
2968
2969 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2970
2971#if defined(REINTERPRET_OUTPUT_AS_3D)
2972 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2973 // in order to take into account the presence of possible cross plane paddings
2974 //
2975 // | |
2976 // | plane0 |
2977 // | |
2978 // |__________________|
2979 // |******************|
2980 // | cross_plane_pad |
2981 // |******************|
2982 // | |
2983 // | plane1 |
2984 // | |
2985 // |__________________|
2986
2987 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2988 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2989 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
2990 zout0 *= (dst_cross_plane_pad * dst_stride_y);
2991#if M0 > 1
2992 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2993 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
2994 zout1 *= (dst_cross_plane_pad * dst_stride_y);
2995#endif // M0 > 1
2996#if M0 > 2
2997 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2998 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
2999 zout2 *= (dst_cross_plane_pad * dst_stride_y);
3000#endif // M0 > 2
3001#if M0 > 3
3002 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
3003 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
3004 zout3 *= (dst_cross_plane_pad * dst_stride_y);
3005#endif // M0 > 3
3006#if M0 > 4
3007 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
3008 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
3009 zout4 *= (dst_cross_plane_pad * dst_stride_y);
3010#endif // M0 > 4
3011#if M0 > 5
3012 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
3013 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
3014 zout5 *= (dst_cross_plane_pad * dst_stride_y);
3015#endif // M0 > 5
3016#if M0 > 6
3017 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
3018 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
3019 zout6 *= (dst_cross_plane_pad * dst_stride_y);
3020#endif // M0 > 6
3021#if M0 > 7
3022 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
3023 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
3024 zout7 *= (dst_cross_plane_pad * dst_stride_y);
3025#endif // M0 > 7
3026
3027 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3028 // multiply dst_stride_z by DEPTH_GEMM3D
3029 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3030
3031#else // defined(REINTERPRET_OUTPUT_AS_3D)
3032
3033 // Add offset for batched GEMM
3034 dst_addr += z * dst_stride_z;
3035
3036#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3037
3038 // Store output block
3039 VSTORE(N0)
3040 (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0));
3041#if M0 > 1
3042 VSTORE(N0)
3043 (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1));
3044#endif // M0 > 1
3045#if M0 > 2
3046 VSTORE(N0)
3047 (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2));
3048#endif // M0 > 2
3049#if M0 > 3
3050 VSTORE(N0)
3051 (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3));
3052#endif // M0 > 3
3053#if M0 > 4
3054 VSTORE(N0)
3055 (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4));
3056#endif // M0 > 4
3057#if M0 > 5
3058 VSTORE(N0)
3059 (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5));
3060#endif // M0 > 5
3061#if M0 > 6
3062 VSTORE(N0)
3063 (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6));
3064#endif // M0 > 6
3065#if M0 > 7
3066 VSTORE(N0)
3067 (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7));
3068#endif // M0 > 7
3069
3070#undef RHS_BLOCK_SIZE
3071#undef RHS_OFFSET_X
3072#undef RHS_STEP_X
3073}
3074#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(K)
3075
Gian Marco05288a22017-11-21 10:57:50 +00003076#if defined(COLS_A)
3077/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
3078 *
3079 * @note This stage is needed to handle the offset of matrix product
3080 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
3081 *
3082 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
3083 *
3084 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
3085 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3086 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3087 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3088 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3089 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3090 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3091 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3092 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
3093 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3094 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3095 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3096 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3097 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3098 */
3099__kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
3100 IMAGE_DECLARATION(dst))
3101{
3102 // Compute source and destination addresses
3103 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3104 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3105
3106 uint4 sum_row_u32 = (uint4)0;
3107 uint sum_row = 0;
3108
3109 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
3110
3111 int i = 0;
3112
3113 // This for loop performs 16 accumulations
3114 for(; i <= ((int)COLS_A - 16); i += 16)
3115 {
3116 const uchar16 a0_u8 = vload16(0, matrix_a + i);
3117
3118 sum_row_u32 += convert_uint4(a0_u8.s0123) + convert_uint4(a0_u8.s4567) + convert_uint4(a0_u8.s89AB) + convert_uint4(a0_u8.sCDEF);
3119 }
3120
3121 // This for loop performs the leftover accumulations
3122 for(; i < COLS_A; ++i)
3123 {
3124 sum_row += matrix_a[i];
3125 }
3126
3127 sum_row += sum_row_u32.s0 + sum_row_u32.s1 + sum_row_u32.s2 + sum_row_u32.s3;
3128
3129 *((__global int *)dst.ptr) = (int)sum_row;
3130}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003131
3132#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
3133/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction
3134 *
3135 * @note This stage is needed to handle the offset of matrix product
3136 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
3137 *
3138 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
3139 *
3140 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
3141 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3142 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3143 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3144 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3145 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3146 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3147 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3148 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
3149 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3150 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3151 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3152 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3153 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3154 */
3155__kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
3156 IMAGE_DECLARATION(dst))
3157{
3158 // Compute source and destination addresses
3159 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3160 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3161
3162 uint sum_row = 0;
3163
3164 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
3165
3166 int i = 0;
3167
3168 // This for loop performs 16 accumulations
3169 for(; i <= ((int)COLS_A - 32); i += 32)
3170 {
3171 uchar16 a0_u8 = vload16(0, matrix_a + i);
3172
3173 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
3174 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
3175 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
3176 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
3177
3178 a0_u8 = vload16(1, matrix_a + i);
3179
3180 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
3181 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
3182 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
3183 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
3184 }
3185
3186 // This for loop performs the leftover accumulations
3187 for(; i < COLS_A; ++i)
3188 {
3189 sum_row += matrix_a[i];
3190 }
3191
3192 *((__global int *)dst.ptr) = (int)sum_row;
3193}
3194#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco05288a22017-11-21 10:57:50 +00003195#endif // defined(COLS_A)
3196
3197#if defined(COLS_B) && defined(ROWS_B)
3198/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
3199 *
3200 * @note This stage is needed to handle the offset of matrix product
3201 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
3202 *
3203 * @attention The number of matrix B columns and rows needs to be passed at compile time using -DCOLS_B and -DROWS_B
3204 *
3205 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
3206 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3207 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3208 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3209 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3210 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3211 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3212 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3213 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
3214 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3215 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3216 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3217 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3218 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3219 */
3220__kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src),
3221 IMAGE_DECLARATION(dst))
3222{
3223 // Compute source and destination addresses
3224 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3225 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3226
3227 uint16 sum_col_u32 = (uint16)0;
3228
3229 __global const uchar *matrix_b = (__global const uchar *)(src.ptr + get_global_id(1) * src_stride_z);
3230
3231 int i = 0;
3232 // This for loop performs 4 accumulations
3233 for(; i <= ((int)ROWS_B - 4); i += 4)
3234 {
3235 const uchar16 b0_u8 = vload16(0, matrix_b + 0 * src_stride_y);
3236 const uchar16 b1_u8 = vload16(0, matrix_b + 1 * src_stride_y);
3237 const uchar16 b2_u8 = vload16(0, matrix_b + 2 * src_stride_y);
3238 const uchar16 b3_u8 = vload16(0, matrix_b + 3 * src_stride_y);
3239
3240 sum_col_u32 += convert_uint16(b0_u8) + convert_uint16(b1_u8) + convert_uint16(b2_u8) + convert_uint16(b3_u8);
3241
3242 matrix_b += 4 * src_stride_y;
3243 }
3244
3245 // This for loop perfoms the leftover accumulations
3246 for(; i < (int)ROWS_B; ++i)
3247 {
3248 const uchar16 b0_u8 = vload16(0, matrix_b);
3249
3250 sum_col_u32 += convert_uint16(b0_u8);
3251
3252 matrix_b += src_stride_y;
3253 }
3254
3255 vstore16(convert_int16(sum_col_u32), 0, (__global int *)dst.ptr);
3256}
3257#endif // defined(COLS_B) && defined(ROWS_B)
3258
3259#if defined(K_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003260
3261/* Helper function used to calculate the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel.
3262 *
3263 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
3264 * and calculates the offset contribution of matrix A and matrix B.
3265 *
3266 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
3267 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
3268 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
3269 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
3270 *
3271 * @param[in] x get_global_id(0) * 4
3272 * @param[in] y get_global_id(1)
3273 * @param[in] z get_global_id(2)
3274 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3275 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3276 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
3277 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3278 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
3279 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3280 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3281 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3282 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
3283 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3284 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
3285 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3286 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3287 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3288 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3289 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
3290 */
3291inline int4 offset_contribution(
3292 int x,
3293 int y,
3294 int z
3295#if defined(A_OFFSET)
3296 ,
3297 IMAGE_DECLARATION(sum_col)
3298#endif // defined(A_OFFSET)
3299#if defined(B_OFFSET)
3300 ,
3301 IMAGE_DECLARATION(sum_row)
3302#endif // defined(B_OFFSET)
3303#if defined(ADD_BIAS)
3304 ,
3305 VECTOR_DECLARATION(biases)
3306#endif // defined(ADD_BIAS)
3307)
3308{
3309 int4 a_offset_s32 = (int4)0;
3310 int4 b_offset_s32 = (int4)0;
3311
3312 int batch_id = z;
3313#if defined(DEPTH_INPUT3D)
3314 batch_id /= (int)DEPTH_INPUT3D;
3315#endif // defined(DEPTH_INPUT3D)
3316
3317#if defined(A_OFFSET)
3318 // Compute the offset contribution due to A_OFFSET
3319 __global uchar *sum_col_addr = sum_col_ptr + sum_col_offset_first_element_in_bytes + x * sizeof(int);
3320
3321 // Compute the offset contribution due to A_OFFSET
3322#if defined(SUM_COL_HAS_BATCHES)
3323 a_offset_s32 = vload4(0, (__global int *)(sum_col_addr + batch_id * sum_col_stride_y));
3324#else // defined(SUM_COL_HAS_BATCHES)
3325 a_offset_s32 = vload4(0, (__global int *)sum_col_addr);
3326#endif // defined(SUM_COL_HAS_BATCHES)
3327
3328 a_offset_s32 *= (int4)A_OFFSET;
3329#endif // defined(A_OFFSET)
3330
3331#if defined(B_OFFSET)
3332 // Compute the offset contribution due to A_OFFSET
3333 __global uchar *sum_row_addr = sum_row_ptr + sum_row_offset_first_element_in_bytes + y * sizeof(int);
3334
3335 // Compute the offset contribution due to B_OFFSET
3336#if defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
3337 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)) + (z % (int)DEPTH_INPUT3D) * (int)HEIGHT_INPUT3D);
3338#else // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
3339 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)));
3340#endif // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
3341 b_offset_s32 *= (int4)B_OFFSET;
3342#endif // defined(B_OFFSET)
3343
3344#if defined(ADD_BIAS)
3345 // Add bias
3346 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3347
3348 int4 biases_values = vload4(0, (__global int *)bias_addr);
3349 b_offset_s32 += (int4)biases_values;
3350#endif // defined(ADD_BIAS)
3351
3352 return (int4)K_OFFSET + a_offset_s32 + b_offset_s32;
3353}
3354
Gian Marco05288a22017-11-21 10:57:50 +00003355/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel. The computation is performed in-place
3356 *
3357 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
3358 * and adds to it the offset contribution of matrix A and matrix B in-place.
3359 *
3360 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
3361 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
3362 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
Chunosov5124be52017-11-22 20:42:13 +07003363 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
Gian Marco05288a22017-11-21 10:57:50 +00003364 *
3365 * The final result is:
3366 *
3367 * mm_result[i][k] = mm_result[i][k] +
3368 * (sum_col[k] * A_OFFSET) +
3369 * (sum_row[i] * B_OFFSET) +
3370 * (K_OFFSET)
3371 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01003372 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
3373 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
3374 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
3375 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
3376 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
3377 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
3378 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
3379 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003380 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3381 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3382 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
3383 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3384 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
3385 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3386 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3387 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3388 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
3389 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3390 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
3391 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3392 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3393 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3394 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3395 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00003396 */
3397__kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result)
3398#if defined(A_OFFSET)
3399 ,
3400 IMAGE_DECLARATION(sum_col)
3401#endif // defined(A_OFFSET)
3402#if defined(B_OFFSET)
3403 ,
3404 IMAGE_DECLARATION(sum_row)
3405#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003406#if defined(ADD_BIAS)
3407 ,
3408 VECTOR_DECLARATION(biases)
3409#endif // defined(ADD_BIAS))
Gian Marco05288a22017-11-21 10:57:50 +00003410 )
3411{
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003412 const int x = get_global_id(0) * 4;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01003413 const int y = get_global_id(1);
3414 const int z = get_global_id(2);
3415
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003416 // Compute offset contribution
3417 int4 offset_term_s32 = offset_contribution(
3418 x, y, z
Gian Marco05288a22017-11-21 10:57:50 +00003419#if defined(A_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003420 ,
3421 sum_col_ptr,
3422 sum_col_stride_x,
3423 sum_col_step_x,
3424 sum_col_stride_y,
3425 sum_col_step_y,
3426 sum_col_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00003427#endif // defined(A_OFFSET)
Gian Marco05288a22017-11-21 10:57:50 +00003428#if defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003429 ,
3430 sum_row_ptr,
3431 sum_row_stride_x,
3432 sum_row_step_x,
3433 sum_row_stride_y,
3434 sum_row_step_y,
3435 sum_row_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00003436#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003437#if defined(ADD_BIAS)
3438 ,
3439 biases_ptr,
3440 biases_stride_x,
3441 biases_step_x,
3442 biases_offset_first_element_in_bytes
3443#endif // defined(ADD_BIAS)
3444 );
Gian Marco05288a22017-11-21 10:57:50 +00003445
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003446 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00003447
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003448 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00003449
3450 // Add the offset terms to GEMM's result
3451 in_s32 += offset_term_s32;
3452
3453 // Store the result with the offset contribution
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003454 vstore4(in_s32, 0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00003455}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003456
3457#if defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
3458/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
3459 *
3460 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
3461 *
3462 *
3463 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
3464 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
3465 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
3466 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
3467 *
3468 * The result before the output stage is:
3469 *
3470 * mm_result[i][k] = mm_result[i][k] +
3471 * (sum_col[k] * A_OFFSET) +
3472 * (sum_row[i] * B_OFFSET) +
3473 * (K_OFFSET)
3474 *
3475 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
3476 *
3477 * -# Add offset terms to final result
3478 * -# Multiply each entry of result by result_mult_int
3479 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
3480 * -# Shift the int32 accumulator by result_shift
3481 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
3482 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3483 *
3484 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3485 *
3486 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3487 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3488 * These values can be used to implement "rectified linear unit" activation functions
3489 *
3490 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
3491 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
3492 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
3493 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
3494 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
3495 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
3496 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
3497 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
3498 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3499 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3500 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
3501 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3502 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
3503 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3504 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3505 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3506 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
3507 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3508 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
3509 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3510 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3511 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3512 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3513 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
3514 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3515 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3516 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3517 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3518 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3519 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3520 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3521 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3522 */
3523__kernel void gemmlowp_offset_contribution_quantize_down(TENSOR3D_DECLARATION(mm_result)
3524#if defined(A_OFFSET)
3525 ,
3526 IMAGE_DECLARATION(sum_col)
3527#endif // defined(A_OFFSET)
3528#if defined(B_OFFSET)
3529 ,
3530 IMAGE_DECLARATION(sum_row)
3531#endif // defined(B_OFFSET)
3532 ,
3533#if defined(ADD_BIAS)
3534 VECTOR_DECLARATION(biases),
3535#endif // defined(ADD_BIAS)
3536 TENSOR3D_DECLARATION(dst))
3537{
3538 const int x = get_global_id(0) * 4;
3539 const int y = get_global_id(1);
3540 const int z = get_global_id(2);
3541
3542 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3543
3544 // Compute offset contribution
3545 int4 offset_term_s32 = offset_contribution(
3546 x, y, z
3547#if defined(A_OFFSET)
3548 ,
3549 sum_col_ptr,
3550 sum_col_stride_x,
3551 sum_col_step_x,
3552 sum_col_stride_y,
3553 sum_col_step_y,
3554 sum_col_offset_first_element_in_bytes
3555#endif // defined(A_OFFSET)
3556#if defined(B_OFFSET)
3557 ,
3558 sum_row_ptr,
3559 sum_row_stride_x,
3560 sum_row_step_x,
3561 sum_row_stride_y,
3562 sum_row_step_y,
3563 sum_row_offset_first_element_in_bytes
3564#endif // defined(B_OFFSET)
3565#if defined(ADD_BIAS)
3566 ,
3567 biases_ptr,
3568 biases_stride_x,
3569 biases_step_x,
3570 biases_offset_first_element_in_bytes
3571#endif // defined(ADD_BIAS)
3572 );
3573
3574 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
3575
3576 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
3577
3578 // Add the offset terms to GEMM's result
3579 in_s32 += offset_term_s32;
3580
3581 // -------------- OUTPUT STAGE
3582
3583 // Add the offset terms to GEMM's result
3584 in_s32 += (int4)RESULT_OFFSET;
3585
3586 // Multiply by result_mult_int and shift
3587 in_s32 *= RESULT_MULTIPLIER;
3588
3589 in_s32 >>= RESULT_SHIFT;
3590
3591 uchar4 res = convert_uchar4_sat(in_s32);
3592
3593#if defined(MIN_BOUND)
3594 res = max(res, (uchar4)MIN_BOUND);
3595#endif // defined(MIN_BOUND)
3596#if defined(MAX_BOUND)
3597 res = min(res, (uchar4)MAX_BOUND);
3598#endif // defined(MAX_BOUND)
3599
3600 // Store the result
3601 vstore4(res, 0, dst_addr);
3602}
3603
3604/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
3605 *
3606 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
3607 *
3608 *
3609 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
3610 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
3611 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
3612 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
3613 *
3614 * The result before the output stage is:
3615 *
3616 * mm_result[i][k] = mm_result[i][k] +
3617 * (sum_col[k] * A_OFFSET) +
3618 * (sum_row[i] * B_OFFSET) +
3619 * (K_OFFSET)
3620 *
3621 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
3622 *
3623 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3624 * -# Add bias to final result if bias tensor is not a nullptr
3625 * -# Round to nearest division by a power-of-two using result_shift
3626 * -# Add offset to each result
3627 * -# Clamp the value between the specified min and max bounds
3628 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3629 *
3630 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3631 *
3632 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3633 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3634 * These values can be used to implement "rectified linear unit" activation functions
3635 *
3636 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
3637 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
3638 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
3639 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
3640 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
3641 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
3642 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
3643 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
3644 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3645 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3646 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
3647 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3648 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
3649 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3650 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3651 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3652 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
3653 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3654 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
3655 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3656 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3657 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3658 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3659 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
3660 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3661 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3662 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3663 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3664 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3665 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3666 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3667 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3668 */
3669__kernel void gemmlowp_offset_contribution_quantize_down_fixedpoint(TENSOR3D_DECLARATION(mm_result)
3670#if defined(A_OFFSET)
3671 ,
3672 IMAGE_DECLARATION(sum_col)
3673#endif // defined(A_OFFSET)
3674#if defined(B_OFFSET)
3675 ,
3676 IMAGE_DECLARATION(sum_row)
3677#endif // defined(B_OFFSET)
3678 ,
3679#if defined(ADD_BIAS)
3680 VECTOR_DECLARATION(biases),
3681#endif // defined(ADD_BIAS)
3682 TENSOR3D_DECLARATION(dst))
3683{
3684 const int x = get_global_id(0) * 4;
3685 const int y = get_global_id(1);
3686 const int z = get_global_id(2);
3687
3688 // Compute offset contribution
3689 int4 offset_term_s32 = offset_contribution(
3690 x, y, z
3691#if defined(A_OFFSET)
3692 ,
3693 sum_col_ptr,
3694 sum_col_stride_x,
3695 sum_col_step_x,
3696 sum_col_stride_y,
3697 sum_col_step_y,
3698 sum_col_offset_first_element_in_bytes
3699#endif // defined(A_OFFSET)
3700#if defined(B_OFFSET)
3701 ,
3702 sum_row_ptr,
3703 sum_row_stride_x,
3704 sum_row_step_x,
3705 sum_row_stride_y,
3706 sum_row_step_y,
3707 sum_row_offset_first_element_in_bytes
3708#endif // defined(B_OFFSET)
3709#if defined(ADD_BIAS)
3710 ,
3711 biases_ptr,
3712 biases_stride_x,
3713 biases_step_x,
3714 biases_offset_first_element_in_bytes
3715#endif // defined(ADD_BIAS)
3716 );
3717
3718 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
3719
3720 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3721
3722 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
3723
3724 // Add the offset terms to GEMM's result
3725 in_s32 += offset_term_s32;
3726
3727 // -------------- OUTPUT STAGE
3728
3729 // Multiply by result_mult_int and shift
3730 in_s32 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(in_s32, RESULT_MULTIPLIER, RESULT_SHIFT, 4);
3731
3732 // Add the offset terms to GEMM's result
3733 in_s32 += (int4)RESULT_OFFSET;
3734
3735 uchar4 res = convert_uchar4_sat(in_s32);
3736
3737#if defined(MIN_BOUND)
3738 res = max(res, (uchar4)MIN_BOUND);
3739#endif // defined(MIN_BOUND)
3740#if defined(MAX_BOUND)
3741 res = min(res, (uchar4)MAX_BOUND);
3742#endif // defined(MAX_BOUND)
3743
3744 // Store the result
3745 vstore4(res, 0, dst_addr);
3746}
3747#endif // defined(K_OFFSET) && defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
Gian Marco05288a22017-11-21 10:57:50 +00003748#endif // defined(K_OFFSET)
3749
3750#if defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3751/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3752 *
3753 * This kernel takes a final int32 accumulator value and processes it to obtain the final QASYMM8 value.
3754 * The following computations will be performed by the kernel:
3755 *
3756 * -# Add offset terms to final result
3757 * -# Multiply each entry of result by result_mult_int
3758 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
3759 * -# Shift the int32 accumulator by result_shift
3760 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
3761 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3762 *
3763 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3764 *
3765 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3766 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3767 * These values can be used to implement "rectified linear unit" activation functions
3768 *
3769 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3770 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3771 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3772 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3773 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3774 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3775 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3776 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003777 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3778 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3779 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3780 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00003781 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3782 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3783 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3784 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3785 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3786 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3787 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3788 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3789 */
3790__kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src),
3791#if defined(ADD_BIAS)
3792 VECTOR_DECLARATION(biases),
3793#endif // defined(ADD_BIAS)
3794 TENSOR3D_DECLARATION(dst))
3795{
3796 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003797 int x = get_global_id(0) * 4;
3798 int y = get_global_id(1);
3799 int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +00003800
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003801 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00003802
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003803 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3804
3805 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003806
Gian Marco05288a22017-11-21 10:57:50 +00003807#if defined(ADD_BIAS)
3808 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003809 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3810
3811 int4 biases_values = vload4(0, (__global int *)bias_addr);
3812 input_values += (int4)biases_values;
Gian Marco05288a22017-11-21 10:57:50 +00003813#endif // defined(ADD_BIAS)
3814
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003815 // Add the offset terms to GEMM's result
3816 input_values += (int4)RESULT_OFFSET;
3817
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +00003818 // Multiply by result_mult_int and shift
Gian Marco58c57942017-11-28 09:10:03 +00003819 input_values *= RESULT_MULT_INT;
Gian Marco05288a22017-11-21 10:57:50 +00003820
Gian Marco58c57942017-11-28 09:10:03 +00003821 input_values >>= RESULT_SHIFT;
Gian Marco05288a22017-11-21 10:57:50 +00003822
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003823 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco05288a22017-11-21 10:57:50 +00003824
3825#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003826 res = max(res, (uchar4)MIN_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003827#endif // defined(MIN_BOUND)
3828#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003829 res = min(res, (uchar4)MAX_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003830#endif // defined(MAX_BOUND)
3831
3832 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003833 vstore4(res, 0, dst_addr);
Gian Marco05288a22017-11-21 10:57:50 +00003834}
Gian Marco58c57942017-11-28 09:10:03 +00003835#endif // defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3836
3837#if defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
3838/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3839 *
3840 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3841 * The following computations will be performed by the kernel:
3842 *
3843 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3844 * -# Add bias to final result if bias tensor is not a nullptr
3845 * -# Round to nearest division by a power-of-two using result_shift
3846 * -# Add offset to each result
3847 * -# Clamp the value between the specified min and max bounds
3848 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3849 *
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003850 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET_AFTER_SHIFT, -DRESULT_FIXEDPOINT_MULTIPLIER and -DRESULT_SHIFT
Gian Marco58c57942017-11-28 09:10:03 +00003851 *
3852 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3853 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3854 * These values can be used to implement "rectified linear unit" activation functions
3855 *
3856 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3857 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3858 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3859 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3860 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3861 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3862 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3863 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003864 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3865 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3866 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3867 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco58c57942017-11-28 09:10:03 +00003868 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3869 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3870 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3871 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3872 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3873 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3874 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3875 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3876 */
3877__kernel void gemmlowp_output_stage_quantize_down_fixedpoint(TENSOR3D_DECLARATION(src),
3878#if defined(ADD_BIAS)
3879 VECTOR_DECLARATION(biases),
3880#endif // defined(ADD_BIAS)
3881 TENSOR3D_DECLARATION(dst))
3882{
3883 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003884 int x = get_global_id(0) * 4;
3885 int y = get_global_id(1);
3886 int z = get_global_id(2);
Georgios Pinitas932491f2018-09-21 16:33:15 +01003887
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003888 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco58c57942017-11-28 09:10:03 +00003889
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003890 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3891
3892 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003893
3894#if defined(ADD_BIAS)
3895 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003896 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3897
3898 int4 biases_values = vload4(0, (__global int *)bias_addr);
3899 input_values += (int4)biases_values;
Gian Marco58c57942017-11-28 09:10:03 +00003900#endif // defined(ADD_BIAS)
3901
3902 // Multiply by result_mult_int and shift
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003903 input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4);
Gian Marco58c57942017-11-28 09:10:03 +00003904
3905 // Add the offset terms to GEMM's result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003906 input_values += (int4)RESULT_OFFSET_AFTER_SHIFT;
Gian Marco58c57942017-11-28 09:10:03 +00003907
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003908 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco58c57942017-11-28 09:10:03 +00003909
3910#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003911 res = max(res, (uchar4)MIN_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003912#endif // defined(MIN_BOUND)
3913#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003914 res = min(res, (uchar4)MAX_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003915#endif // defined(MAX_BOUND)
3916
3917 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003918 vstore4(res, 0, dst_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003919}
Chunosov5124be52017-11-22 20:42:13 +07003920#endif // defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003921
3922#if defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)
3923/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3924 *
3925 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3926 * The following computations will be performed by the kernel:
3927 *
3928 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3929 * -# Add bias to final result if bias tensor is not a nullptr
3930 * -# Requantize
3931 * -# Add offset to each result
3932 * -# Clamp the value between the specified min and max bounds
3933 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3934 *
3935 * @attention The offset and scalar scale factor must be passed at compile time using -DRESULT_OFFSET, -DREAL_MULTIPLIER
3936 *
3937 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3938 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3939 * These values can be used to implement "rectified linear unit" activation functions
3940 *
3941 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3942 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3943 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3944 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3945 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3946 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3947 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3948 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3949 * @param[in] biases_ptr Pointer to the biases tensor. Supported data type: same as @p src_ptr
3950 * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
3951 * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
3952 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
3953 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3954 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3955 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3956 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3957 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3958 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3959 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3960 * @param[in] dst_stride_w Stride of the source tensor in W dimension (in bytes)
3961 * @param[in] dst_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
3962 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3963 */
3964__kernel void gemmlowp_output_stage_quantize_down_float(TENSOR3D_DECLARATION(src),
3965#if defined(ADD_BIAS)
3966 VECTOR_DECLARATION(biases),
3967#endif // defined(ADD_BIAS)
3968#if defined(DST_HEIGHT)
3969 TENSOR4D_DECLARATION(dst))
3970#else // defined(DST_HEIGHT)
3971 TENSOR3D_DECLARATION(dst))
3972#endif // defined(DST_HEIGHT)
3973{
3974 // Compute source and destination addresses
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003975 int x = get_global_id(0) * 4;
3976 int y = get_global_id(1);
3977 int z = get_global_id(2);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003978
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003979 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003980
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003981 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3982
3983 int4 input_values = vload4(0, (__global int *)src_addr);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003984
3985#if defined(ADD_BIAS)
3986 // Add bias
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003987 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3988
3989 int4 biases_values = vload4(0, (__global int *)bias_addr);
3990 input_values += (int4)biases_values;
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003991#endif // defined(ADD_BIAS)
3992
3993 // Convert to float
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003994 float16 input_values_f = convert_float4(input_values);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003995 input_values_f = round(input_values_f * (float)REAL_MULTIPLIER + (float)OUTPUT_OFFSET);
3996
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003997 uchar4 res = convert_uchar4_sat(input_values_f);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003998
3999#if defined(MIN_BOUND)
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00004000 res = max(res, (uchar4)MIN_BOUND);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01004001#endif // defined(MIN_BOUND)
4002#if defined(MAX_BOUND)
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00004003 res = min(res, (uchar4)MAX_BOUND);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01004004#endif // defined(MAX_BOUND)
4005
4006 // Store the result
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00004007 vstore4(res, 0, dst_addr);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01004008}
Gian Marco Iodice2ec6c1e2019-04-09 12:03:05 +01004009#endif // defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)