blob: b1ba8e03777e465d89a27ce5abfe441b99c94f0c [file] [log] [blame]
Gian Marco05288a22017-11-21 10:57:50 +00001/*
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00002 * Copyright (c) 2017-2019 ARM Limited.
Gian Marco05288a22017-11-21 10:57:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +000025#include "helpers_asymm.h"
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +000026#include "repeat.h"
Gian Marco05288a22017-11-21 10:57:50 +000027
Georgios Pinitasdaa38552018-08-28 17:43:18 +010028#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
29#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010030#define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010031#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
Gian Marco Iodice4b908652018-10-18 10:21:02 +010032#define ARM_DOT(x, y, val) val += arm_dot((x), (y));
Georgios Pinitasdaa38552018-08-28 17:43:18 +010033#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
34#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arenac50da382018-07-26 15:50:09 +010035
Gian Marco19835e52018-01-30 13:35:54 +000036#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +000037/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice5fc07aa2019-05-15 17:08:02 +010038 * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel before running the matrix multiplication
Gian Marco05288a22017-11-21 10:57:50 +000039 *
Gian Marco19835e52018-01-30 13:35:54 +000040 * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024
41 * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
42 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco05288a22017-11-21 10:57:50 +000043 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010044 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
45 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
46 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
47 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
48 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
49 *
Gian Marco05288a22017-11-21 10:57:50 +000050 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
51 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
52 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
53 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
54 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
55 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
56 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
57 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
58 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
59 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
60 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
61 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
62 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
63 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
64 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
65 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
66 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
67 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010068 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
69 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
70 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
71 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +000072 */
Gian Marco19835e52018-01-30 13:35:54 +000073__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0),
74 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010075 IMAGE_DECLARATION(dst),
76 uint src0_stride_z,
77 uint src1_stride_z,
78 uint dst_stride_z
79#if defined(REINTERPRET_OUTPUT_AS_3D)
80 ,
81 uint cross_plane_pad
82#endif // REINTERPRET_OUTPUT_AS_3D
83 )
Gian Marco05288a22017-11-21 10:57:50 +000084{
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010085 const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
86 const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
87 const int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +000088
Gian Marco19835e52018-01-30 13:35:54 +000089 // Offset
90 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
91 const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
92
93 // src_addr_a = address of matrix A
94 // src_addr_b = address of matrix B
Isabella Gottardib92805b2018-09-28 18:24:27 +010095 __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
Gian Marco19835e52018-01-30 13:35:54 +000096 __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Gian Marco05288a22017-11-21 10:57:50 +000097
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010098#if defined(MATRIX_B_DEPTH)
99 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
100 src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
101#else // defined(MATRIX_B_DEPTH)
102 src_addr_b += z * src1_stride_z;
103#endif // defined(MATRIX_B_DEPTH)
104
Gian Marco05288a22017-11-21 10:57:50 +0000105 // Compute end row address for matrix B
Gian Marco19835e52018-01-30 13:35:54 +0000106 __global uchar *src_end_addr_b = src_addr_b + COLS_B;
107
108 src_addr_a += offset_row_a;
109 src_addr_b += offset_row_b;
Gian Marco05288a22017-11-21 10:57:50 +0000110
111 // Reset accumulators
Gian Marco19835e52018-01-30 13:35:54 +0000112 int4 c00 = 0;
113 int4 c10 = 0;
114 int4 c20 = 0;
115 int4 c30 = 0;
Gian Marco05288a22017-11-21 10:57:50 +0000116
Gian Marco19835e52018-01-30 13:35:54 +0000117 for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000118 {
119 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000120 int4 a0 = convert_int4(vload4(0, src_addr_a));
121 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000122
Gian Marco19835e52018-01-30 13:35:54 +0000123 c00 += (int4)a0.s0 * b0;
124 c10 += (int4)a0.s1 * b0;
125 c20 += (int4)a0.s2 * b0;
126 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000127
Gian Marco19835e52018-01-30 13:35:54 +0000128 a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
129 b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP));
Gian Marco05288a22017-11-21 10:57:50 +0000130
Gian Marco19835e52018-01-30 13:35:54 +0000131 c00 += (int4)a0.s0 * b0;
132 c10 += (int4)a0.s1 * b0;
133 c20 += (int4)a0.s2 * b0;
134 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000135 }
136
Gian Marco19835e52018-01-30 13:35:54 +0000137 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
Gian Marco05288a22017-11-21 10:57:50 +0000138 {
139 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco19835e52018-01-30 13:35:54 +0000140 int4 a0 = convert_int4(vload4(0, src_addr_a));
141 int4 b0 = convert_int4(vload4(0, src_addr_b));
Gian Marco05288a22017-11-21 10:57:50 +0000142
Gian Marco19835e52018-01-30 13:35:54 +0000143 c00 += (int4)a0.s0 * b0;
144 c10 += (int4)a0.s1 * b0;
145 c20 += (int4)a0.s2 * b0;
146 c30 += (int4)a0.s3 * b0;
Gian Marco05288a22017-11-21 10:57:50 +0000147 }
148
149 // Compute destination address
150 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
151
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100152#if defined(REINTERPRET_OUTPUT_AS_3D)
153 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
154 // in order to take into account the presence of possible cross plane paddings
155 //
156 // | |
157 // | plane0 |
158 // | |
159 // |__________________|
160 // |******************|
161 // | cross_plane_pad |
162 // |******************|
163 // | |
164 // | plane1 |
165 // | |
166 // |__________________|
167
168 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
169 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
170 zout = min(DEPTH_GEMM3D - 1, zout);
171
172 // Add offset due to the cross plane paddings
173 zout *= (cross_plane_pad * dst_stride_y);
174
175 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
176 // multiply dst_stride_z by DEPTH_GEMM3D
177 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
178
Gian Marco19835e52018-01-30 13:35:54 +0000179 // Store 4x4 block
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100180 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
181 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
182 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
183 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
184
185#else // defined(REINTERPRET_OUTPUT_AS_3D)
186 // Add offset for batched GEMM
187 dst.ptr += z * dst_stride_z;
188
189 // Store 4x4 block
190 vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
191 vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
192 vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
193 vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
194#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000195}
Gian Marco19835e52018-01-30 13:35:54 +0000196#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
Gian Marco05288a22017-11-21 10:57:50 +0000197
198#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
199#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
200#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
201#define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
202/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
203 *
204 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
205 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100206 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
207 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
208 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
209 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
210 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
211 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
212 *
Gian Marco05288a22017-11-21 10:57:50 +0000213 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
214 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
215 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
216 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
217 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
218 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
219 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
220 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
221 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
222 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
223 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
224 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
225 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
226 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
227 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
228 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
229 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
230 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100231 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
232 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
233 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
234 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
235 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000236 */
Gian Marco7b4d5472018-01-10 15:56:30 +0000237__kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
238 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100239 IMAGE_DECLARATION(dst),
240 uint src0_stride_z,
241 uint src1_stride_z,
242 uint dst_stride_z
243#if defined(REINTERPRET_INPUT_AS_3D)
244 ,
245 uint src_cross_plane_pad
246#endif // REINTERPRET_INPUT_AS_3D
247#if defined(REINTERPRET_OUTPUT_AS_3D)
248 ,
249 uint dst_cross_plane_pad
250#endif // REINTERPRET_OUTPUT_AS_3D
251 )
Gian Marco05288a22017-11-21 10:57:50 +0000252{
253 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
254
255 // Compute starting address for matrix A and Matrix B
256 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
257
258 // Update address for the matrix A
259 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
260
261 // Update address for the matrix B
262 src_addr.s1 += idx;
263
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100264#if defined(REINTERPRET_INPUT_AS_3D)
265 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
266 // in order to take into account the presence of possible cross plane paddings
267 //
268 // | |
269 // | plane0 |
270 // | |
271 // |__________________|
272 // |******************|
273 // | cross_plane_pad |
274 // |******************|
275 // | |
276 // | plane1 |
277 // | |
278 // |__________________|
279
280 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
281 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
282 zin = min(DEPTH_GEMM3D - 1, zin);
283
284 // Add offset due to the cross plane paddings
285 zin *= (src_cross_plane_pad * src0_stride_y);
286
287 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
288 // multiply src0_stride_z by DEPTH_GEMM3D
289 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
290
291#else // defined(REINTERPRET_INPUT_AS_3D)
292
293 // Add offset for batched GEMM
294 src_addr.s0 += get_global_id(2) * src0_stride_z;
295
296#endif // defined(REINTERPRET_INPUT_AS_3D)
297
298#if defined(MATRIX_B_DEPTH)
299 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
300 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
301#else // defined(MATRIX_B_DEPTH)
302 src_addr.s1 += get_global_id(2) * src1_stride_z;
303#endif // defined(MATRIX_B_DEPTH)
304
Gian Marco05288a22017-11-21 10:57:50 +0000305 int end_row_vec_a = src_addr.s0 + COLS_A;
306
307 VECTOR_UINT acc0 = 0;
308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
309 VECTOR_UINT acc1 = 0;
310#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
312 VECTOR_UINT acc2 = 0;
313#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
315 VECTOR_UINT acc3 = 0;
316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
318 VECTOR_UINT acc4 = 0;
319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000320
321 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
322 {
323 // Load values from matrix A
324 uchar2 a0 = vload2(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
325#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
326 uchar2 a1 = vload2(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
327#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
328#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
329 uchar2 a2 = vload2(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
330#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
331#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
332 uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
333#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000334#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
335 uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
336#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000337 // Load values from matrix B
338 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
339 VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y);
340
341 // Accumulate
342 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0.s0;
343 acc0 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a0.s1;
344#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
345 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1.s0;
346 acc1 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a1.s1;
347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
348#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
349 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2.s0;
350 acc2 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a2.s1;
351#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
352#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
353 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0;
354 acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1;
355#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000356#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
357 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0;
358 acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1;
359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000360 }
361
362 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
363 {
364 // Load values from matrix A
365 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
367 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
370 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
373 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
376 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000378 // Load values from matrix B
379 VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
380
381 // Accumulate
382 acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0;
383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
384 acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1;
385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
387 acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2;
388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
390 acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3;
391#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000392#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
393 acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4;
394#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Gian Marco05288a22017-11-21 10:57:50 +0000395 }
396
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100397 const int z = get_global_id(2);
398
Gian Marco05288a22017-11-21 10:57:50 +0000399 // Compute destination address
400 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
401
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100402#if defined(REINTERPRET_OUTPUT_AS_3D)
403 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
404 // in order to take into account the presence of possible cross plane paddings
405 //
406 // | |
407 // | plane0 |
408 // | |
409 // |__________________|
410 // |******************|
411 // | cross_plane_pad |
412 // |******************|
413 // | |
414 // | plane1 |
415 // | |
416 // |__________________|
417
418 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
419 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
420 zout = min(DEPTH_GEMM3D - 1, zout);
421
422 // Add offset due to the cross plane paddings
423 zout *= (dst_cross_plane_pad * dst_stride_y);
424
425 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
426 // multiply dst_stride_z by DEPTH_GEMM3D
427 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
428
Gian Marco05288a22017-11-21 10:57:50 +0000429 // Store the result
430 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100431 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco05288a22017-11-21 10:57:50 +0000432#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
433 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100434 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco05288a22017-11-21 10:57:50 +0000435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
437 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100438 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco05288a22017-11-21 10:57:50 +0000439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
440#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
441 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100442 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco05288a22017-11-21 10:57:50 +0000443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco7b4d5472018-01-10 15:56:30 +0000444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
445 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100446 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +0000447#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100448
449#else // defined(REINTERPRET_OUTPUT_AS_3D)
450 // Add offset for batched GEMM
451 dst.ptr += z * dst_stride_z;
452
453 // Store the result
454 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
455 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
456#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
457 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
458 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
459#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
460#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
461 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
462 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
464#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
465 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
466 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
467#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
468#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
469 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
470 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
471#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
472#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +0000473}
474
475/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
476 *
477 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
478 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100479 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
480 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
481 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
482 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
483 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
484 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
485 *
Gian Marco7b4d5472018-01-10 15:56:30 +0000486 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
487 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
488 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
489 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
490 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
491 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
492 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
493 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
494 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
495 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
496 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
497 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
498 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
499 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
500 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
501 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
502 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
503 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100504 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
505 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
506 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
507 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
508 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco7b4d5472018-01-10 15:56:30 +0000509 */
510__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
511 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100512 IMAGE_DECLARATION(dst),
513 uint src0_stride_z,
514 uint src1_stride_z,
515 uint dst_stride_z
516#if defined(REINTERPRET_INPUT_AS_3D)
517 ,
518 uint src_cross_plane_pad
519#endif // REINTERPRET_INPUT_AS_3D
520#if defined(REINTERPRET_OUTPUT_AS_3D)
521 ,
522 uint dst_cross_plane_pad
523#endif // REINTERPRET_OUTPUT_AS_3D
524 )
Gian Marco7b4d5472018-01-10 15:56:30 +0000525{
526 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
527
528 // Compute starting address for matrix A and Matrix B
529 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
530
531 // Update address for the matrix A
532 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
533
534 // Update address for the matrix B
535 src_addr.s1 += idx;
536
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100537#if defined(REINTERPRET_INPUT_AS_3D)
538 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
539 // in order to take into account the presence of possible cross plane paddings
540 //
541 // | |
542 // | plane0 |
543 // | |
544 // |__________________|
545 // |******************|
546 // | cross_plane_pad |
547 // |******************|
548 // | |
549 // | plane1 |
550 // | |
551 // |__________________|
552
553 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
554 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
555 zin = min(DEPTH_GEMM3D - 1, zin);
556
557 // Add offset due to the cross plane paddings
558 zin *= (src_cross_plane_pad * src0_stride_y);
559
560 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
561 // multiply src0_stride_z by DEPTH_GEMM3D
562 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
563
564#else // defined(REINTERPRET_INPUT_AS_3D)
565
566 // Add offset for batched GEMM
567 src_addr.s0 += get_global_id(2) * src0_stride_z;
568
569#endif // defined(REINTERPRET_INPUT_AS_3D)
570
571#if defined(MATRIX_B_DEPTH)
572 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
573 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
574#else // defined(MATRIX_B_DEPTH)
575 src_addr.s1 += get_global_id(2) * src1_stride_z;
576#endif // defined(MATRIX_B_DEPTH)
577
Gian Marco7b4d5472018-01-10 15:56:30 +0000578 int end_row_vec_a = src_addr.s0 + COLS_A;
579
580 uint acc00 = 0;
581 uint acc01 = 0;
582 uint acc02 = 0;
583 uint acc03 = 0;
584#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
585 uint acc10 = 0;
586 uint acc11 = 0;
587 uint acc12 = 0;
588 uint acc13 = 0;
589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
590#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
591 uint acc20 = 0;
592 uint acc21 = 0;
593 uint acc22 = 0;
594 uint acc23 = 0;
595#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
597 uint acc30 = 0;
598 uint acc31 = 0;
599 uint acc32 = 0;
600 uint acc33 = 0;
601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
603 uint acc40 = 0;
604 uint acc41 = 0;
605 uint acc42 = 0;
606 uint acc43 = 0;
607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
608
609 for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
610 {
611 // Load values from matrix A
612 uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
614 uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
616#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
617 uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
620 uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
621#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
622#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
623 uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
624#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
625 // Load values from matrix B
626 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
627 uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
628 uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
629 uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
630
631 {
632 // Accumulate
633 ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
634 ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
635 ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
636 ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
637
638 ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
639 ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
640 ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
641 ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
642
643 ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
644 ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
645 ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
646 ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
647
648 ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
649 ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
650 ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
651 ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
652
653 acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
654 acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
655 acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
656 acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
657 }
658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
659 {
660 // Accumulate
661 ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
662 ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
663 ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
664 ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
665
666 ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
667 ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
668 ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
669 ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
670
671 ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
672 ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
673 ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
674 ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
675
676 ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
677 ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
678 ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
679 ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
680
681 acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
682 acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
683 acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
684 acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
685 }
686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
688 {
689 // Accumulate
690 ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
691 ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
692 ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
693 ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
694
695 ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
696 ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
697 ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
698 ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
699
700 ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
701 ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
702 ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
703 ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
704
705 ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
706 ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
707 ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
708 ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
709
710 acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
711 acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
712 acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
713 acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
714 }
715#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
716#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
717 {
718 // Accumulate
719 ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
720 ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
721 ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
722 ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
723
724 ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
725 ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
726 ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
727 ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
728
729 ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
730 ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
731 ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
732 ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
733
734 ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
735 ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
736 ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
737 ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
738
739 acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
740 acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
741 acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
742 acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
743 }
744#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
745#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
746 {
747 // Accumulate
748 ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
749 ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
750 ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
751 ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
752
753 ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
754 ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
755 ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
756 ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
757
758 ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
759 ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
760 ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
761 ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
762
763 ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
764 ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
765 ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
766 ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
767
768 acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
769 acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
770 acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
771 acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
772 }
773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
774 }
775
776 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
777 {
778 // Load values from matrix A
779 uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
780#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
781 uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
783#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
784 uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
785#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
786#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
787 uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
790 uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
792 // Load values from matrix B
793 uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
794
795 // Accumulate
796 {
797 // Accumulate
798 ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
799 ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
800 ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
801 ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
802
803 acc00 += ((uint)tmp0);
804 acc01 += ((uint)tmp1);
805 acc02 += ((uint)tmp2);
806 acc03 += ((uint)tmp3);
807 }
808#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
809 {
810 // Accumulate
811 ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
812 ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
813 ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
814 ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
815
816 acc10 += ((uint)tmp0);
817 acc11 += ((uint)tmp1);
818 acc12 += ((uint)tmp2);
819 acc13 += ((uint)tmp3);
820 }
821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
822#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
823 {
824 // Accumulate
825 ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
826 ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
827 ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
828 ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
829
830 acc20 += ((uint)tmp0);
831 acc21 += ((uint)tmp1);
832 acc22 += ((uint)tmp2);
833 acc23 += ((uint)tmp3);
834 }
835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
837 {
838 // Accumulate
839 ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
840 ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
841 ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
842 ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
843
844 acc30 += ((uint)tmp0);
845 acc31 += ((uint)tmp1);
846 acc32 += ((uint)tmp2);
847 acc33 += ((uint)tmp3);
848 }
849#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
850#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
851 {
852 // Accumulate
853 ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
854 ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
855 ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
856 ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
857
858 acc40 += ((uint)tmp0);
859 acc41 += ((uint)tmp1);
860 acc42 += ((uint)tmp2);
861 acc43 += ((uint)tmp3);
862 }
863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
864 }
865
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100866 const int z = get_global_id(2);
867
Gian Marco7b4d5472018-01-10 15:56:30 +0000868 // Compute destination address
869 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
870
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100871#if defined(REINTERPRET_OUTPUT_AS_3D)
872 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
873 // in order to take into account the presence of possible cross plane paddings
874 //
875 // | |
876 // | plane0 |
877 // | |
878 // |__________________|
879 // |******************|
880 // | cross_plane_pad |
881 // |******************|
882 // | |
883 // | plane1 |
884 // | |
885 // |__________________|
886
887 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
888 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
889 zout = min(DEPTH_GEMM3D - 1, zout);
890
891 // Add offset due to the cross plane paddings
892 zout *= (dst_cross_plane_pad * dst_stride_y);
893
894 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
895 // multiply dst_stride_z by DEPTH_GEMM3D
896 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
897
Gian Marco7b4d5472018-01-10 15:56:30 +0000898 // Store the result
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100899 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
Gian Marco7b4d5472018-01-10 15:56:30 +0000900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100901 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
Gian Marco7b4d5472018-01-10 15:56:30 +0000902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100904 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
Gian Marco7b4d5472018-01-10 15:56:30 +0000905#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
906#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100907 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
Gian Marco7b4d5472018-01-10 15:56:30 +0000908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100910 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
Gian Marco7b4d5472018-01-10 15:56:30 +0000911#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100912
913#else // defined(REINTERPRET_OUTPUT_AS_3D)
914 // Add offset for batched GEMM
915 dst.ptr += z * dst_stride_z;
916
917 // Store the result
918 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
919#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
920 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
921#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
922#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
923 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
924#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
925#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
926 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
927#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
928#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
929 vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
930#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
931#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco05288a22017-11-21 10:57:50 +0000932}
Giorgio Arena6200fa42018-07-06 17:06:36 +0100933
Georgios Pinitasdaa38552018-08-28 17:43:18 +0100934#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100935/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
936 *
937 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
938 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100939 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
940 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
941 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
942 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
943 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
944 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
945 *
Giorgio Arena6200fa42018-07-06 17:06:36 +0100946 * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
947 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
948 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
949 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
950 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
951 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
952 * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
953 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
954 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
955 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
956 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
957 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
958 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
959 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
960 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
961 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
962 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
963 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100964 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
965 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
966 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
967 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
968 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +0100969 */
970__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
971 IMAGE_DECLARATION(src1),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100972 IMAGE_DECLARATION(dst),
973 uint src0_stride_z,
974 uint src1_stride_z,
975 uint dst_stride_z
976#if defined(REINTERPRET_INPUT_AS_3D)
977 ,
978 uint src_cross_plane_pad
979#endif // REINTERPRET_INPUT_AS_3D
980#if defined(REINTERPRET_OUTPUT_AS_3D)
981 ,
982 uint dst_cross_plane_pad
983#endif // REINTERPRET_OUTPUT_AS_3D)
984 )
Giorgio Arena6200fa42018-07-06 17:06:36 +0100985{
986 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
987
988 // Compute starting address for matrix A and Matrix B
989 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
990
991 // Update address for the matrix A
992 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
993
994 // Update address for the matrix B
995 src_addr.s1 += idx;
996
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100997#if defined(REINTERPRET_INPUT_AS_3D)
998 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
999 // in order to take into account the presence of possible cross plane paddings
1000 //
1001 // | |
1002 // | plane0 |
1003 // | |
1004 // |__________________|
1005 // |******************|
1006 // | cross_plane_pad |
1007 // |******************|
1008 // | |
1009 // | plane1 |
1010 // | |
1011 // |__________________|
1012
1013 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1014 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1015 zin = min(DEPTH_GEMM3D - 1, zin);
1016
1017 // Add offset due to the cross plane paddings
1018 zin *= (src_cross_plane_pad * src0_stride_y);
1019
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001020 zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
1021
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001022 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1023 // multiply src0_stride_z by DEPTH_GEMM3D
1024 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1025
1026#else // defined(REINTERPRET_INPUT_AS_3D)
1027
1028 // Add offset for batched GEMM
1029 src_addr.s0 += get_global_id(2) * src0_stride_z;
1030
1031#endif // defined(REINTERPRET_INPUT_AS_3D)
1032
1033#if defined(MATRIX_B_DEPTH)
1034 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1035 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1036#else // defined(MATRIX_B_DEPTH)
1037 src_addr.s1 += get_global_id(2) * src1_stride_z;
1038#endif // defined(MATRIX_B_DEPTH)
1039
Giorgio Arena6200fa42018-07-06 17:06:36 +01001040 uint acc00 = 0;
1041 uint acc01 = 0;
1042 uint acc02 = 0;
1043 uint acc03 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001044 uint acc04 = 0;
1045 uint acc05 = 0;
1046 uint acc06 = 0;
1047 uint acc07 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1049 uint acc10 = 0;
1050 uint acc11 = 0;
1051 uint acc12 = 0;
1052 uint acc13 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001053 uint acc14 = 0;
1054 uint acc15 = 0;
1055 uint acc16 = 0;
1056 uint acc17 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001057#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1058#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1059 uint acc20 = 0;
1060 uint acc21 = 0;
1061 uint acc22 = 0;
1062 uint acc23 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001063 uint acc24 = 0;
1064 uint acc25 = 0;
1065 uint acc26 = 0;
1066 uint acc27 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001067#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1068#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1069 uint acc30 = 0;
1070 uint acc31 = 0;
1071 uint acc32 = 0;
1072 uint acc33 = 0;
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001073 uint acc34 = 0;
1074 uint acc35 = 0;
1075 uint acc36 = 0;
1076 uint acc37 = 0;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001077#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001078
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001079 // A and B src indices get incremented at the same time.
1080 int i = 0;
1081 for(; i <= ((int)COLS_A - 8); i += 8)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001082 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001083#if defined(REINTERPRET_INPUT_AS_3D)
1084 // Load values from matrix A and matrix B
1085 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001086#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001087 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001088#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001090 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001091#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001093 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001095#else // defined(REINTERPRET_INPUT_AS_3D)
1096 // Load values from matrix A and matrix B
1097 uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1099 uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1101#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1102 uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1103#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1104#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1105 uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1106#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1107#endif // defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001108
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001109 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1110 uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1111 uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1112 uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1113 src_addr.s1 += 4 * src1_stride_y;
1114
1115 ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1116 ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1117 ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1118 ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1119 ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1120 ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1121 ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1122 ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1123
Giorgio Arena6200fa42018-07-06 17:06:36 +01001124#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001125 ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1126 ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1127 ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1128 ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1129 ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1130 ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1131 ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1132 ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001133#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1134#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001135 ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1136 ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1137 ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1138 ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1139 ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1140 ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1141 ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1142 ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001143#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1144#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001145 ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1146 ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1147 ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1148 ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1149 ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1150 ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1151 ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1152 ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
Giorgio Arena6200fa42018-07-06 17:06:36 +01001153#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001154
1155 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
1156 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
1157 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
1158 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
1159 src_addr.s1 += 4 * src1_stride_y;
1160
1161 ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
1162 ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
1163 ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
1164 ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
1165 ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
1166 ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
1167 ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
1168 ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
1169
1170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1171 ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
1172 ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
1173 ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
1174 ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
1175 ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
1176 ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
1177 ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
1178 ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
1179#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1181 ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
1182 ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
1183 ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
1184 ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
1185 ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
1186 ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
1187 ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
1188 ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
1189#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1190#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1191 ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
1192 ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
1193 ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
1194 ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
1195 ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
1196 ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
1197 ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
1198 ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
1199#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1200
1201 src_addr.s0 += 8;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001202 }
1203
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001204 for(; i < (int)COLS_A; ++i)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001205 {
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001206#if defined(REINTERPRET_INPUT_AS_3D)
Giorgio Arena6200fa42018-07-06 17:06:36 +01001207 // Load values from matrix A
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001208 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001210 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001213 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001216 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001218#else // defined(REINTERPRET_INPUT_AS_3D)
1219 // Load values from matrix A
1220 uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1222 uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1225 uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1226#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1227#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1228 uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1230#endif // defined(REINTERPRET_INPUT_AS_3D)
1231
Giorgio Arena6200fa42018-07-06 17:06:36 +01001232 // Load values from matrix B
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001233 uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
1234 src_addr.s1 += src1_stride_y;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001235
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001236 acc00 += (uint)a0 * b0.s0;
1237 acc01 += (uint)a0 * b0.s1;
1238 acc02 += (uint)a0 * b0.s2;
1239 acc03 += (uint)a0 * b0.s3;
1240 acc04 += (uint)a0 * b0.s4;
1241 acc05 += (uint)a0 * b0.s5;
1242 acc06 += (uint)a0 * b0.s6;
1243 acc07 += (uint)a0 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001244
Giorgio Arena6200fa42018-07-06 17:06:36 +01001245#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001246 acc10 += (uint)a1 * b0.s0;
1247 acc11 += (uint)a1 * b0.s1;
1248 acc12 += (uint)a1 * b0.s2;
1249 acc13 += (uint)a1 * b0.s3;
1250 acc14 += (uint)a1 * b0.s4;
1251 acc15 += (uint)a1 * b0.s5;
1252 acc16 += (uint)a1 * b0.s6;
1253 acc17 += (uint)a1 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001254#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1255#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001256 acc20 += (uint)a2 * b0.s0;
1257 acc21 += (uint)a2 * b0.s1;
1258 acc22 += (uint)a2 * b0.s2;
1259 acc23 += (uint)a2 * b0.s3;
1260 acc24 += (uint)a2 * b0.s4;
1261 acc25 += (uint)a2 * b0.s5;
1262 acc26 += (uint)a2 * b0.s6;
1263 acc27 += (uint)a2 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001264#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1265#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001266 acc30 += (uint)a3 * b0.s0;
1267 acc31 += (uint)a3 * b0.s1;
1268 acc32 += (uint)a3 * b0.s2;
1269 acc33 += (uint)a3 * b0.s3;
1270 acc34 += (uint)a3 * b0.s4;
1271 acc35 += (uint)a3 * b0.s5;
1272 acc36 += (uint)a3 * b0.s6;
1273 acc37 += (uint)a3 * b0.s7;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Giorgio Arena6200fa42018-07-06 17:06:36 +01001275
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001276 src_addr.s0 += 1;
Giorgio Arena6200fa42018-07-06 17:06:36 +01001277 }
1278
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001279 int z = get_global_id(2);
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001280
Giorgio Arena6200fa42018-07-06 17:06:36 +01001281 // Compute destination address
1282 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1283
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001284 // Compute dst address
1285 __global uchar *dst_addr = dst.ptr;
1286
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001287#if defined(REINTERPRET_OUTPUT_AS_3D)
1288 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1289 // in order to take into account the presence of possible cross plane paddings
1290 //
1291 // | |
1292 // | plane0 |
1293 // | |
1294 // |__________________|
1295 // |******************|
1296 // | cross_plane_pad |
1297 // |******************|
1298 // | |
1299 // | plane1 |
1300 // | |
1301 // |__________________|
1302
1303 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001304 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001305 zout = min(DEPTH_GEMM3D - 1, zout);
1306
1307 // Add offset due to the cross plane paddings
1308 zout *= (dst_cross_plane_pad * dst_stride_y);
1309
1310 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1311 // multiply dst_stride_z by DEPTH_GEMM3D
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001312 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001313
Giorgio Arena6200fa42018-07-06 17:06:36 +01001314 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001315 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
1316 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001318 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
1319 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001320#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1321#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001322 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
1323 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001324#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1325#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001326 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
1327 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
Giorgio Arena6200fa42018-07-06 17:06:36 +01001328#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001329
1330#else // defined(REINTERPRET_OUTPUT_AS_3D)
1331 // Add offset for batched GEMM
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001332 dst_addr += z * dst_stride_z;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001333
1334 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001335 vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
1336 vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001337#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001338 vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
1339 vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001340#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001342 vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
1343 vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001346 vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
1347 vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01001348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001349#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1350}
Georgios Pinitasdaa38552018-08-28 17:43:18 +01001351#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco05288a22017-11-21 10:57:50 +00001352#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
1353
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001354#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N)
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001355
1356#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1357
1358#if K0 == 2
1359#define ARM_DOT_K0(a, b, c) \
1360 ({ \
1361 ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \
1362 })
1363#elif K0 == 3 // K0 == 3
1364#define ARM_DOT_K0(a, b, c) \
1365 ({ \
1366 ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \
1367 })
1368#elif K0 == 4 // K0 == 4
1369#define ARM_DOT_K0(a, b, c) \
1370 ({ \
1371 ARM_DOT(a, b, c); \
1372 })
1373#elif K0 == 8 // K0 == 8
1374#define ARM_DOT_K0(a, b, c) \
1375 ({ \
1376 ARM_DOT(a.s0123, b.s0123, c); \
1377 ARM_DOT(a.s4567, b.s4567, c); \
1378 })
1379#elif K0 == 16 // K0 == 16
1380#define ARM_DOT_K0(a, b, c) \
1381 ({ \
1382 ARM_DOT(a.s0123, b.s0123, c); \
1383 ARM_DOT(a.s4567, b.s4567, c); \
1384 ARM_DOT(a.s89AB, b.s89AB, c); \
1385 ARM_DOT(a.sCDEF, b.sCDEF, c); \
1386 })
1387#else // K0 not supported
1388#error "K0 value not supported"
1389#endif // K0
1390
1391#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1392
1393#if K0 == 2
1394#define ARM_DOT_K0(a, b, c) \
1395 ({ \
1396 c += (uint)a.s0 * b.s0; \
1397 c += (uint)a.s1 * b.s1; \
1398 })
1399#elif K0 == 3 // K0 == 3
1400#define ARM_DOT_K0(a, b, c) \
1401 ({ \
1402 c += (uint)a.s0 * b.s0; \
1403 c += (uint)a.s1 * b.s1; \
1404 c += (uint)a.s2 * b.s2; \
1405 })
1406#elif K0 == 4 // K0 == 4
1407#define ARM_DOT_K0(a, b, c) \
1408 ({ \
1409 c += (uint)a.s0 * b.s0; \
1410 c += (uint)a.s1 * b.s1; \
1411 c += (uint)a.s2 * b.s2; \
1412 c += (uint)a.s3 * b.s3; \
1413 })
1414#elif K0 == 8 // K0 == 8
1415#define ARM_DOT_K0(a, b, c) \
1416 ({ \
1417 c += (uint)a.s0 * b.s0; \
1418 c += (uint)a.s1 * b.s1; \
1419 c += (uint)a.s2 * b.s2; \
1420 c += (uint)a.s3 * b.s3; \
1421 c += (uint)a.s4 * b.s4; \
1422 c += (uint)a.s5 * b.s5; \
1423 c += (uint)a.s6 * b.s6; \
1424 c += (uint)a.s7 * b.s7; \
1425 })
1426#elif K0 == 16 // K0 == 16
1427#define ARM_DOT_K0(a, b, c) \
1428 ({ \
1429 c += (uint)a.s0 * b.s0; \
1430 c += (uint)a.s1 * b.s1; \
1431 c += (uint)a.s2 * b.s2; \
1432 c += (uint)a.s3 * b.s3; \
1433 c += (uint)a.s4 * b.s4; \
1434 c += (uint)a.s5 * b.s5; \
1435 c += (uint)a.s6 * b.s6; \
1436 c += (uint)a.s7 * b.s7; \
1437 c += (uint)a.s8 * b.s8; \
1438 c += (uint)a.s9 * b.s9; \
1439 c += (uint)a.sA * b.sA; \
1440 c += (uint)a.sB * b.sB; \
1441 c += (uint)a.sC * b.sC; \
1442 c += (uint)a.sD * b.sD; \
1443 c += (uint)a.sE * b.sE; \
1444 c += (uint)a.sF * b.sF; \
1445 })
1446#else // K0 not supported
1447#error "K0 value not supported"
1448#endif // K0
1449
1450#endif //defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1451
1452#if N0 == 2
1453#define ARM_DOT_K0XN0(a, b, c) \
1454 ({ \
1455 ARM_DOT_K0((a), (b##0), (c.s0)); \
1456 ARM_DOT_K0((a), (b##1), (c.s1)); \
1457 })
1458#elif N0 == 3 // N0 == 3
1459#define ARM_DOT_K0XN0(a, b, c) \
1460 ({ \
1461 ARM_DOT_K0((a), (b##0), (c.s0)); \
1462 ARM_DOT_K0((a), (b##1), (c.s1)); \
1463 ARM_DOT_K0((a), (b##2), (c.s2)); \
1464 })
1465#elif N0 == 4 // N0 == 4
1466#define ARM_DOT_K0XN0(a, b, c) \
1467 ({ \
1468 ARM_DOT_K0((a), (b##0), (c.s0)); \
1469 ARM_DOT_K0((a), (b##1), (c.s1)); \
1470 ARM_DOT_K0((a), (b##2), (c.s2)); \
1471 ARM_DOT_K0((a), (b##3), (c.s3)); \
1472 })
1473#elif N0 == 8 // N0 == 8
1474#define ARM_DOT_K0XN0(a, b, c) \
1475 ({ \
1476 ARM_DOT_K0((a), (b##0), (c.s0)); \
1477 ARM_DOT_K0((a), (b##1), (c.s1)); \
1478 ARM_DOT_K0((a), (b##2), (c.s2)); \
1479 ARM_DOT_K0((a), (b##3), (c.s3)); \
1480 ARM_DOT_K0((a), (b##4), (c.s4)); \
1481 ARM_DOT_K0((a), (b##5), (c.s5)); \
1482 ARM_DOT_K0((a), (b##6), (c.s6)); \
1483 ARM_DOT_K0((a), (b##7), (c.s7)); \
1484 })
1485#elif N0 == 16 // N0 == 16
1486#define ARM_DOT_K0XN0(a, b, c) \
1487 ({ \
1488 ARM_DOT_K0((a), (b##0), (c.s0)); \
1489 ARM_DOT_K0((a), (b##1), (c.s1)); \
1490 ARM_DOT_K0((a), (b##2), (c.s2)); \
1491 ARM_DOT_K0((a), (b##3), (c.s3)); \
1492 ARM_DOT_K0((a), (b##4), (c.s4)); \
1493 ARM_DOT_K0((a), (b##5), (c.s5)); \
1494 ARM_DOT_K0((a), (b##6), (c.s6)); \
1495 ARM_DOT_K0((a), (b##7), (c.s7)); \
1496 ARM_DOT_K0((a), (b##8), (c.s8)); \
1497 ARM_DOT_K0((a), (b##9), (c.s9)); \
1498 ARM_DOT_K0((a), (b##A), (c.sA)); \
1499 ARM_DOT_K0((a), (b##B), (c.sB)); \
1500 ARM_DOT_K0((a), (b##C), (c.sC)); \
1501 ARM_DOT_K0((a), (b##D), (c.sD)); \
1502 ARM_DOT_K0((a), (b##E), (c.sE)); \
1503 ARM_DOT_K0((a), (b##F), (c.sF)); \
1504 })
1505#else // N0 not supported
1506#error "N0 value not supported"
1507#endif // N0 conditions
1508
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001509/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM data type .
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001510 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1511 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1512 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001513 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1514 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001515 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1516 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1517 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1518 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1519 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1520 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1521 * - M0 = 2, 3, 4, 5, 6, 7, 8
1522 * - N0 = 2, 3, 4, 8, 16
1523 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001524 * - V0 >= 1
1525 * - H0 >= 1
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001526 *
1527 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1528 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1529 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1530 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1531 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1532 *
1533 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8
1534 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1535 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1536 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1537 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1538 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1539 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1540 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1541 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1542 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1543 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1544 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1545 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1546 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1547 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1548 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1549 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1550 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1551 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1552 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1553 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1554 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1555 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1556 */
1557__kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1558 IMAGE_DECLARATION(rhs),
1559 IMAGE_DECLARATION(dst),
1560 uint k,
1561 uint lhs_stride_z,
1562 uint rhs_stride_z,
1563 uint dst_stride_z
1564#if defined(REINTERPRET_OUTPUT_AS_3D)
1565 ,
1566 uint dst_cross_plane_pad
1567#endif // REINTERPRET_OUTPUT_AS_3D
1568 )
1569{
1570 // Block size
1571#define LHS_BLOCK_SIZE ((K0) * (M0))
1572
1573#if defined(LHS_INTERLEAVE)
1574#define LHS_OFFSET_X (K0)
1575#define LHS_STEP_X ((K0) * (V0))
1576#define LHS_STEP_LOOP (1)
1577#else // defined(INTERLEAVE)
1578#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1579#define LHS_STEP_X (K0)
1580#define LHS_STEP_LOOP (V0)
1581#endif // defined(INTERLEAVE)
1582
1583 // Block size
1584#define RHS_BLOCK_SIZE ((K0) * (N0))
1585
1586 // RHS offset and step X
1587#if defined(RHS_INTERLEAVE)
1588#define RHS_OFFSET_X (K0)
1589#define RHS_STEP_X ((K0) * (H0))
1590#define RHS_STEP_LOOP (1)
1591#else // defined(RHS_INTERLEAVE)
1592#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1593#define RHS_STEP_X (K0)
1594#define RHS_STEP_LOOP (H0)
1595#endif // defined(RHS_INTERLEAVE)
1596
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001597#if defined(DUMMY_WORK_ITEMS)
1598 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1599 {
1600 return;
1601 }
1602#endif // defined(DUMMY_WORK_ITEMS)
1603
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001604 // Compute LHS matrix address
1605 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X + (get_global_id(1) / V0) * (uint)lhs_stride_y + (get_global_id(
1606 2)
1607 * lhs_stride_z);
1608
1609 // Compute RHS matrix address
1610 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1611
1612#if defined(MATRIX_B_DEPTH)
1613 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1614 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1615#else // defined(MATRIX_B_DEPTH)
1616 rhs_addr += get_global_id(2) * rhs_stride_z;
1617#endif // defined(MATRIX_B_DEPTH)
1618
1619 // Initialize the accumulators
1620 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1621
1622 for(int i = 0; i < k; i += K0)
1623 {
1624 // Supported cases (M0, K0):
1625 // 2,4 - 2,8 - 2,16
1626 // 3,4 - 3,8 - 3,16
1627 // 4,4 - 4,8 - 4,16
1628 // 5,4 - 5,8 - 5,16
1629 // 6,4 - 6,8 - 6,16
1630 // Load values from LHS matrix
1631 VEC_DATA_TYPE(uchar, K0)
1632 a0 = VLOAD(K0)(0, lhs_addr + 0 * LHS_STEP_X);
1633#if M0 > 1
1634 VEC_DATA_TYPE(uchar, K0)
1635 a1 = VLOAD(K0)(0, lhs_addr + 1 * LHS_STEP_X);
1636#endif // M0 > 1
1637#if M0 > 2
1638 VEC_DATA_TYPE(uchar, K0)
1639 a2 = VLOAD(K0)(0, lhs_addr + 2 * LHS_STEP_X);
1640#endif // M0 > 2
1641#if M0 > 3
1642 VEC_DATA_TYPE(uchar, K0)
1643 a3 = VLOAD(K0)(0, lhs_addr + 3 * LHS_STEP_X);
1644#endif // M0 > 3
1645#if M0 > 4
1646 VEC_DATA_TYPE(uchar, K0)
1647 a4 = VLOAD(K0)(0, lhs_addr + 4 * LHS_STEP_X);
1648#endif // M0 > 4
1649#if M0 > 5
1650 VEC_DATA_TYPE(uchar, K0)
1651 a5 = VLOAD(K0)(0, lhs_addr + 5 * LHS_STEP_X);
1652#endif // M0 > 5
1653#if M0 > 6
1654 VEC_DATA_TYPE(uchar, K0)
1655 a6 = VLOAD(K0)(0, lhs_addr + 6 * LHS_STEP_X);
1656#endif // M0 > 6
1657#if M0 > 7
1658 VEC_DATA_TYPE(uchar, K0)
1659 a7 = VLOAD(K0)(0, lhs_addr + 7 * LHS_STEP_X);
1660#endif // M0 > 7
1661
1662 // Load values from RHS matrix
1663 VEC_DATA_TYPE(uchar, K0)
1664 b0 = VLOAD(K0)(0, rhs_addr + 0 * RHS_STEP_X);
1665 VEC_DATA_TYPE(uchar, K0)
1666 b1 = VLOAD(K0)(0, rhs_addr + 1 * RHS_STEP_X);
1667#if N0 > 2
1668 VEC_DATA_TYPE(uchar, K0)
1669 b2 = VLOAD(K0)(0, rhs_addr + 2 * RHS_STEP_X);
1670#endif // N0 > 2
1671#if N0 > 3
1672 VEC_DATA_TYPE(uchar, K0)
1673 b3 = VLOAD(K0)(0, rhs_addr + 3 * RHS_STEP_X);
1674#endif // N0 > 3
1675#if N0 > 4
1676 VEC_DATA_TYPE(uchar, K0)
1677 b4 = VLOAD(K0)(0, rhs_addr + 4 * RHS_STEP_X);
1678 VEC_DATA_TYPE(uchar, K0)
1679 b5 = VLOAD(K0)(0, rhs_addr + 5 * RHS_STEP_X);
1680 VEC_DATA_TYPE(uchar, K0)
1681 b6 = VLOAD(K0)(0, rhs_addr + 6 * RHS_STEP_X);
1682 VEC_DATA_TYPE(uchar, K0)
1683 b7 = VLOAD(K0)(0, rhs_addr + 7 * RHS_STEP_X);
1684#endif // N0 > 4
1685#if N0 > 8
1686 VEC_DATA_TYPE(uchar, K0)
1687 b8 = VLOAD(K0)(0, rhs_addr + 8 * RHS_STEP_X);
1688 VEC_DATA_TYPE(uchar, K0)
1689 b9 = VLOAD(K0)(0, rhs_addr + 9 * RHS_STEP_X);
1690 VEC_DATA_TYPE(uchar, K0)
1691 bA = VLOAD(K0)(0, rhs_addr + 10 * RHS_STEP_X);
1692 VEC_DATA_TYPE(uchar, K0)
1693 bB = VLOAD(K0)(0, rhs_addr + 11 * RHS_STEP_X);
1694 VEC_DATA_TYPE(uchar, K0)
1695 bC = VLOAD(K0)(0, rhs_addr + 12 * RHS_STEP_X);
1696 VEC_DATA_TYPE(uchar, K0)
1697 bD = VLOAD(K0)(0, rhs_addr + 13 * RHS_STEP_X);
1698 VEC_DATA_TYPE(uchar, K0)
1699 bE = VLOAD(K0)(0, rhs_addr + 14 * RHS_STEP_X);
1700 VEC_DATA_TYPE(uchar, K0)
1701 bF = VLOAD(K0)(0, rhs_addr + 15 * RHS_STEP_X);
1702#endif // N0 > 8
1703
1704 // Accumulate
1705 ARM_DOT_K0XN0(a0, b, c0);
1706#if M0 > 1
1707 ARM_DOT_K0XN0(a1, b, c1);
1708#endif // M0 > 1
1709#if M0 > 2
1710 ARM_DOT_K0XN0(a2, b, c2);
1711#endif // M0 > 2
1712#if M0 > 3
1713 ARM_DOT_K0XN0(a3, b, c3);
1714#endif // M0 > 3
1715#if M0 > 4
1716 ARM_DOT_K0XN0(a4, b, c4);
1717#endif // M0 > 4
1718#if M0 > 5
1719 ARM_DOT_K0XN0(a5, b, c5);
1720#endif // M0 > 5
1721#if M0 > 6
1722 ARM_DOT_K0XN0(a6, b, c6);
1723#endif // M0 > 6
1724#if M0 > 7
1725 ARM_DOT_K0XN0(a7, b, c7);
1726#endif // M0 > 7
1727
1728 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP);
1729 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP);
1730 }
1731
1732 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(int)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1733
1734 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1735
1736#if defined(REINTERPRET_OUTPUT_AS_3D)
1737 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1738 // in order to take into account the presence of possible cross plane paddings
1739 //
1740 // | |
1741 // | plane0 |
1742 // | |
1743 // |__________________|
1744 // |******************|
1745 // | cross_plane_pad |
1746 // |******************|
1747 // | |
1748 // | plane1 |
1749 // | |
1750 // |__________________|
1751
1752 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1753 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1754 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
1755 zout0 *= (dst_cross_plane_pad * dst_stride_y);
1756#if M0 > 1
1757 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1758 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
1759 zout1 *= (dst_cross_plane_pad * dst_stride_y);
1760#endif // M0 > 1
1761#if M0 > 2
1762 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1763 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
1764 zout2 *= (dst_cross_plane_pad * dst_stride_y);
1765#endif // M0 > 2
1766#if M0 > 3
1767 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1768 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
1769 zout3 *= (dst_cross_plane_pad * dst_stride_y);
1770#endif // M0 > 3
1771#if M0 > 4
1772 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1773 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
1774 zout4 *= (dst_cross_plane_pad * dst_stride_y);
1775#endif // M0 > 4
1776#if M0 > 5
1777 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1778 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
1779 zout5 *= (dst_cross_plane_pad * dst_stride_y);
1780#endif // M0 > 5
1781#if M0 > 6
1782 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1783 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
1784 zout6 *= (dst_cross_plane_pad * dst_stride_y);
1785#endif // M0 > 6
1786#if M0 > 7
1787 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1788 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1789 zout7 *= (dst_cross_plane_pad * dst_stride_y);
1790#endif // M0 > 7
1791
1792 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1793 // multiply dst_stride_z by DEPTH_GEMM3D
1794 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1795
1796#else // defined(REINTERPRET_OUTPUT_AS_3D)
1797
1798 // Add offset for batched GEMM
1799 dst_addr += get_global_id(2) * dst_stride_z;
1800
1801#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1802
1803 // Store output block
1804 VSTORE(N0)
1805 (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0));
1806#if M0 > 1
1807 VSTORE(N0)
1808 (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1));
1809#endif // M0 > 1
1810#if M0 > 2
1811 VSTORE(N0)
1812 (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2));
1813#endif // M0 > 2
1814#if M0 > 3
1815 VSTORE(N0)
1816 (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3));
1817#endif // M0 > 3
1818#if M0 > 4
1819 VSTORE(N0)
1820 (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4));
1821#endif // M0 > 4
1822#if M0 > 5
1823 VSTORE(N0)
1824 (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5));
1825#endif // M0 > 5
1826#if M0 > 6
1827 VSTORE(N0)
1828 (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6));
1829#endif // M0 > 6
1830#if M0 > 7
1831 VSTORE(N0)
1832 (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7));
1833#endif // M0 > 7
1834
1835#undef LHS_BLOCK_SIZE
1836#undef LHS_OFFSET_X
1837#undef LHS_STEP_X
1838#undef RHS_BLOCK_SIZE
1839#undef RHS_OFFSET_X
1840#undef RHS_STEP_X
1841}
1842
1843#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001844/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM8 data type using the dot8 instruction.
Gian Marco Iodicedb63b9c2019-01-17 09:47:04 +00001845 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1846 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1847 *
1848 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1849 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1850 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1851 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1852 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1853 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1854 * - M0 = 2, 3, 4, 5, 6, 7, 8
1855 * - N0 = 2, 3, 4, 8, 16
1856 * - K0 = 2, 3, 4, 8, 16
1857 *
1858 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1859 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1860 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1861 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1862 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1863 *
1864 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8
1865 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1866 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1867 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1868 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1869 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1870 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1871 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1872 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1873 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1874 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1875 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1876 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1877 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1878 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1879 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1880 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1881 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1882 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1883 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1884 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1885 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1886 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1887 */
1888__kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t_dot8(IMAGE_DECLARATION(lhs),
1889 IMAGE_DECLARATION(rhs),
1890 IMAGE_DECLARATION(dst),
1891 uint k,
1892 uint lhs_stride_z,
1893 uint rhs_stride_z,
1894 uint dst_stride_z
1895#if defined(REINTERPRET_OUTPUT_AS_3D)
1896 ,
1897 uint dst_cross_plane_pad
1898#endif // REINTERPRET_OUTPUT_AS_3D
1899 )
1900{
1901 // Note: ARM_DOT_K0XN0 is generated with the dot8 instruction
1902 gemmlowp_mm_reshaped_lhs_nt_rhs_t(lhs_ptr,
1903 lhs_stride_x,
1904 lhs_step_x,
1905 lhs_stride_y,
1906 lhs_step_y,
1907 lhs_offset_first_element_in_bytes,
1908 rhs_ptr,
1909 rhs_stride_x,
1910 rhs_step_x,
1911 rhs_stride_y,
1912 rhs_step_y,
1913 rhs_offset_first_element_in_bytes,
1914 dst_ptr,
1915 dst_stride_x,
1916 dst_step_x,
1917 dst_stride_y,
1918 dst_step_y,
1919 dst_offset_first_element_in_bytes,
1920 k,
1921 lhs_stride_z,
1922 rhs_stride_z,
1923 dst_stride_z
1924#if defined(REINTERPRET_OUTPUT_AS_3D)
1925 ,
1926 dst_cross_plane_pad
1927#endif // REINTERPRET_OUTPUT_AS_3D
1928 );
1929}
1930#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1931#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K)
1932
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001933#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(K)
1934
1935#define CONCAT(a, b) a##b
1936
1937#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1938
1939#define ARM_DOT1(a, b, c) \
1940 ({ \
1941 ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \
1942 })
1943#define ARM_DOT2(a, b, c) \
1944 ({ \
1945 ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \
1946 })
1947#define ARM_DOT3(a, b, c) \
1948 ({ \
1949 ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \
1950 })
1951#define ARM_DOT4(a, b, c) \
1952 ({ \
1953 ARM_DOT(a, b, c); \
1954 })
1955#define ARM_DOT8(a, b, c) \
1956 ({ \
1957 ARM_DOT4((a.lo), (b.lo), c); \
1958 ARM_DOT4((a.hi), (b.hi), c); \
1959 })
1960#define ARM_DOT16(a, b, c) \
1961 ({ \
1962 ARM_DOT8((a.lo), (b.lo), c); \
1963 ARM_DOT8((a.hi), (b.hi), c); \
1964 })
1965
1966#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1967
1968#define ARM_DOT1(a, b, c) \
1969 ({ \
1970 c += (uint)a.s0 * b.s0; \
1971 })
1972#define ARM_DOT2(a, b, c) \
1973 ({ \
1974 ARM_DOT1(a, b, c); \
1975 c += (uint)a.s1 * b.s1; \
1976 })
1977#define ARM_DOT3(a, b, c) \
1978 ({ \
1979 ARM_DOT2(a, b, c); \
1980 c += (uint)a.s2 * b.s2; \
1981 })
1982#define ARM_DOT4(a, b, c) \
1983 ({ \
1984 ARM_DOT3(a, b, c); \
1985 c += (uint)a.s3 * b.s3; \
1986 })
1987#define ARM_DOT8(a, b, c) \
1988 ({ \
1989 ARM_DOT4((a.lo), (b.lo), c); \
1990 ARM_DOT4((a.hi), (b.hi), c); \
1991 })
1992#define ARM_DOT16(a, b, c) \
1993 ({ \
1994 ARM_DOT8((a.lo), (b.lo), c); \
1995 ARM_DOT8((a.hi), (b.hi), c); \
1996 })
1997#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
1998
1999#if N0 == 2
2000#define ARM_DOT_K0XN0(k0, a, b, c) \
2001 ({ \
2002 CONCAT(ARM_DOT, k0) \
2003 ((a), (b##0), (c.s0)); \
2004 CONCAT(ARM_DOT, k0) \
2005 ((a), (b##1), (c.s1)); \
2006 })
2007#elif N0 == 3 // N0 == 3
2008#define ARM_DOT_K0XN0(k0, a, b, c) \
2009 ({ \
2010 CONCAT(ARM_DOT, k0) \
2011 ((a), (b##0), (c.s0)); \
2012 CONCAT(ARM_DOT, k0) \
2013 ((a), (b##1), (c.s1)); \
2014 CONCAT(ARM_DOT, k0) \
2015 ((a), (b##2), (c.s2)); \
2016 })
2017#elif N0 == 4 // N0 == 4
2018#define ARM_DOT_K0XN0(k0, a, b, c) \
2019 ({ \
2020 CONCAT(ARM_DOT, k0) \
2021 ((a), (b##0), (c.s0)); \
2022 CONCAT(ARM_DOT, k0) \
2023 ((a), (b##1), (c.s1)); \
2024 CONCAT(ARM_DOT, k0) \
2025 ((a), (b##2), (c.s2)); \
2026 CONCAT(ARM_DOT, k0) \
2027 ((a), (b##3), (c.s3)); \
2028 })
2029#elif N0 == 8 // N0 == 8
2030#define ARM_DOT_K0XN0(k0, a, b, c) \
2031 ({ \
2032 CONCAT(ARM_DOT, k0) \
2033 ((a), (b##0), (c.s0)); \
2034 CONCAT(ARM_DOT, k0) \
2035 ((a), (b##1), (c.s1)); \
2036 CONCAT(ARM_DOT, k0) \
2037 ((a), (b##2), (c.s2)); \
2038 CONCAT(ARM_DOT, k0) \
2039 ((a), (b##3), (c.s3)); \
2040 CONCAT(ARM_DOT, k0) \
2041 ((a), (b##4), (c.s4)); \
2042 CONCAT(ARM_DOT, k0) \
2043 ((a), (b##5), (c.s5)); \
2044 CONCAT(ARM_DOT, k0) \
2045 ((a), (b##6), (c.s6)); \
2046 CONCAT(ARM_DOT, k0) \
2047 ((a), (b##7), (c.s7)); \
2048 })
2049#elif N0 == 16 // N0 == 16
2050#define ARM_DOT_K0XN0(k0, a, b, c) \
2051 ({ \
2052 CONCAT(ARM_DOT, k0) \
2053 ((a), (b##0), (c.s0)); \
2054 CONCAT(ARM_DOT, k0) \
2055 ((a), (b##1), (c.s1)); \
2056 CONCAT(ARM_DOT, k0) \
2057 ((a), (b##2), (c.s2)); \
2058 CONCAT(ARM_DOT, k0) \
2059 ((a), (b##3), (c.s3)); \
2060 CONCAT(ARM_DOT, k0) \
2061 ((a), (b##4), (c.s4)); \
2062 CONCAT(ARM_DOT, k0) \
2063 ((a), (b##5), (c.s5)); \
2064 CONCAT(ARM_DOT, k0) \
2065 ((a), (b##6), (c.s6)); \
2066 CONCAT(ARM_DOT, k0) \
2067 ((a), (b##7), (c.s7)); \
2068 CONCAT(ARM_DOT, k0) \
2069 ((a), (b##8), (c.s8)); \
2070 CONCAT(ARM_DOT, k0) \
2071 ((a), (b##9), (c.s9)); \
2072 CONCAT(ARM_DOT, k0) \
2073 ((a), (b##A), (c.sA)); \
2074 CONCAT(ARM_DOT, k0) \
2075 ((a), (b##B), (c.sB)); \
2076 CONCAT(ARM_DOT, k0) \
2077 ((a), (b##C), (c.sC)); \
2078 CONCAT(ARM_DOT, k0) \
2079 ((a), (b##D), (c.sD)); \
2080 CONCAT(ARM_DOT, k0) \
2081 ((a), (b##E), (c.sE)); \
2082 CONCAT(ARM_DOT, k0) \
2083 ((a), (b##F), (c.sF)); \
2084 })
2085#else // N0 not supported
2086#error "N0 value not supported"
2087#endif // N0 conditions
2088
2089/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2090 * The LHS matrix is NOT reshaped
2091 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
2092 *
2093 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
2094 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
2095 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
2096 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2097 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2098 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2099 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2100 * - N0 = 2, 3, 4, 8, 16
2101 * - K0 = 2, 3, 4, 8, 16
2102 * - H0 >= 1
2103 *
2104 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2105 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2106 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2107 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2108 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2109 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2110 *
2111 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2112 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2113 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2114 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2115 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2116 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2117 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2118 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2119 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2120 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2121 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2122 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2123 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2124 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2125 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2126 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2127 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2128 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2129 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2130 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2131 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2132 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2133 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2134 */
2135__kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
2136 IMAGE_DECLARATION(rhs),
2137 IMAGE_DECLARATION(dst),
2138 uint lhs_stride_z,
2139 uint rhs_stride_z,
2140 uint dst_stride_z
2141#if defined(REINTERPRET_INPUT_AS_3D)
2142 ,
2143 uint lhs_cross_plane_pad
2144#endif // REINTERPRET_INPUT_AS_3D
2145#if defined(REINTERPRET_OUTPUT_AS_3D)
2146 ,
2147 uint dst_cross_plane_pad
2148#endif // REINTERPRET_OUTPUT_AS_3D
2149 )
2150{
2151 // Block size
2152#define RHS_BLOCK_SIZE ((K0) * (N0))
2153
2154 // RHS offset and step X
2155#if defined(RHS_INTERLEAVE)
2156#define RHS_OFFSET_X (K0)
2157#define RHS_STEP_X ((K0) * (H0))
2158#define RHS_STEP_LOOP (1)
2159#else // defined(RHS_INTERLEAVE)
2160#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2161#define RHS_STEP_X (K0)
2162#define RHS_STEP_LOOP (H0)
2163#endif // defined(RHS_INTERLEAVE)
2164
2165 uint x = get_global_id(0);
2166 uint y = get_global_id(1);
2167 uint z = get_global_id(2);
2168
Gian Marco Iodice86cfffe2019-04-02 11:02:20 +01002169#if defined(DUMMY_WORK_ITEMS)
2170 if((x * N0 >= N) || (y * M0 >= M))
2171 {
2172 return;
2173 }
2174#endif // defined(DUMMY_WORK_ITEMS)
2175
Gian Marco Iodice62251f72019-03-11 16:07:12 +00002176 // Compute LHS matrix address
2177 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2178
2179 // Compute RHS matrix address
2180 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X + (x / (uint)H0) * rhs_stride_y;
2181
2182#if defined(MATRIX_B_DEPTH)
2183 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2184 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2185#else // defined(MATRIX_B_DEPTH)
2186 rhs_offset += z * rhs_stride_z;
2187#endif // defined(MATRIX_B_DEPTH)
2188
2189 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2190
2191#if defined(REINTERPRET_INPUT_AS_3D)
2192 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2193 // in order to take into account the presence of possible cross plane paddings
2194 //
2195 // | |
2196 // | plane0 |
2197 // | |
2198 // |__________________|
2199 // |******************|
2200 // | cross_plane_pad |
2201 // |******************|
2202 // | |
2203 // | plane1 |
2204 // | |
2205 // |__________________|
2206
2207 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2208 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2209 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
2210 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
2211#if M0 > 1
2212 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2213 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
2214 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
2215#endif // M0 > 1
2216#if M0 > 2
2217 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2218 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
2219 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
2220#endif // M0 > 2
2221#if M0 > 3
2222 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2223 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
2224 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
2225#endif // M0 > 3
2226#if M0 > 4
2227 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2228 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
2229 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
2230#endif // M0 > 4
2231#if M0 > 5
2232 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2233 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
2234 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
2235#endif // M0 > 5
2236#if M0 > 6
2237 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2238 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
2239 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
2240#endif // M0 > 6
2241#if M0 > 7
2242 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2243 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2244 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
2245#endif // M0 > 7
2246
2247 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2248 // multiply lhs_stride_z by DEPTH_GEMM3D
2249 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2250
2251#else // defined(REINTERPRET_INPUT_AS_3D)
2252
2253 // Add offset for batched GEMM
2254 lhs_offset += z * lhs_stride_z;
2255
2256#endif // defined(REINTERPRET_INPUT_AS_3D)
2257
2258 // Initialize the accumulators
2259 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2260
2261 for(int i = 0; i < K; i += K0)
2262 {
2263 // Supported cases (M0, K0):
2264 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2265 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2266 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2267 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2268 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2269 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2270 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2271 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2272 // Load values from LHS matrix
2273 VEC_DATA_TYPE(uchar, K0)
2274 a0 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0);
2275#if M0 > 1
2276 VEC_DATA_TYPE(uchar, K0)
2277 a1 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1);
2278#endif // M0 > 1
2279#if M0 > 2
2280 VEC_DATA_TYPE(uchar, K0)
2281 a2 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2);
2282#endif // M0 > 2
2283#if M0 > 3
2284 VEC_DATA_TYPE(uchar, K0)
2285 a3 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3);
2286#endif // M0 > 3
2287#if M0 > 4
2288 VEC_DATA_TYPE(uchar, K0)
2289 a4 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4);
2290#endif // M0 > 4
2291#if M0 > 5
2292 VEC_DATA_TYPE(uchar, K0)
2293 a5 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5);
2294#endif // M0 > 5
2295#if M0 > 6
2296 VEC_DATA_TYPE(uchar, K0)
2297 a6 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6);
2298#endif // M0 > 6
2299#if M0 > 7
2300 VEC_DATA_TYPE(uchar, K0)
2301 a7 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7);
2302#endif // M0 > 7
2303
2304 // Load values from RHS matrix
2305 VEC_DATA_TYPE(uchar, K0)
2306 b0 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 0 * RHS_STEP_X);
2307 VEC_DATA_TYPE(uchar, K0)
2308 b1 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 1 * RHS_STEP_X);
2309#if N0 > 2
2310 VEC_DATA_TYPE(uchar, K0)
2311 b2 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 2 * RHS_STEP_X);
2312#endif // N0 > 2
2313#if N0 > 3
2314 VEC_DATA_TYPE(uchar, K0)
2315 b3 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 3 * RHS_STEP_X);
2316#endif // N0 > 3
2317#if N0 > 4
2318 VEC_DATA_TYPE(uchar, K0)
2319 b4 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 4 * RHS_STEP_X);
2320 VEC_DATA_TYPE(uchar, K0)
2321 b5 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 5 * RHS_STEP_X);
2322 VEC_DATA_TYPE(uchar, K0)
2323 b6 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 6 * RHS_STEP_X);
2324 VEC_DATA_TYPE(uchar, K0)
2325 b7 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 7 * RHS_STEP_X);
2326#endif // N0 > 4
2327#if N0 > 8
2328 VEC_DATA_TYPE(uchar, K0)
2329 b8 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 8 * RHS_STEP_X);
2330 VEC_DATA_TYPE(uchar, K0)
2331 b9 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 9 * RHS_STEP_X);
2332 VEC_DATA_TYPE(uchar, K0)
2333 bA = VLOAD(K0)(0, rhs_ptr + rhs_offset + 10 * RHS_STEP_X);
2334 VEC_DATA_TYPE(uchar, K0)
2335 bB = VLOAD(K0)(0, rhs_ptr + rhs_offset + 11 * RHS_STEP_X);
2336 VEC_DATA_TYPE(uchar, K0)
2337 bC = VLOAD(K0)(0, rhs_ptr + rhs_offset + 12 * RHS_STEP_X);
2338 VEC_DATA_TYPE(uchar, K0)
2339 bD = VLOAD(K0)(0, rhs_ptr + rhs_offset + 13 * RHS_STEP_X);
2340 VEC_DATA_TYPE(uchar, K0)
2341 bE = VLOAD(K0)(0, rhs_ptr + rhs_offset + 14 * RHS_STEP_X);
2342 VEC_DATA_TYPE(uchar, K0)
2343 bF = VLOAD(K0)(0, rhs_ptr + rhs_offset + 15 * RHS_STEP_X);
2344#endif // N0 > 8
2345
2346 // Accumulate
2347 ARM_DOT_K0XN0(K0, a0, b, c0);
2348#if M0 > 1
2349 ARM_DOT_K0XN0(K0, a1, b, c1);
2350#endif // M0 > 1
2351#if M0 > 2
2352 ARM_DOT_K0XN0(K0, a2, b, c2);
2353#endif // M0 > 2
2354#if M0 > 3
2355 ARM_DOT_K0XN0(K0, a3, b, c3);
2356#endif // M0 > 3
2357#if M0 > 4
2358 ARM_DOT_K0XN0(K0, a4, b, c4);
2359#endif // M0 > 4
2360#if M0 > 5
2361 ARM_DOT_K0XN0(K0, a5, b, c5);
2362#endif // M0 > 5
2363#if M0 > 6
2364 ARM_DOT_K0XN0(K0, a6, b, c6);
2365#endif // M0 > 6
2366#if M0 > 7
2367 ARM_DOT_K0XN0(K0, a7, b, c7);
2368#endif // M0 > 7
2369
2370 lhs_offset += K0;
2371 rhs_offset += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2372 }
2373
2374 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0) * sizeof(int) + (y * (uint)M0 * dst_stride_y);
2375
2376 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2377
2378#if defined(REINTERPRET_OUTPUT_AS_3D)
2379 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2380 // in order to take into account the presence of possible cross plane paddings
2381 //
2382 // | |
2383 // | plane0 |
2384 // | |
2385 // |__________________|
2386 // |******************|
2387 // | cross_plane_pad |
2388 // |******************|
2389 // | |
2390 // | plane1 |
2391 // | |
2392 // |__________________|
2393
2394 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2395 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2396 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
2397 zout0 *= (dst_cross_plane_pad * dst_stride_y);
2398#if M0 > 1
2399 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2400 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
2401 zout1 *= (dst_cross_plane_pad * dst_stride_y);
2402#endif // M0 > 1
2403#if M0 > 2
2404 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2405 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
2406 zout2 *= (dst_cross_plane_pad * dst_stride_y);
2407#endif // M0 > 2
2408#if M0 > 3
2409 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2410 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
2411 zout3 *= (dst_cross_plane_pad * dst_stride_y);
2412#endif // M0 > 3
2413#if M0 > 4
2414 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2415 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
2416 zout4 *= (dst_cross_plane_pad * dst_stride_y);
2417#endif // M0 > 4
2418#if M0 > 5
2419 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2420 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
2421 zout5 *= (dst_cross_plane_pad * dst_stride_y);
2422#endif // M0 > 5
2423#if M0 > 6
2424 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2425 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
2426 zout6 *= (dst_cross_plane_pad * dst_stride_y);
2427#endif // M0 > 6
2428#if M0 > 7
2429 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2430 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2431 zout7 *= (dst_cross_plane_pad * dst_stride_y);
2432#endif // M0 > 7
2433
2434 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2435 // multiply dst_stride_z by DEPTH_GEMM3D
2436 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2437
2438#else // defined(REINTERPRET_OUTPUT_AS_3D)
2439
2440 // Add offset for batched GEMM
2441 dst_addr += z * dst_stride_z;
2442
2443#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2444
2445 // Store output block
2446 VSTORE(N0)
2447 (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0));
2448#if M0 > 1
2449 VSTORE(N0)
2450 (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1));
2451#endif // M0 > 1
2452#if M0 > 2
2453 VSTORE(N0)
2454 (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2));
2455#endif // M0 > 2
2456#if M0 > 3
2457 VSTORE(N0)
2458 (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3));
2459#endif // M0 > 3
2460#if M0 > 4
2461 VSTORE(N0)
2462 (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4));
2463#endif // M0 > 4
2464#if M0 > 5
2465 VSTORE(N0)
2466 (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5));
2467#endif // M0 > 5
2468#if M0 > 6
2469 VSTORE(N0)
2470 (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6));
2471#endif // M0 > 6
2472#if M0 > 7
2473 VSTORE(N0)
2474 (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7));
2475#endif // M0 > 7
2476
2477#undef RHS_BLOCK_SIZE
2478#undef RHS_OFFSET_X
2479#undef RHS_STEP_X
2480}
2481#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(K)
2482
Gian Marco05288a22017-11-21 10:57:50 +00002483#if defined(COLS_A)
2484/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
2485 *
2486 * @note This stage is needed to handle the offset of matrix product
2487 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2488 *
2489 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
2490 *
2491 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2492 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2493 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2494 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2495 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2496 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2497 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2498 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2499 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2500 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2501 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2502 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2503 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2504 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2505 */
2506__kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
2507 IMAGE_DECLARATION(dst))
2508{
2509 // Compute source and destination addresses
2510 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2511 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2512
2513 uint4 sum_row_u32 = (uint4)0;
2514 uint sum_row = 0;
2515
2516 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
2517
2518 int i = 0;
2519
2520 // This for loop performs 16 accumulations
2521 for(; i <= ((int)COLS_A - 16); i += 16)
2522 {
2523 const uchar16 a0_u8 = vload16(0, matrix_a + i);
2524
2525 sum_row_u32 += convert_uint4(a0_u8.s0123) + convert_uint4(a0_u8.s4567) + convert_uint4(a0_u8.s89AB) + convert_uint4(a0_u8.sCDEF);
2526 }
2527
2528 // This for loop performs the leftover accumulations
2529 for(; i < COLS_A; ++i)
2530 {
2531 sum_row += matrix_a[i];
2532 }
2533
2534 sum_row += sum_row_u32.s0 + sum_row_u32.s1 + sum_row_u32.s2 + sum_row_u32.s3;
2535
2536 *((__global int *)dst.ptr) = (int)sum_row;
2537}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002538
2539#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2540/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction
2541 *
2542 * @note This stage is needed to handle the offset of matrix product
2543 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2544 *
2545 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
2546 *
2547 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2548 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2549 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2550 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2551 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2552 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2553 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2554 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2555 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2556 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2557 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2558 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2559 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2560 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2561 */
2562__kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
2563 IMAGE_DECLARATION(dst))
2564{
2565 // Compute source and destination addresses
2566 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2567 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2568
2569 uint sum_row = 0;
2570
2571 __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z);
2572
2573 int i = 0;
2574
2575 // This for loop performs 16 accumulations
2576 for(; i <= ((int)COLS_A - 32); i += 32)
2577 {
2578 uchar16 a0_u8 = vload16(0, matrix_a + i);
2579
2580 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
2581 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
2582 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
2583 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
2584
2585 a0_u8 = vload16(1, matrix_a + i);
2586
2587 sum_row += arm_dot(a0_u8.s0123, (uchar4)(1));
2588 sum_row += arm_dot(a0_u8.s4567, (uchar4)(1));
2589 sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1));
2590 sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1));
2591 }
2592
2593 // This for loop performs the leftover accumulations
2594 for(; i < COLS_A; ++i)
2595 {
2596 sum_row += matrix_a[i];
2597 }
2598
2599 *((__global int *)dst.ptr) = (int)sum_row;
2600}
2601#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
Gian Marco05288a22017-11-21 10:57:50 +00002602#endif // defined(COLS_A)
2603
2604#if defined(COLS_B) && defined(ROWS_B)
2605/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
2606 *
2607 * @note This stage is needed to handle the offset of matrix product
2608 * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
2609 *
2610 * @attention The number of matrix B columns and rows needs to be passed at compile time using -DCOLS_B and -DROWS_B
2611 *
2612 * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
2613 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2614 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2615 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2616 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2617 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2618 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2619 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2620 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32
2621 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2622 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2623 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2624 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2625 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2626 */
2627__kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src),
2628 IMAGE_DECLARATION(dst))
2629{
2630 // Compute source and destination addresses
2631 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2632 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2633
2634 uint16 sum_col_u32 = (uint16)0;
2635
2636 __global const uchar *matrix_b = (__global const uchar *)(src.ptr + get_global_id(1) * src_stride_z);
2637
2638 int i = 0;
2639 // This for loop performs 4 accumulations
2640 for(; i <= ((int)ROWS_B - 4); i += 4)
2641 {
2642 const uchar16 b0_u8 = vload16(0, matrix_b + 0 * src_stride_y);
2643 const uchar16 b1_u8 = vload16(0, matrix_b + 1 * src_stride_y);
2644 const uchar16 b2_u8 = vload16(0, matrix_b + 2 * src_stride_y);
2645 const uchar16 b3_u8 = vload16(0, matrix_b + 3 * src_stride_y);
2646
2647 sum_col_u32 += convert_uint16(b0_u8) + convert_uint16(b1_u8) + convert_uint16(b2_u8) + convert_uint16(b3_u8);
2648
2649 matrix_b += 4 * src_stride_y;
2650 }
2651
2652 // This for loop perfoms the leftover accumulations
2653 for(; i < (int)ROWS_B; ++i)
2654 {
2655 const uchar16 b0_u8 = vload16(0, matrix_b);
2656
2657 sum_col_u32 += convert_uint16(b0_u8);
2658
2659 matrix_b += src_stride_y;
2660 }
2661
2662 vstore16(convert_int16(sum_col_u32), 0, (__global int *)dst.ptr);
2663}
2664#endif // defined(COLS_B) && defined(ROWS_B)
2665
2666#if defined(K_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002667
2668/* Helper function used to calculate the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel.
2669 *
2670 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
2671 * and calculates the offset contribution of matrix A and matrix B.
2672 *
2673 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2674 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2675 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
2676 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
2677 *
2678 * @param[in] x get_global_id(0) * 4
2679 * @param[in] y get_global_id(1)
2680 * @param[in] z get_global_id(2)
2681 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2682 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2683 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2684 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2685 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2686 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2687 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2688 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2689 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2690 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2691 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2692 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2693 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2694 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2695 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2696 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
2697 */
2698inline int4 offset_contribution(
2699 int x,
2700 int y,
2701 int z
2702#if defined(A_OFFSET)
2703 ,
2704 IMAGE_DECLARATION(sum_col)
2705#endif // defined(A_OFFSET)
2706#if defined(B_OFFSET)
2707 ,
2708 IMAGE_DECLARATION(sum_row)
2709#endif // defined(B_OFFSET)
2710#if defined(ADD_BIAS)
2711 ,
2712 VECTOR_DECLARATION(biases)
2713#endif // defined(ADD_BIAS)
2714)
2715{
2716 int4 a_offset_s32 = (int4)0;
2717 int4 b_offset_s32 = (int4)0;
2718
2719 int batch_id = z;
2720#if defined(DEPTH_INPUT3D)
2721 batch_id /= (int)DEPTH_INPUT3D;
2722#endif // defined(DEPTH_INPUT3D)
2723
2724#if defined(A_OFFSET)
2725 // Compute the offset contribution due to A_OFFSET
2726 __global uchar *sum_col_addr = sum_col_ptr + sum_col_offset_first_element_in_bytes + x * sizeof(int);
2727
2728 // Compute the offset contribution due to A_OFFSET
2729#if defined(SUM_COL_HAS_BATCHES)
2730 a_offset_s32 = vload4(0, (__global int *)(sum_col_addr + batch_id * sum_col_stride_y));
2731#else // defined(SUM_COL_HAS_BATCHES)
2732 a_offset_s32 = vload4(0, (__global int *)sum_col_addr);
2733#endif // defined(SUM_COL_HAS_BATCHES)
2734
2735 a_offset_s32 *= (int4)A_OFFSET;
2736#endif // defined(A_OFFSET)
2737
2738#if defined(B_OFFSET)
2739 // Compute the offset contribution due to A_OFFSET
2740 __global uchar *sum_row_addr = sum_row_ptr + sum_row_offset_first_element_in_bytes + y * sizeof(int);
2741
2742 // Compute the offset contribution due to B_OFFSET
2743#if defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2744 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)) + (z % (int)DEPTH_INPUT3D) * (int)HEIGHT_INPUT3D);
2745#else // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2746 b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)));
2747#endif // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D)
2748 b_offset_s32 *= (int4)B_OFFSET;
2749#endif // defined(B_OFFSET)
2750
2751#if defined(ADD_BIAS)
2752 // Add bias
2753 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
2754
2755 int4 biases_values = vload4(0, (__global int *)bias_addr);
2756 b_offset_s32 += (int4)biases_values;
2757#endif // defined(ADD_BIAS)
2758
2759 return (int4)K_OFFSET + a_offset_s32 + b_offset_s32;
2760}
2761
Gian Marco05288a22017-11-21 10:57:50 +00002762/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel. The computation is performed in-place
2763 *
2764 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel),
2765 * and adds to it the offset contribution of matrix A and matrix B in-place.
2766 *
2767 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2768 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2769 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
Chunosov5124be52017-11-22 20:42:13 +07002770 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
Gian Marco05288a22017-11-21 10:57:50 +00002771 *
2772 * The final result is:
2773 *
2774 * mm_result[i][k] = mm_result[i][k] +
2775 * (sum_col[k] * A_OFFSET) +
2776 * (sum_row[i] * B_OFFSET) +
2777 * (K_OFFSET)
2778 *
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01002779 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
2780 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
2781 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
2782 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
2783 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
2784 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
2785 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
2786 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002787 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2788 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2789 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2790 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2791 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2792 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2793 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2794 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2795 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2796 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2797 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2798 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2799 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2800 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2801 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2802 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00002803 */
2804__kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result)
2805#if defined(A_OFFSET)
2806 ,
2807 IMAGE_DECLARATION(sum_col)
2808#endif // defined(A_OFFSET)
2809#if defined(B_OFFSET)
2810 ,
2811 IMAGE_DECLARATION(sum_row)
2812#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002813#if defined(ADD_BIAS)
2814 ,
2815 VECTOR_DECLARATION(biases)
2816#endif // defined(ADD_BIAS))
Gian Marco05288a22017-11-21 10:57:50 +00002817 )
2818{
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002819 const int x = get_global_id(0) * 4;
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +01002820 const int y = get_global_id(1);
2821 const int z = get_global_id(2);
2822
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002823 // Compute offset contribution
2824 int4 offset_term_s32 = offset_contribution(
2825 x, y, z
Gian Marco05288a22017-11-21 10:57:50 +00002826#if defined(A_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002827 ,
2828 sum_col_ptr,
2829 sum_col_stride_x,
2830 sum_col_step_x,
2831 sum_col_stride_y,
2832 sum_col_step_y,
2833 sum_col_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00002834#endif // defined(A_OFFSET)
Gian Marco05288a22017-11-21 10:57:50 +00002835#if defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002836 ,
2837 sum_row_ptr,
2838 sum_row_stride_x,
2839 sum_row_step_x,
2840 sum_row_stride_y,
2841 sum_row_step_y,
2842 sum_row_offset_first_element_in_bytes
Gian Marco05288a22017-11-21 10:57:50 +00002843#endif // defined(B_OFFSET)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002844#if defined(ADD_BIAS)
2845 ,
2846 biases_ptr,
2847 biases_stride_x,
2848 biases_step_x,
2849 biases_offset_first_element_in_bytes
2850#endif // defined(ADD_BIAS)
2851 );
Gian Marco05288a22017-11-21 10:57:50 +00002852
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002853 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00002854
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002855 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00002856
2857 // Add the offset terms to GEMM's result
2858 in_s32 += offset_term_s32;
2859
2860 // Store the result with the offset contribution
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002861 vstore4(in_s32, 0, (__global int *)mm_result_addr);
Gian Marco05288a22017-11-21 10:57:50 +00002862}
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002863
2864#if defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
2865/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
2866 *
2867 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
2868 *
2869 *
2870 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
2871 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
2872 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
2873 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
2874 *
2875 * The result before the output stage is:
2876 *
2877 * mm_result[i][k] = mm_result[i][k] +
2878 * (sum_col[k] * A_OFFSET) +
2879 * (sum_row[i] * B_OFFSET) +
2880 * (K_OFFSET)
2881 *
2882 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
2883 *
2884 * -# Add offset terms to final result
2885 * -# Multiply each entry of result by result_mult_int
2886 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
2887 * -# Shift the int32 accumulator by result_shift
2888 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
2889 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
2890 *
2891 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
2892 *
2893 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
2894 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
2895 * These values can be used to implement "rectified linear unit" activation functions
2896 *
2897 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
2898 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
2899 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
2900 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
2901 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
2902 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
2903 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
2904 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
2905 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2906 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2907 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
2908 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2909 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
2910 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2911 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
2912 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
2913 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
2914 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
2915 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
2916 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
2917 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
2918 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
2919 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
2920 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
2921 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
2922 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2923 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2924 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2925 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2926 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
2927 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2928 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2929 */
2930__kernel void gemmlowp_offset_contribution_quantize_down(TENSOR3D_DECLARATION(mm_result)
2931#if defined(A_OFFSET)
2932 ,
2933 IMAGE_DECLARATION(sum_col)
2934#endif // defined(A_OFFSET)
2935#if defined(B_OFFSET)
2936 ,
2937 IMAGE_DECLARATION(sum_row)
2938#endif // defined(B_OFFSET)
2939 ,
2940#if defined(ADD_BIAS)
2941 VECTOR_DECLARATION(biases),
2942#endif // defined(ADD_BIAS)
2943 TENSOR3D_DECLARATION(dst))
2944{
2945 const int x = get_global_id(0) * 4;
2946 const int y = get_global_id(1);
2947 const int z = get_global_id(2);
2948
2949 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
2950
2951 // Compute offset contribution
2952 int4 offset_term_s32 = offset_contribution(
2953 x, y, z
2954#if defined(A_OFFSET)
2955 ,
2956 sum_col_ptr,
2957 sum_col_stride_x,
2958 sum_col_step_x,
2959 sum_col_stride_y,
2960 sum_col_step_y,
2961 sum_col_offset_first_element_in_bytes
2962#endif // defined(A_OFFSET)
2963#if defined(B_OFFSET)
2964 ,
2965 sum_row_ptr,
2966 sum_row_stride_x,
2967 sum_row_step_x,
2968 sum_row_stride_y,
2969 sum_row_step_y,
2970 sum_row_offset_first_element_in_bytes
2971#endif // defined(B_OFFSET)
2972#if defined(ADD_BIAS)
2973 ,
2974 biases_ptr,
2975 biases_stride_x,
2976 biases_step_x,
2977 biases_offset_first_element_in_bytes
2978#endif // defined(ADD_BIAS)
2979 );
2980
2981 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
2982
2983 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
2984
2985 // Add the offset terms to GEMM's result
2986 in_s32 += offset_term_s32;
2987
2988 // -------------- OUTPUT STAGE
2989
2990 // Add the offset terms to GEMM's result
2991 in_s32 += (int4)RESULT_OFFSET;
2992
2993 // Multiply by result_mult_int and shift
2994 in_s32 *= RESULT_MULTIPLIER;
2995
2996 in_s32 >>= RESULT_SHIFT;
2997
2998 uchar4 res = convert_uchar4_sat(in_s32);
2999
3000#if defined(MIN_BOUND)
3001 res = max(res, (uchar4)MIN_BOUND);
3002#endif // defined(MIN_BOUND)
3003#if defined(MAX_BOUND)
3004 res = min(res, (uchar4)MAX_BOUND);
3005#endif // defined(MAX_BOUND)
3006
3007 // Store the result
3008 vstore4(res, 0, dst_addr);
3009}
3010
3011/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8.
3012 *
3013 * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage.
3014 *
3015 *
3016 * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200)
3017 * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1)
3018 * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6)
3019 * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches
3020 *
3021 * The result before the output stage is:
3022 *
3023 * mm_result[i][k] = mm_result[i][k] +
3024 * (sum_col[k] * A_OFFSET) +
3025 * (sum_row[i] * B_OFFSET) +
3026 * (K_OFFSET)
3027 *
3028 * This result is quantized down to uint8 using the output stage. The output stage computes the following operations:
3029 *
3030 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3031 * -# Add bias to final result if bias tensor is not a nullptr
3032 * -# Round to nearest division by a power-of-two using result_shift
3033 * -# Add offset to each result
3034 * -# Clamp the value between the specified min and max bounds
3035 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3036 *
3037 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3038 *
3039 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3040 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3041 * These values can be used to implement "rectified linear unit" activation functions
3042 *
3043 * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32
3044 * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes)
3045 * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes)
3046 * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes)
3047 * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes)
3048 * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes)
3049 * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes)
3050 * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor
3051 * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3052 * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3053 * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes)
3054 * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3055 * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes)
3056 * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3057 * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr
3058 * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes)
3059 * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes)
3060 * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes)
3061 * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes)
3062 * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor
3063 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3064 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3065 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3066 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
3067 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3068 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3069 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3070 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3071 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3072 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3073 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3074 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3075 */
3076__kernel void gemmlowp_offset_contribution_quantize_down_fixedpoint(TENSOR3D_DECLARATION(mm_result)
3077#if defined(A_OFFSET)
3078 ,
3079 IMAGE_DECLARATION(sum_col)
3080#endif // defined(A_OFFSET)
3081#if defined(B_OFFSET)
3082 ,
3083 IMAGE_DECLARATION(sum_row)
3084#endif // defined(B_OFFSET)
3085 ,
3086#if defined(ADD_BIAS)
3087 VECTOR_DECLARATION(biases),
3088#endif // defined(ADD_BIAS)
3089 TENSOR3D_DECLARATION(dst))
3090{
3091 const int x = get_global_id(0) * 4;
3092 const int y = get_global_id(1);
3093 const int z = get_global_id(2);
3094
3095 // Compute offset contribution
3096 int4 offset_term_s32 = offset_contribution(
3097 x, y, z
3098#if defined(A_OFFSET)
3099 ,
3100 sum_col_ptr,
3101 sum_col_stride_x,
3102 sum_col_step_x,
3103 sum_col_stride_y,
3104 sum_col_step_y,
3105 sum_col_offset_first_element_in_bytes
3106#endif // defined(A_OFFSET)
3107#if defined(B_OFFSET)
3108 ,
3109 sum_row_ptr,
3110 sum_row_stride_x,
3111 sum_row_step_x,
3112 sum_row_stride_y,
3113 sum_row_step_y,
3114 sum_row_offset_first_element_in_bytes
3115#endif // defined(B_OFFSET)
3116#if defined(ADD_BIAS)
3117 ,
3118 biases_ptr,
3119 biases_stride_x,
3120 biases_step_x,
3121 biases_offset_first_element_in_bytes
3122#endif // defined(ADD_BIAS)
3123 );
3124
3125 __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z;
3126
3127 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3128
3129 int4 in_s32 = vload4(0, (__global int *)mm_result_addr);
3130
3131 // Add the offset terms to GEMM's result
3132 in_s32 += offset_term_s32;
3133
3134 // -------------- OUTPUT STAGE
3135
3136 // Multiply by result_mult_int and shift
3137 in_s32 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(in_s32, RESULT_MULTIPLIER, RESULT_SHIFT, 4);
3138
3139 // Add the offset terms to GEMM's result
3140 in_s32 += (int4)RESULT_OFFSET;
3141
3142 uchar4 res = convert_uchar4_sat(in_s32);
3143
3144#if defined(MIN_BOUND)
3145 res = max(res, (uchar4)MIN_BOUND);
3146#endif // defined(MIN_BOUND)
3147#if defined(MAX_BOUND)
3148 res = min(res, (uchar4)MAX_BOUND);
3149#endif // defined(MAX_BOUND)
3150
3151 // Store the result
3152 vstore4(res, 0, dst_addr);
3153}
3154#endif // defined(K_OFFSET) && defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT)
Gian Marco05288a22017-11-21 10:57:50 +00003155#endif // defined(K_OFFSET)
3156
3157#if defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3158/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3159 *
3160 * This kernel takes a final int32 accumulator value and processes it to obtain the final QASYMM8 value.
3161 * The following computations will be performed by the kernel:
3162 *
3163 * -# Add offset terms to final result
3164 * -# Multiply each entry of result by result_mult_int
3165 * -# Add bias to final result (if -DADD_BIAS is passed at compile time)
3166 * -# Shift the int32 accumulator by result_shift
3167 * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time)
3168 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3169 *
3170 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT
3171 *
3172 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3173 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3174 * These values can be used to implement "rectified linear unit" activation functions
3175 *
3176 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3177 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3178 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3179 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3180 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3181 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3182 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3183 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003184 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3185 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3186 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3187 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco05288a22017-11-21 10:57:50 +00003188 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3189 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3190 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3191 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3192 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3193 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3194 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3195 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3196 */
3197__kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src),
3198#if defined(ADD_BIAS)
3199 VECTOR_DECLARATION(biases),
3200#endif // defined(ADD_BIAS)
3201 TENSOR3D_DECLARATION(dst))
3202{
3203 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003204 int x = get_global_id(0) * 4;
3205 int y = get_global_id(1);
3206 int z = get_global_id(2);
Gian Marco05288a22017-11-21 10:57:50 +00003207
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003208 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco05288a22017-11-21 10:57:50 +00003209
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003210 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3211
3212 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003213
Gian Marco05288a22017-11-21 10:57:50 +00003214#if defined(ADD_BIAS)
3215 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003216 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3217
3218 int4 biases_values = vload4(0, (__global int *)bias_addr);
3219 input_values += (int4)biases_values;
Gian Marco05288a22017-11-21 10:57:50 +00003220#endif // defined(ADD_BIAS)
3221
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003222 // Add the offset terms to GEMM's result
3223 input_values += (int4)RESULT_OFFSET;
3224
Georgios Pinitas45bcc3a2017-11-29 11:06:49 +00003225 // Multiply by result_mult_int and shift
Gian Marco58c57942017-11-28 09:10:03 +00003226 input_values *= RESULT_MULT_INT;
Gian Marco05288a22017-11-21 10:57:50 +00003227
Gian Marco58c57942017-11-28 09:10:03 +00003228 input_values >>= RESULT_SHIFT;
Gian Marco05288a22017-11-21 10:57:50 +00003229
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003230 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco05288a22017-11-21 10:57:50 +00003231
3232#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003233 res = max(res, (uchar4)MIN_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003234#endif // defined(MIN_BOUND)
3235#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003236 res = min(res, (uchar4)MAX_BOUND);
Gian Marco05288a22017-11-21 10:57:50 +00003237#endif // defined(MAX_BOUND)
3238
3239 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003240 vstore4(res, 0, dst_addr);
Gian Marco05288a22017-11-21 10:57:50 +00003241}
Gian Marco58c57942017-11-28 09:10:03 +00003242#endif // defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT)
3243
3244#if defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
3245/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3246 *
3247 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3248 * The following computations will be performed by the kernel:
3249 *
3250 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3251 * -# Add bias to final result if bias tensor is not a nullptr
3252 * -# Round to nearest division by a power-of-two using result_shift
3253 * -# Add offset to each result
3254 * -# Clamp the value between the specified min and max bounds
3255 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3256 *
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003257 * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET_AFTER_SHIFT, -DRESULT_FIXEDPOINT_MULTIPLIER and -DRESULT_SHIFT
Gian Marco58c57942017-11-28 09:10:03 +00003258 *
3259 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3260 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3261 * These values can be used to implement "rectified linear unit" activation functions
3262 *
3263 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3264 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3265 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3266 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3267 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3268 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3269 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3270 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003271 * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr
3272 * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes)
3273 * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
3274 * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor
Gian Marco58c57942017-11-28 09:10:03 +00003275 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3276 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3277 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3278 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3279 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3280 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3281 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3282 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3283 */
3284__kernel void gemmlowp_output_stage_quantize_down_fixedpoint(TENSOR3D_DECLARATION(src),
3285#if defined(ADD_BIAS)
3286 VECTOR_DECLARATION(biases),
3287#endif // defined(ADD_BIAS)
3288 TENSOR3D_DECLARATION(dst))
3289{
3290 // Compute source and destination addresses
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003291 int x = get_global_id(0) * 4;
3292 int y = get_global_id(1);
3293 int z = get_global_id(2);
Georgios Pinitas932491f2018-09-21 16:33:15 +01003294
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003295 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Gian Marco58c57942017-11-28 09:10:03 +00003296
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003297 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3298
3299 int4 input_values = vload4(0, (__global int *)src_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003300
3301#if defined(ADD_BIAS)
3302 // Add bias
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003303 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3304
3305 int4 biases_values = vload4(0, (__global int *)bias_addr);
3306 input_values += (int4)biases_values;
Gian Marco58c57942017-11-28 09:10:03 +00003307#endif // defined(ADD_BIAS)
3308
3309 // Multiply by result_mult_int and shift
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003310 input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4);
Gian Marco58c57942017-11-28 09:10:03 +00003311
3312 // Add the offset terms to GEMM's result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003313 input_values += (int4)RESULT_OFFSET_AFTER_SHIFT;
Gian Marco58c57942017-11-28 09:10:03 +00003314
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003315 uchar4 res = convert_uchar4_sat(input_values);
Gian Marco58c57942017-11-28 09:10:03 +00003316
3317#if defined(MIN_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003318 res = max(res, (uchar4)MIN_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003319#endif // defined(MIN_BOUND)
3320#if defined(MAX_BOUND)
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003321 res = min(res, (uchar4)MAX_BOUND);
Gian Marco58c57942017-11-28 09:10:03 +00003322#endif // defined(MAX_BOUND)
3323
3324 // Store the result
Gian Marco Iodice4b908652018-10-18 10:21:02 +01003325 vstore4(res, 0, dst_addr);
Gian Marco58c57942017-11-28 09:10:03 +00003326}
Chunosov5124be52017-11-22 20:42:13 +07003327#endif // defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT)
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003328
3329#if defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)
3330/** This OpenCL kernel is used to quantize down the int32 accumulator values of GEMMLowp to QASYMM8
3331 *
3332 * This kernel takes a final int32 accumulator value (the output of @ref CLGEMMLowpMatrixMultiplyKernel), and processes it to obtain the final QASYMM8 value.
3333 * The following computations will be performed by the kernel:
3334 *
3335 * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier
3336 * -# Add bias to final result if bias tensor is not a nullptr
3337 * -# Requantize
3338 * -# Add offset to each result
3339 * -# Clamp the value between the specified min and max bounds
3340 * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8.
3341 *
3342 * @attention The offset and scalar scale factor must be passed at compile time using -DRESULT_OFFSET, -DREAL_MULTIPLIER
3343 *
3344 * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time
3345 * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND.
3346 * These values can be used to implement "rectified linear unit" activation functions
3347 *
3348 * @param[in] src_ptr Pointer to the source tensor. Supported data type: S32
3349 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3350 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3351 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3352 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3353 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3354 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3355 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3356 * @param[in] biases_ptr Pointer to the biases tensor. Supported data type: same as @p src_ptr
3357 * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
3358 * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
3359 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
3360 * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8
3361 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3362 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3363 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3364 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3365 * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
3366 * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3367 * @param[in] dst_stride_w Stride of the source tensor in W dimension (in bytes)
3368 * @param[in] dst_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
3369 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3370 */
3371__kernel void gemmlowp_output_stage_quantize_down_float(TENSOR3D_DECLARATION(src),
3372#if defined(ADD_BIAS)
3373 VECTOR_DECLARATION(biases),
3374#endif // defined(ADD_BIAS)
3375#if defined(DST_HEIGHT)
3376 TENSOR4D_DECLARATION(dst))
3377#else // defined(DST_HEIGHT)
3378 TENSOR3D_DECLARATION(dst))
3379#endif // defined(DST_HEIGHT)
3380{
3381 // Compute source and destination addresses
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003382 int x = get_global_id(0) * 4;
3383 int y = get_global_id(1);
3384 int z = get_global_id(2);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003385
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003386 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z;
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003387
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003388 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z;
3389
3390 int4 input_values = vload4(0, (__global int *)src_addr);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003391
3392#if defined(ADD_BIAS)
3393 // Add bias
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003394 __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int);
3395
3396 int4 biases_values = vload4(0, (__global int *)bias_addr);
3397 input_values += (int4)biases_values;
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003398#endif // defined(ADD_BIAS)
3399
3400 // Convert to float
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003401 float16 input_values_f = convert_float4(input_values);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003402 input_values_f = round(input_values_f * (float)REAL_MULTIPLIER + (float)OUTPUT_OFFSET);
3403
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003404 uchar4 res = convert_uchar4_sat(input_values_f);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003405
3406#if defined(MIN_BOUND)
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003407 res = max(res, (uchar4)MIN_BOUND);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003408#endif // defined(MIN_BOUND)
3409#if defined(MAX_BOUND)
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003410 res = min(res, (uchar4)MAX_BOUND);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003411#endif // defined(MAX_BOUND)
3412
3413 // Store the result
Gian Marco Iodice0c54a622018-10-30 12:20:03 +00003414 vstore4(res, 0, dst_addr);
Georgios Pinitas51e53a32018-10-22 13:49:08 +01003415}
Gian Marco Iodicedb18a6f2019-05-30 09:53:10 +01003416#endif // defined(REAL_MULTIPLIER) && defined(OUTPUT_OFFSET)