blob: e969e847d75af6da363d2cb6984d2a95cad3df20 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Gian Marco36a0a462018-01-12 10:21:40 +000030#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
Gian Marco19835e52018-01-30 13:35:54 +000032#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000033#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000034#elif ELEMENT_SIZE == 2
35#define DATA_TYPE ushort
36#elif ELEMENT_SIZE == 4
37#define DATA_TYPE uint
38#else // ELEMENT_SIZE == 1
39#error "Element size not supported"
40#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000041
42/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043 *
Gian Marco19835e52018-01-30 13:35:54 +000044 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000046 *
47 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010048 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000052 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010055 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000057 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000059 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
61 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
63 */
Gian Marcoae2af742018-02-15 12:35:44 +000064__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066{
67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000069 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010071 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000072 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010074 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000075 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077
Gian Marcoae2af742018-02-15 12:35:44 +000078 // Add offset for batched GEMM
79 dst_addr_in_bytes += z * dst_stride_z;
80
Gian Marco36a0a462018-01-12 10:21:40 +000081 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
82 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083
Gian Marco36a0a462018-01-12 10:21:40 +000084 VSTORE(TRANSPOSE_W)
85 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086}
Gian Marco36a0a462018-01-12 10:21:40 +000087#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088
Gian Marco36a0a462018-01-12 10:21:40 +000089#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
90
91/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092 *
Gian Marco19835e52018-01-30 13:35:54 +000093 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
94 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
95 *
Gian Marco36a0a462018-01-12 10:21:40 +000096 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
102 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100104 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
106 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
107 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
108 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000109 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
110 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
112 */
Gian Marcoae2af742018-02-15 12:35:44 +0000113__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
114 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115{
Gian Marco36a0a462018-01-12 10:21:40 +0000116 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000119 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120
Gian Marcoae2af742018-02-15 12:35:44 +0000121 // Compute address for source tensor
122 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000124 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marcoae2af742018-02-15 12:35:44 +0000128 // Add offset for batched GEMM
129 dst_addr_in_bytes += z * dst_stride_z;
130
131 __global uchar *input_ptr = src.ptr;
132
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000133 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000134 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000135 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000136 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000138 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000140 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142
Gian Marco36a0a462018-01-12 10:21:40 +0000143 VEC_DATA_TYPE(DATA_TYPE, 4)
144 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100146
Gian Marco36a0a462018-01-12 10:21:40 +0000147 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149
Gian Marco36a0a462018-01-12 10:21:40 +0000150 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152
Gian Marco36a0a462018-01-12 10:21:40 +0000153 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155}
Gian Marco36a0a462018-01-12 10:21:40 +0000156#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157
Gian Marco36a0a462018-01-12 10:21:40 +0000158#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100160 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100161 *
Gian Marco19835e52018-01-30 13:35:54 +0000162 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
163 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
164 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000165 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
166 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000168 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
169 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
170 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
171 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
172 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
173 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
175 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
176 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
178 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100180 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100181 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
182 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
183 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
184 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
185 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100186 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100187 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000188 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100189 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000190 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100191 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000192 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
193 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
194 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
195 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100197__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
198 IMAGE_DECLARATION(src1),
199 IMAGE_DECLARATION(dst),
200 uint src0_stride_z,
201 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000202 uint dst_stride_z
203#if defined(REINTERPRET_OUTPUT_AS_3D)
204 ,
205 uint pad_bottom
206#endif // REINTERPRET_OUTPUT_AS_3D
207 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100208{
Gian Marco36a0a462018-01-12 10:21:40 +0000209 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
210 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000211 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212
Gian Marco36a0a462018-01-12 10:21:40 +0000213 // Offset
214 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
215 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216
Gian Marco36a0a462018-01-12 10:21:40 +0000217 // src_addr_a = address of matrix A
218 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000219 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
220 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
221
222#if defined(MATRIX_B_DEPTH)
223 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
224 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
225#else // defined(MATRIX_B_DEPTH)
226 src1_addr_in_bytes += z * src1_stride_z;
227#endif // defined(MATRIX_B_DEPTH)
228
229 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
230 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000232 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000233 __global float *src_end_addr_b = src_addr_b + COLS_B;
234
235 src_addr_a += offset_row_a;
236 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100237
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000238 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100239 float4 c00 = 0.0f;
240 float4 c10 = 0.0f;
241 float4 c20 = 0.0f;
242 float4 c30 = 0.0f;
243
Gian Marco36a0a462018-01-12 10:21:40 +0000244 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000246 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000247 float4 a0 = vload4(0, src_addr_a);
248 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249
250 c00 += (float4)a0.s0 * b0;
251 c10 += (float4)a0.s1 * b0;
252 c20 += (float4)a0.s2 * b0;
253 c30 += (float4)a0.s3 * b0;
254
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000255 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000256 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
257 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100258
259 c00 += (float4)a0.s0 * b0;
260 c10 += (float4)a0.s1 * b0;
261 c20 += (float4)a0.s2 * b0;
262 c30 += (float4)a0.s3 * b0;
263 }
264
Gian Marco36a0a462018-01-12 10:21:40 +0000265 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000267 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000268 float4 a0 = vload4(0, src_addr_a);
269 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100270
271 c00 += (float4)a0.s0 * b0;
272 c10 += (float4)a0.s1 * b0;
273 c20 += (float4)a0.s2 * b0;
274 c30 += (float4)a0.s3 * b0;
275 }
276
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000277 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
279
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000280#if defined(ALPHA)
281 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100282 c00 = c00 * (float4)ALPHA;
283 c10 = c10 * (float4)ALPHA;
284 c20 = c20 * (float4)ALPHA;
285 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000286#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287
Gian Marcoae2af742018-02-15 12:35:44 +0000288 // Compute dst address
289 __global uchar *dst_addr = offset(&dst, 0, 0);
290
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000291#if defined(REINTERPRET_OUTPUT_AS_3D)
292 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
293 // in order to take into account the presence of possible bottom paddings
294 //
295 // | |
296 // | plane0 |
297 // | |
298 // |_____________|
299 // |*************|
300 // | pad_bottom |
301 // |*************|
302 // | |
303 // | plane1 |
304 // | |
305 // |_____________|
306
307 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
308 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
309 zout = min(DEPTH_GEMM3D - 1, zout);
310
311 // Add offset due to the bottom paddings
312 zout *= (pad_bottom * dst_stride_y);
313
314 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
315 // multiply dst_stride_z by DEPTH_GEMM3D
316 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
317
318 // Store 4x4 block
319 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
320 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
321 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
322 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
323
324#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000325 // Add offset for batched GEMM
326 dst_addr += z * dst_stride_z;
327
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000328 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000329 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
330 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
331 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
332 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000333#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100334}
335
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000336/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100337 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338 *
Gian Marco19835e52018-01-30 13:35:54 +0000339 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
340 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
341 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000342 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
343 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
344 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100345 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000346 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
347 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
348 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
349 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
350 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
351 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100352 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
353 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
354 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
355 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
356 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
357 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100358 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100359 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
360 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
361 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
362 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
363 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100364 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100365 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000366 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100367 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000368 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100369 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000370 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
371 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
372 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
373 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100374 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100375__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
376 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000377 IMAGE_DECLARATION(dst),
378 uint src0_stride_z,
379 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000380 uint dst_stride_z
381#if defined(REINTERPRET_OUTPUT_AS_3D)
382 ,
383 uint pad_bottom
384#endif // REINTERPRET_OUTPUT_AS_3D
385 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100386{
Gian Marco36a0a462018-01-12 10:21:40 +0000387 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
388 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000389 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000390
391 // Offset
392 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
393 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
394
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100395 // src_addr_a = address of matrix A
396 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000397 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
398 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
399
400#if defined(MATRIX_B_DEPTH)
401 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
402 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
403#else // defined(MATRIX_B_DEPTH)
404 src1_addr_in_bytes += z * src1_stride_z;
405#endif // defined(MATRIX_B_DEPTH)
406
407 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
408 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409
Gian Marco36a0a462018-01-12 10:21:40 +0000410 src_addr_a += offset_row_a;
411 src_addr_b += offset_row_b;
412
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100413 // Reset accumulators
414 float c00 = 0.0f;
415 float c01 = 0.0f;
416 float c02 = 0.0f;
417 float c03 = 0.0f;
418 float c10 = 0.0f;
419 float c11 = 0.0f;
420 float c12 = 0.0f;
421 float c13 = 0.0f;
422 float c20 = 0.0f;
423 float c21 = 0.0f;
424 float c22 = 0.0f;
425 float c23 = 0.0f;
426 float c30 = 0.0f;
427 float c31 = 0.0f;
428 float c32 = 0.0f;
429 float c33 = 0.0f;
430
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100431#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
432
433 int i = 0;
434 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100435 {
436 // Load values from matrix A (interleaved) and matrix B (transposed)
437 float4 a0 = vload4(0, src_addr_a);
438 float4 b0 = vload4(0, src_addr_b);
439
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100440 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
441 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100442
443 c00 = fma(a0.s0, b0.s0, c00);
444 c01 = fma(a0.s0, b0.s1, c01);
445 c02 = fma(a0.s0, b0.s2, c02);
446 c03 = fma(a0.s0, b0.s3, c03);
447
448 c10 = fma(a0.s1, b0.s0, c10);
449 c11 = fma(a0.s1, b0.s1, c11);
450 c12 = fma(a0.s1, b0.s2, c12);
451 c13 = fma(a0.s1, b0.s3, c13);
452
453 c20 = fma(a0.s2, b0.s0, c20);
454 c21 = fma(a0.s2, b0.s1, c21);
455 c22 = fma(a0.s2, b0.s2, c22);
456 c23 = fma(a0.s2, b0.s3, c23);
457
458 c30 = fma(a0.s3, b0.s0, c30);
459 c31 = fma(a0.s3, b0.s1, c31);
460 c32 = fma(a0.s3, b0.s2, c32);
461 c33 = fma(a0.s3, b0.s3, c33);
462
463 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100464 a0 = vload4(0, src_addr_a);
465 b0 = vload4(0, src_addr_b);
466
467 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
468 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469
470 c00 = fma(a0.s0, b0.s0, c00);
471 c01 = fma(a0.s0, b0.s1, c01);
472 c02 = fma(a0.s0, b0.s2, c02);
473 c03 = fma(a0.s0, b0.s3, c03);
474
475 c10 = fma(a0.s1, b0.s0, c10);
476 c11 = fma(a0.s1, b0.s1, c11);
477 c12 = fma(a0.s1, b0.s2, c12);
478 c13 = fma(a0.s1, b0.s3, c13);
479
480 c20 = fma(a0.s2, b0.s0, c20);
481 c21 = fma(a0.s2, b0.s1, c21);
482 c22 = fma(a0.s2, b0.s2, c22);
483 c23 = fma(a0.s2, b0.s3, c23);
484
485 c30 = fma(a0.s3, b0.s0, c30);
486 c31 = fma(a0.s3, b0.s1, c31);
487 c32 = fma(a0.s3, b0.s2, c32);
488 c33 = fma(a0.s3, b0.s3, c33);
489
490 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100491 a0 = vload4(0, src_addr_a);
492 b0 = vload4(0, src_addr_b);
493
494 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
495 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
496
497 c00 = fma(a0.s0, b0.s0, c00);
498 c01 = fma(a0.s0, b0.s1, c01);
499 c02 = fma(a0.s0, b0.s2, c02);
500 c03 = fma(a0.s0, b0.s3, c03);
501
502 c10 = fma(a0.s1, b0.s0, c10);
503 c11 = fma(a0.s1, b0.s1, c11);
504 c12 = fma(a0.s1, b0.s2, c12);
505 c13 = fma(a0.s1, b0.s3, c13);
506
507 c20 = fma(a0.s2, b0.s0, c20);
508 c21 = fma(a0.s2, b0.s1, c21);
509 c22 = fma(a0.s2, b0.s2, c22);
510 c23 = fma(a0.s2, b0.s3, c23);
511
512 c30 = fma(a0.s3, b0.s0, c30);
513 c31 = fma(a0.s3, b0.s1, c31);
514 c32 = fma(a0.s3, b0.s2, c32);
515 c33 = fma(a0.s3, b0.s3, c33);
516
517 // Load values from matrix A (interleaved) and matrix B (transposed)
518 a0 = vload4(0, src_addr_a);
519 b0 = vload4(0, src_addr_b);
520
521 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
522 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100523
524 c00 = fma(a0.s0, b0.s0, c00);
525 c01 = fma(a0.s0, b0.s1, c01);
526 c02 = fma(a0.s0, b0.s2, c02);
527 c03 = fma(a0.s0, b0.s3, c03);
528
529 c10 = fma(a0.s1, b0.s0, c10);
530 c11 = fma(a0.s1, b0.s1, c11);
531 c12 = fma(a0.s1, b0.s2, c12);
532 c13 = fma(a0.s1, b0.s3, c13);
533
534 c20 = fma(a0.s2, b0.s0, c20);
535 c21 = fma(a0.s2, b0.s1, c21);
536 c22 = fma(a0.s2, b0.s2, c22);
537 c23 = fma(a0.s2, b0.s3, c23);
538
539 c30 = fma(a0.s3, b0.s0, c30);
540 c31 = fma(a0.s3, b0.s1, c31);
541 c32 = fma(a0.s3, b0.s2, c32);
542 c33 = fma(a0.s3, b0.s3, c33);
543 }
544
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100545 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100546 {
547 // Load values from matrix A (interleaved) and matrix B (transposed)
548 float4 a0 = vload4(0, src_addr_a);
549 float4 b0 = vload4(0, src_addr_b);
550
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100551 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
552 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
553
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100554 c00 = fma(a0.s0, b0.s0, c00);
555 c01 = fma(a0.s0, b0.s1, c01);
556 c02 = fma(a0.s0, b0.s2, c02);
557 c03 = fma(a0.s0, b0.s3, c03);
558
559 c10 = fma(a0.s1, b0.s0, c10);
560 c11 = fma(a0.s1, b0.s1, c11);
561 c12 = fma(a0.s1, b0.s2, c12);
562 c13 = fma(a0.s1, b0.s3, c13);
563
564 c20 = fma(a0.s2, b0.s0, c20);
565 c21 = fma(a0.s2, b0.s1, c21);
566 c22 = fma(a0.s2, b0.s2, c22);
567 c23 = fma(a0.s2, b0.s3, c23);
568
569 c30 = fma(a0.s3, b0.s0, c30);
570 c31 = fma(a0.s3, b0.s1, c31);
571 c32 = fma(a0.s3, b0.s2, c32);
572 c33 = fma(a0.s3, b0.s3, c33);
573 }
574
575 // Compute destination address
576 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
577
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000578#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100579 // Multiply by the weight of matrix product
580 c00 = c00 * ALPHA;
581 c01 = c01 * ALPHA;
582 c02 = c02 * ALPHA;
583 c03 = c03 * ALPHA;
584 c10 = c10 * ALPHA;
585 c11 = c11 * ALPHA;
586 c12 = c12 * ALPHA;
587 c13 = c13 * ALPHA;
588 c20 = c20 * ALPHA;
589 c21 = c21 * ALPHA;
590 c22 = c22 * ALPHA;
591 c23 = c23 * ALPHA;
592 c30 = c30 * ALPHA;
593 c31 = c31 * ALPHA;
594 c32 = c32 * ALPHA;
595 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000596#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100597
Gian Marcoae2af742018-02-15 12:35:44 +0000598 // Compute dst address
599 __global uchar *dst_addr = offset(&dst, 0, 0);
600
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000601#if defined(REINTERPRET_OUTPUT_AS_3D)
602 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
603 // in order to take into account the presence of possible bottom paddings
604 //
605 // | |
606 // | plane0 |
607 // | |
608 // |_____________|
609 // |*************|
610 // | pad_bottom |
611 // |*************|
612 // | |
613 // | plane1 |
614 // | |
615 // |_____________|
616
617 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
618 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
619 zout = min(DEPTH_GEMM3D - 1, zout);
620
621 // Add offset due to the bottom paddings
622 zout *= (pad_bottom * dst_stride_y);
623
624 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
625 // multiply dst_stride_z by DEPTH_GEMM3D
626 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
627
628 // Store 4x4 block
629 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
630 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
631 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
632 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
633
634#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000635 // Add offset for batched GEMM
636 dst_addr += z * dst_stride_z;
637
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100638 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000639 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
640 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
641 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
642 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000643#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644}
645
Georgios Pinitas84225582018-05-14 12:00:05 +0100646// Undefine local defines
647#undef COLS_MTX_B
648
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100649#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100650/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100651 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100652 *
Gian Marco19835e52018-01-30 13:35:54 +0000653 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
654 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
655 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000656 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
657 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100658 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000659 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
660 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
661 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
662 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
663 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
664 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100665 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
666 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
667 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
668 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
669 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
670 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100671 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100672 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
673 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
674 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
675 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
676 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100677 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100678 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000679 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100680 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000681 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100682 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000683 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
684 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
685 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
686 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100687 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100688__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
689 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000690 IMAGE_DECLARATION(dst),
691 uint src0_stride_z,
692 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000693 uint dst_stride_z
694#if defined(REINTERPRET_OUTPUT_AS_3D)
695 ,
696 uint pad_bottom
697#endif // REINTERPRET_OUTPUT_AS_3D
698 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100699{
Gian Marco36a0a462018-01-12 10:21:40 +0000700 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
701 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000702 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703
Gian Marco36a0a462018-01-12 10:21:40 +0000704 // Offset
705 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
706 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100707
Gian Marco36a0a462018-01-12 10:21:40 +0000708 // src_addr_a = address of matrix A
709 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000710 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
711 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
712
713#if defined(MATRIX_B_DEPTH)
714 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
715 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
716#else // defined(MATRIX_B_DEPTH)
717 src1_addr_in_bytes += z * src1_stride_z;
718#endif // defined(MATRIX_B_DEPTH)
719
720 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
721 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000723 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000724 __global half *src_end_addr_b = src_addr_b + COLS_B;
725
726 src_addr_a += offset_row_a;
727 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100728
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000729 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100730 half8 c00 = 0.0f;
731 half8 c10 = 0.0f;
732 half8 c20 = 0.0f;
733 half8 c30 = 0.0f;
734
Gian Marco36a0a462018-01-12 10:21:40 +0000735 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000737 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000738 half4 a0 = vload4(0, src_addr_a);
739 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100740
741 c00 += (half8)a0.s0 * b0;
742 c10 += (half8)a0.s1 * b0;
743 c20 += (half8)a0.s2 * b0;
744 c30 += (half8)a0.s3 * b0;
745
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000746 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000747 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
748 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100749
750 c00 += (half8)a0.s0 * b0;
751 c10 += (half8)a0.s1 * b0;
752 c20 += (half8)a0.s2 * b0;
753 c30 += (half8)a0.s3 * b0;
754 }
755
Gian Marco36a0a462018-01-12 10:21:40 +0000756 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100757 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000758 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000759 half4 a0 = vload4(0, src_addr_a);
760 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100761
762 c00 += (half8)a0.s0 * b0;
763 c10 += (half8)a0.s1 * b0;
764 c20 += (half8)a0.s2 * b0;
765 c30 += (half8)a0.s3 * b0;
766 }
767
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000768 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
770
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000771#if defined(ALPHA)
772 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100773 c00 = c00 * (half8)ALPHA;
774 c10 = c10 * (half8)ALPHA;
775 c20 = c20 * (half8)ALPHA;
776 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000777#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100778
Gian Marcoae2af742018-02-15 12:35:44 +0000779 // Compute dst address
780 __global uchar *dst_addr = offset(&dst, 0, 0);
781
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000782#if defined(REINTERPRET_OUTPUT_AS_3D)
783 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
784 // in order to take into account the presence of possible bottom paddings
785 //
786 // | |
787 // | plane0 |
788 // | |
789 // |_____________|
790 // |*************|
791 // | pad_bottom |
792 // |*************|
793 // | |
794 // | plane1 |
795 // | |
796 // |_____________|
797
798 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
799 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
800 zout = min(DEPTH_GEMM3D - 1, zout);
801
802 // Add offset due to the bottom paddings
803 zout *= (pad_bottom * dst_stride_y);
804
805 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
806 // multiply dst_stride_z by DEPTH_GEMM3D
807 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
808
809 // Store 4x8 block
810 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
811 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
812 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
813 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
814
815#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000816 // Add offset for batched GEMM
817 dst_addr += z * dst_stride_z;
818
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000819 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000820 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
821 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
822 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
823 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000824#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100825}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100826
827/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
828 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
829 *
830 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
831 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
832 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
833 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
834 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
835 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000836 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
837 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
838 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
839 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
840 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
841 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100842 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
843 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
844 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
845 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
846 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
847 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
848 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
849 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
850 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
851 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
852 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
853 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
854 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
855 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
856 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
857 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
858 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
859 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000860 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100861 */
862__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
863 IMAGE_DECLARATION(src1),
864 IMAGE_DECLARATION(dst),
865 uint src0_stride_z,
866 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000867 uint dst_stride_z
868#if defined(REINTERPRET_OUTPUT_AS_3D)
869 ,
870 uint pad_bottom
871#endif // REINTERPRET_OUTPUT_AS_3D
872 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100873{
874 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
875 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
876 int z = get_global_id(2);
877
878 // Offset
879 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
880 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
881
882 // src_addr_a = address of matrix A
883 // src_addr_b = address of matrix B
884 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
885 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
886
887#if defined(MATRIX_B_DEPTH)
888 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
889 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
890#else // defined(MATRIX_B_DEPTH)
891 src1_addr_in_bytes += z * src1_stride_z;
892#endif // defined(MATRIX_B_DEPTH)
893
894 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
895 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
896
897 // Compute end row address for matrix B
898 __global half *src_end_addr_b = src_addr_b + COLS_B;
899
900 src_addr_a += offset_row_a;
901 src_addr_b += offset_row_b;
902
903 // Reset accumulators
904 half8 c00 = 0.0f;
905 half8 c10 = 0.0f;
906 half8 c20 = 0.0f;
907 half8 c30 = 0.0f;
908
909#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
910
911 int i = 0;
912 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
913 {
914#if MULT_INTERLEAVE4X4_HEIGHT == 1
915 // Load values from matrix A (interleaved) and matrix B (transposed)
916 half8 a0 = vload8(0, src_addr_a);
917 half8 b0 = vload8(0, src_addr_b);
918
919 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
920 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
921
922 c00 = fma((half8)a0.s0, b0, c00);
923 c10 = fma((half8)a0.s1, b0, c10);
924 c20 = fma((half8)a0.s2, b0, c20);
925 c30 = fma((half8)a0.s3, b0, c30);
926
927 // Load values from matrix B (transposed)
928 b0 = vload8(0, src_addr_b);
929
930 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
931
932 c00 = fma((half8)a0.s4, b0, c00);
933 c10 = fma((half8)a0.s5, b0, c10);
934 c20 = fma((half8)a0.s6, b0, c20);
935 c30 = fma((half8)a0.s7, b0, c30);
936
937 // Load values from matrix A (interleaved) and matrix B (transposed)
938 a0 = vload8(0, src_addr_a);
939 b0 = vload8(0, src_addr_b);
940
941 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
942 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
943
944 c00 = fma((half8)a0.s0, b0, c00);
945 c10 = fma((half8)a0.s1, b0, c10);
946 c20 = fma((half8)a0.s2, b0, c20);
947 c30 = fma((half8)a0.s3, b0, c30);
948
949 // Load values from matrix B (transposed)
950 b0 = vload8(0, src_addr_b);
951
952 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
953
954 c00 = fma((half8)a0.s4, b0, c00);
955 c10 = fma((half8)a0.s5, b0, c10);
956 c20 = fma((half8)a0.s6, b0, c20);
957 c30 = fma((half8)a0.s7, b0, c30);
958#else // MULT_INTERLEAVE4X4_HEIGHT == 1
959 // Load values from matrix A (interleaved) and matrix B (transposed)
960 half4 a0 = vload4(0, src_addr_a);
961 half8 b0 = vload8(0, src_addr_b);
962
963 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
964 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
965
966 c00 = fma((half8)a0.s0, b0, c00);
967 c10 = fma((half8)a0.s1, b0, c10);
968 c20 = fma((half8)a0.s2, b0, c20);
969 c30 = fma((half8)a0.s3, b0, c30);
970
971 // Load values from matrix A (interleaved) and matrix B (transposed)
972 a0 = vload4(0, src_addr_a);
973 b0 = vload8(0, src_addr_b);
974
975 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
976 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
977
978 c00 = fma((half8)a0.s0, b0, c00);
979 c10 = fma((half8)a0.s1, b0, c10);
980 c20 = fma((half8)a0.s2, b0, c20);
981 c30 = fma((half8)a0.s3, b0, c30);
982
983 // Load values from matrix A (interleaved) and matrix B (transposed)
984 a0 = vload4(0, src_addr_a);
985 b0 = vload8(0, src_addr_b);
986
987 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
988 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
989
990 c00 = fma((half8)a0.s0, b0, c00);
991 c10 = fma((half8)a0.s1, b0, c10);
992 c20 = fma((half8)a0.s2, b0, c20);
993 c30 = fma((half8)a0.s3, b0, c30);
994
995 // Load values from matrix A (interleaved) and matrix B (transposed)
996 a0 = vload4(0, src_addr_a);
997 b0 = vload8(0, src_addr_b);
998
999 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1000 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1001
1002 c00 = fma((half8)a0.s0, b0, c00);
1003 c10 = fma((half8)a0.s1, b0, c10);
1004 c20 = fma((half8)a0.s2, b0, c20);
1005 c30 = fma((half8)a0.s3, b0, c30);
1006#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
1007 }
1008
1009 for(; i < (int)(COLS_MTX_B); ++i)
1010 {
1011 // Load values from matrix A (interleaved) and matrix B (transposed)
1012 half4 a0 = vload4(0, src_addr_a);
1013 half8 b0 = vload8(0, src_addr_b);
1014
1015 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1016 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1017
1018 c00 = fma((half8)a0.s0, b0, c00);
1019 c10 = fma((half8)a0.s1, b0, c10);
1020 c20 = fma((half8)a0.s2, b0, c20);
1021 c30 = fma((half8)a0.s3, b0, c30);
1022 }
1023
1024 // Compute destination address
1025 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1026
1027#if defined(ALPHA)
1028 // Multiply by the weight of matrix product
1029 c00 = c00 * (half8)ALPHA;
1030 c10 = c10 * (half8)ALPHA;
1031 c20 = c20 * (half8)ALPHA;
1032 c30 = c30 * (half8)ALPHA;
1033#endif // defined(ALPHA)
1034
1035 // Compute dst address
1036 __global uchar *dst_addr = offset(&dst, 0, 0);
1037
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001038#if defined(REINTERPRET_OUTPUT_AS_3D)
1039 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1040 // in order to take into account the presence of possible bottom paddings
1041 //
1042 // | |
1043 // | plane0 |
1044 // | |
1045 // |_____________|
1046 // |*************|
1047 // | pad_bottom |
1048 // |*************|
1049 // | |
1050 // | plane1 |
1051 // | |
1052 // |_____________|
1053
1054 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1055 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1056 zout = min(DEPTH_GEMM3D - 1, zout);
1057
1058 // Add offset due to the bottom paddings
1059 zout *= (pad_bottom * dst_stride_y);
1060
1061 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1062 // multiply dst_stride_z by DEPTH_GEMM3D
1063 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1064
1065 // Store 4x8 block
1066 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1067 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1068 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1069 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1070
1071#else // defined(REINTERPRET_OUTPUT_AS_3D)
1072 // Add offset for batched GEMM
1073 dst_addr += z * dst_stride_z;
1074
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001075 // Store 4x8 block
1076 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1077 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1078 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1079 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001080#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001081}
Georgios Pinitas84225582018-05-14 12:00:05 +01001082
1083// Undefine local defines
1084#undef COLS_MTX_B
1085
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001086#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001087
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001088#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001089/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
1090 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
1091 *
Gian Marco19835e52018-01-30 13:35:54 +00001092 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1093 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1094 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001095 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1096 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1097 * @note:ALPHA must be passed in 8 bit fixed point format
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001098 *
1099 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
1100 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1101 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1102 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1103 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1104 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1105 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1106 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1107 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1108 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1109 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1110 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1111 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1112 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001113 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001114 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001115 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001116 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001117 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1118 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1119 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001120 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001121__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
1122 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001123 IMAGE_DECLARATION(dst),
1124 uint src0_stride_z,
1125 uint src1_stride_z,
1126 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001127{
Gian Marco36a0a462018-01-12 10:21:40 +00001128 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1129 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001130 int z = get_global_id(2);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001131
Gian Marco36a0a462018-01-12 10:21:40 +00001132 // Offset
1133 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1134 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001135
Gian Marco36a0a462018-01-12 10:21:40 +00001136 // src_addr_a = address of matrix A
1137 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001138 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1139 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1140
1141#if defined(MATRIX_B_DEPTH)
1142 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1143 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1144#else // defined(MATRIX_B_DEPTH)
1145 src1_addr_in_bytes += z * src1_stride_z;
1146#endif // defined(MATRIX_B_DEPTH)
1147
1148 __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
1149 __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001150
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001151 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001152 __global char *src_end_addr_b = src_addr_b + COLS_B;
1153
1154 src_addr_a += offset_row_a;
1155 src_addr_b += offset_row_b;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001156
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001157 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001158 short8 c00 = 0.0f;
1159 short8 c10 = 0.0f;
1160 short8 c20 = 0.0f;
1161 short8 c30 = 0.0f;
1162 short8 c01 = 0.0f;
1163 short8 c11 = 0.0f;
1164 short8 c21 = 0.0f;
1165 short8 c31 = 0.0f;
1166
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001167 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +00001168 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001169 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001170 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001171 char4 a0 = vload4(0, src_addr_a);
1172 char16 b0 = vload16(0, src_addr_b);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001173
1174 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
1175 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
1176 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
1177 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
1178
1179 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1180 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
1181 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
1182 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
1183 }
1184
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001185 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001186 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1187
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001188 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001189 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
1190 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
1191 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
1192 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
1193
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001194#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001195 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1196 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1197 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1198 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001199#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001200
Gian Marcoae2af742018-02-15 12:35:44 +00001201 // Compute dst address
1202 __global uchar *dst_addr = offset(&dst, 0, 0);
1203
1204 // Add offset for batched GEMM
1205 dst_addr += z * dst_stride_z;
1206
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001207 // Store 16x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001208 vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
1209 vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
1210 vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
1211 vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001212}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001213
1214/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
1215 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1216 *
Gian Marco19835e52018-01-30 13:35:54 +00001217 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1218 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1219 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001220 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1221 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1222 * @note:ALPHA must be passed in 16 bit fixed point format
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001223 *
1224 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
1225 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1226 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1227 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1228 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1229 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1230 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1231 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1232 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1233 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1234 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1235 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1236 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1237 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001238 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001239 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001240 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001241 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001242 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1243 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1244 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001245 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001246__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
1247 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001248 IMAGE_DECLARATION(dst),
1249 uint src0_stride_z,
1250 uint src1_stride_z,
1251 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001252{
Gian Marco36a0a462018-01-12 10:21:40 +00001253 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1254 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001255 int z = get_global_id(2);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001256
Gian Marco36a0a462018-01-12 10:21:40 +00001257 // Offset
1258 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1259 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001260
Gian Marco36a0a462018-01-12 10:21:40 +00001261 // src_addr_a = address of matrix A
1262 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001263 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1264 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1265
1266#if defined(MATRIX_B_DEPTH)
1267 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1268 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1269#else // defined(MATRIX_B_DEPTH)
1270 src1_addr_in_bytes += z * src1_stride_z;
1271#endif // defined(MATRIX_B_DEPTH)
1272
1273 __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
1274 __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001275
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001276 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001277 __global short *src_end_addr_b = src_addr_b + COLS_B;
1278
1279 src_addr_a += offset_row_a;
1280 src_addr_b += offset_row_b;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001281
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001282 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001283 int8 c00 = 0.0f;
1284 int8 c10 = 0.0f;
1285 int8 c20 = 0.0f;
1286 int8 c30 = 0.0f;
1287
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001288 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +00001289 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001290 {
1291 /* Load values from matrix A (interleaved) and matrix B (transposed) */
Gian Marco36a0a462018-01-12 10:21:40 +00001292 short4 a0 = vload4(0, src_addr_a);
1293 short8 b0 = vload8(0, src_addr_b);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001294
1295 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1296 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
1297 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
1298 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
1299 }
1300
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001301 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001302 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1303
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001304 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001305 short8 c00_qs16 = convert_short8_sat(c00);
1306 short8 c10_qs16 = convert_short8_sat(c10);
1307 short8 c20_qs16 = convert_short8_sat(c20);
1308 short8 c30_qs16 = convert_short8_sat(c30);
1309
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001310#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001311 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1312 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1313 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1314 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001315#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001316
Gian Marcoae2af742018-02-15 12:35:44 +00001317 // Compute dst address
1318 __global uchar *dst_addr = offset(&dst, 0, 0);
1319
1320 // Add offset for batched GEMM
1321 dst_addr += z * dst_stride_z;
1322
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001323 // Store 8x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001324 vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
1325 vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
1326 vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
1327 vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001328}
1329#endif // defined(FIXED_POINT_POSITION)
Gian Marco36a0a462018-01-12 10:21:40 +00001330#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001331
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001332#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1333#if defined(DATA_TYPE)
1334#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001335/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001336 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001337 * @note This OpenCL kernel works with floating point data types (F16/F32)
1338 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1339 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001340 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001341 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1342 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001343 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001344 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1345 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1346 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1347 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1348 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1349 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001350 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001351 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1352 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1353 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1354 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1355 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001356 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001357 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1358 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1359 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1360 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1361 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001362 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001363 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1364 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1365 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1366 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1367 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001368 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1369 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1370 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1371 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001372 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001373__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1374 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001375 IMAGE_DECLARATION(dst),
1376 uint src0_stride_z,
1377 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001378 uint dst_stride_z
1379#if defined(REINTERPRET_OUTPUT_AS_3D)
1380 ,
1381 uint pad_bottom
1382#endif // REINTERPRET_OUTPUT_AS_3D
1383 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001384{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001385 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001386
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001387 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001388 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001389
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001390 // Update address for the matrix A
1391 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001392
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001393 // Update address for the matrix B
1394 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001395
Gian Marcoae2af742018-02-15 12:35:44 +00001396 // Add offset for batched GEMM
1397 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001398
1399#if defined(MATRIX_B_DEPTH)
1400 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1401 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1402#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001403 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001404#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001405
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001406 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1407
1408 VECTOR_TYPE acc0 = 0.0f;
1409#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1410 VECTOR_TYPE acc1 = 0.0f;
1411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1412#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1413 VECTOR_TYPE acc2 = 0.0f;
1414#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1416 VECTOR_TYPE acc3 = 0.0f;
1417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1418
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001419 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001420 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001421 // Load values from matrix A
1422 VEC_DATA_TYPE(DATA_TYPE, 2)
1423 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1425 VEC_DATA_TYPE(DATA_TYPE, 2)
1426 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1429 VEC_DATA_TYPE(DATA_TYPE, 2)
1430 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1431#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1432#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1433 VEC_DATA_TYPE(DATA_TYPE, 2)
1434 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1436 // Load values from matrix B
1437 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1438 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001439
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001440 // Accumulate
1441 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1442 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1443#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1444 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1445 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1448 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1449 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1450#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1451#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1452 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1453 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1454#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001455 }
1456
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001457 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001458 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001459 // Load values from matrix A
1460 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1461#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1462 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1464#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1465 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1468 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1469#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1470 // Load values from matrix B
1471 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001472
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001473 // Accumulate
1474 acc0 += b0 * (VECTOR_TYPE)a0;
1475#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1476 acc1 += b0 * (VECTOR_TYPE)a1;
1477#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1478#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1479 acc2 += b0 * (VECTOR_TYPE)a2;
1480#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1481#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1482 acc3 += b0 * (VECTOR_TYPE)a3;
1483#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001484 }
1485
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001486 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001487 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1488
Gian Marcoae2af742018-02-15 12:35:44 +00001489 // Compute dst address
1490 __global uchar *dst_addr = offset(&dst, 0, 0);
1491
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001492 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001493#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001494 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001495#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001496#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1497 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1498#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1499#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1500 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1501#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1502#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1503 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1504#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1505
1506 int z = get_global_id(2);
1507
1508#if defined(REINTERPRET_OUTPUT_AS_3D)
1509 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1510 // in order to take into account the presence of possible bottom paddings
1511 //
1512 // | |
1513 // | plane0 |
1514 // | |
1515 // |_____________|
1516 // |*************|
1517 // | pad_bottom |
1518 // |*************|
1519 // | |
1520 // | plane1 |
1521 // | |
1522 // |_____________|
1523
1524 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1525 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1526 zout = min(DEPTH_GEMM3D - 1, zout);
1527
1528 // Add offset due to the bottom paddings
1529 zout *= (pad_bottom * dst_stride_y);
1530
1531 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1532 // multiply dst_stride_z by DEPTH_GEMM3D
1533 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1534
1535 // Store output block
1536 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1537 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
1538#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1539 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1540 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
1541#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1542#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1543 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1544 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
1545#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1546#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1547 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1548 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
1549#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1550
1551#else // defined(REINTERPRET_OUTPUT_AS_3D)
1552 // Add offset for batched GEMM
1553 dst_addr += z * dst_stride_z;
1554
1555 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001556 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001557 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001558#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001559 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001560 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001561#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1562#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001563 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001564 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001565#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1566#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001567 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001568 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001569#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001570#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001571}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001572#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001573
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001574/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001575 *
1576 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1577 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1578 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1579 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1580 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001581 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1582 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001583 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001584 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1585 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1586 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1587 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1588 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1589 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001590 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1591 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1592 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1593 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1594 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1595 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1596 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1597 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1598 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1599 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1600 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1601 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1602 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1603 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1604 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1605 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1606 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1607 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001608 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1609 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1610 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1611 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001612 */
1613__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1614 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001615 IMAGE_DECLARATION(dst),
1616 uint src0_stride_z,
1617 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001618 uint dst_stride_z
1619#if defined(REINTERPRET_OUTPUT_AS_3D)
1620 ,
1621 uint pad_bottom
1622#endif // REINTERPRET_OUTPUT_AS_3D
1623 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001624{
1625 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1626
1627 // Compute starting address for matrix A and matrix B
1628 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1629
1630 // Update address for matrix A
1631 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1632
1633 // Update address for matrix B
1634 src_addr.s1 += idx * sizeof(float);
1635
Gian Marcoae2af742018-02-15 12:35:44 +00001636 // Add offset for batched GEMM
1637 src_addr.s0 += get_global_id(2) * src0_stride_z;
1638
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001639#if defined(MATRIX_B_DEPTH)
1640 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1641 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1642#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001643 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001644#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001645
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001646 // Initialize accumulators
1647 float acc00 = 0.0f;
1648 float acc01 = 0.0f;
1649 float acc02 = 0.0f;
1650 float acc03 = 0.0f;
1651
1652#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1653 float acc10 = 0.0f;
1654 float acc11 = 0.0f;
1655 float acc12 = 0.0f;
1656 float acc13 = 0.0f;
1657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1658
1659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1660 float acc20 = 0.0f;
1661 float acc21 = 0.0f;
1662 float acc22 = 0.0f;
1663 float acc23 = 0.0f;
1664#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1665
1666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1667 float acc30 = 0.0f;
1668 float acc31 = 0.0f;
1669 float acc32 = 0.0f;
1670 float acc33 = 0.0f;
1671#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1672
1673 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001674 int i = 0;
1675 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001676 {
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001677 // Load values from matrix A and matrix B
1678 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001679#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001680 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001683 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001686 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001687#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001688 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1689 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001690
1691 // Multiply and accumulate
1692 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001693 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001694 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001695 acc03 = fma(a0.s0, b0.s3, acc03);
1696
1697#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001698
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001699 acc10 = fma(a1.s0, b0.s0, acc10);
1700 acc11 = fma(a1.s0, b0.s1, acc11);
1701 acc12 = fma(a1.s0, b0.s2, acc12);
1702 acc13 = fma(a1.s0, b0.s3, acc13);
1703
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1705#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001706
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001707 acc20 = fma(a2.s0, b0.s0, acc20);
1708 acc21 = fma(a2.s0, b0.s1, acc21);
1709 acc22 = fma(a2.s0, b0.s2, acc22);
1710 acc23 = fma(a2.s0, b0.s3, acc23);
1711
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001712#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1713#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001714
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001715 acc30 = fma(a3.s0, b0.s0, acc30);
1716 acc31 = fma(a3.s0, b0.s1, acc31);
1717 acc32 = fma(a3.s0, b0.s2, acc32);
1718 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001719#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001720
1721 // Load values from matrix A and matrix B
1722 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1723 src_addr.s1 += src1_stride_y;
1724
1725 // Multiply and accumulate
1726 acc00 = fma(a0.s1, b0.s0, acc00);
1727 acc01 = fma(a0.s1, b0.s1, acc01);
1728 acc02 = fma(a0.s1, b0.s2, acc02);
1729 acc03 = fma(a0.s1, b0.s3, acc03);
1730
1731#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1732
1733 acc10 = fma(a1.s1, b0.s0, acc10);
1734 acc11 = fma(a1.s1, b0.s1, acc11);
1735 acc12 = fma(a1.s1, b0.s2, acc12);
1736 acc13 = fma(a1.s1, b0.s3, acc13);
1737
1738#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1739#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1740
1741 acc20 = fma(a2.s1, b0.s0, acc20);
1742 acc21 = fma(a2.s1, b0.s1, acc21);
1743 acc22 = fma(a2.s1, b0.s2, acc22);
1744 acc23 = fma(a2.s1, b0.s3, acc23);
1745
1746#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1747#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1748
1749 acc30 = fma(a3.s1, b0.s0, acc30);
1750 acc31 = fma(a3.s1, b0.s1, acc31);
1751 acc32 = fma(a3.s1, b0.s2, acc32);
1752 acc33 = fma(a3.s1, b0.s3, acc33);
1753#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1754
1755 // Load values from matrix A and matrix B
1756 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1757 src_addr.s1 += src1_stride_y;
1758
1759 // Multiply and accumulate
1760 acc00 = fma(a0.s2, b0.s0, acc00);
1761 acc01 = fma(a0.s2, b0.s1, acc01);
1762 acc02 = fma(a0.s2, b0.s2, acc02);
1763 acc03 = fma(a0.s2, b0.s3, acc03);
1764
1765#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1766
1767 acc10 = fma(a1.s2, b0.s0, acc10);
1768 acc11 = fma(a1.s2, b0.s1, acc11);
1769 acc12 = fma(a1.s2, b0.s2, acc12);
1770 acc13 = fma(a1.s2, b0.s3, acc13);
1771
1772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1774
1775 acc20 = fma(a2.s2, b0.s0, acc20);
1776 acc21 = fma(a2.s2, b0.s1, acc21);
1777 acc22 = fma(a2.s2, b0.s2, acc22);
1778 acc23 = fma(a2.s2, b0.s3, acc23);
1779
1780#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1781#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1782
1783 acc30 = fma(a3.s2, b0.s0, acc30);
1784 acc31 = fma(a3.s2, b0.s1, acc31);
1785 acc32 = fma(a3.s2, b0.s2, acc32);
1786 acc33 = fma(a3.s2, b0.s3, acc33);
1787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1788
1789 // Load values from matrix A and matrix B
1790 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1791 src_addr.s1 += src1_stride_y;
1792
1793 // Multiply and accumulate
1794 acc00 = fma(a0.s3, b0.s0, acc00);
1795 acc01 = fma(a0.s3, b0.s1, acc01);
1796 acc02 = fma(a0.s3, b0.s2, acc02);
1797 acc03 = fma(a0.s3, b0.s3, acc03);
1798
1799#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1800
1801 acc10 = fma(a1.s3, b0.s0, acc10);
1802 acc11 = fma(a1.s3, b0.s1, acc11);
1803 acc12 = fma(a1.s3, b0.s2, acc12);
1804 acc13 = fma(a1.s3, b0.s3, acc13);
1805
1806#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1807#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1808
1809 acc20 = fma(a2.s3, b0.s0, acc20);
1810 acc21 = fma(a2.s3, b0.s1, acc21);
1811 acc22 = fma(a2.s3, b0.s2, acc22);
1812 acc23 = fma(a2.s3, b0.s3, acc23);
1813
1814#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1815#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1816
1817 acc30 = fma(a3.s3, b0.s0, acc30);
1818 acc31 = fma(a3.s3, b0.s1, acc31);
1819 acc32 = fma(a3.s3, b0.s2, acc32);
1820 acc33 = fma(a3.s3, b0.s3, acc33);
1821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1822
1823 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001824 }
1825
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001826 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001827 {
1828 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001829 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001830#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1831 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1832#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1833#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1834 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1837 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1839 // Load values from matrix B
1840 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001841 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001842
1843 // Multiply and accumulate
1844 acc00 = fma(a0, b0.s0, acc00);
1845 acc01 = fma(a0, b0.s1, acc01);
1846 acc02 = fma(a0, b0.s2, acc02);
1847 acc03 = fma(a0, b0.s3, acc03);
1848#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1849 acc10 = fma(a1, b0.s0, acc10);
1850 acc11 = fma(a1, b0.s1, acc11);
1851 acc12 = fma(a1, b0.s2, acc12);
1852 acc13 = fma(a1, b0.s3, acc13);
1853#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1854#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1855 acc20 = fma(a2, b0.s0, acc20);
1856 acc21 = fma(a2, b0.s1, acc21);
1857 acc22 = fma(a2, b0.s2, acc22);
1858 acc23 = fma(a2, b0.s3, acc23);
1859#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1860#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1861 acc30 = fma(a3, b0.s0, acc30);
1862 acc31 = fma(a3, b0.s1, acc31);
1863 acc32 = fma(a3, b0.s2, acc32);
1864 acc33 = fma(a3, b0.s3, acc33);
1865#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001866
1867 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001868 }
1869
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001870 int z = get_global_id(2);
1871
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001872 // Compute destination address
1873 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1874
1875 // Multiply by the weight of matrix-matrix product and store the result
1876#if defined(ALPHA)
1877 acc00 = acc00 * ALPHA;
1878 acc01 = acc01 * ALPHA;
1879 acc02 = acc02 * ALPHA;
1880 acc03 = acc03 * ALPHA;
1881#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001882#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001883 acc10 = acc10 * ALPHA;
1884 acc11 = acc11 * ALPHA;
1885 acc12 = acc12 * ALPHA;
1886 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001887#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1888#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001889 acc20 = acc20 * ALPHA;
1890 acc21 = acc21 * ALPHA;
1891 acc22 = acc22 * ALPHA;
1892 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001893#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1894#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001895 acc30 = acc30 * ALPHA;
1896 acc31 = acc31 * ALPHA;
1897 acc32 = acc32 * ALPHA;
1898 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1900
1901 // Compute dst address
1902 __global uchar *dst_addr = offset(&dst, 0, 0);
1903
1904#if defined(REINTERPRET_OUTPUT_AS_3D)
1905 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1906 // in order to take into account the presence of possible bottom paddings
1907 //
1908 // | |
1909 // | plane0 |
1910 // | |
1911 // |_____________|
1912 // |*************|
1913 // | pad_bottom |
1914 // |*************|
1915 // | |
1916 // | plane1 |
1917 // | |
1918 // |_____________|
1919
1920 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1921 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1922 zout = min(DEPTH_GEMM3D - 1, zout);
1923
1924 // Add offset due to the bottom paddings
1925 zout *= (pad_bottom * dst_stride_y);
1926
1927 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1928 // multiply dst_stride_z by DEPTH_GEMM3D
1929 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1930
1931 // Store the output block
1932 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1933#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1934 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1935#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1937 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1938#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1939#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1940 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001942
1943#else // defined(REINTERPRET_OUTPUT_AS_3D)
1944 // Add offset for batched GEMM
1945 dst_addr += z * dst_stride_z;
1946
1947 // Store the output block
1948 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1949#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1950 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1951#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1952#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1953 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1954#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1955#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1956 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1957#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1958#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001959}
1960
1961/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1962 *
1963 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1964 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1965 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1966 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1967 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1968 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001969 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1970 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001971 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001972 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1973 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1974 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1975 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1976 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1977 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001978 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1979 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1980 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1981 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1982 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1983 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1984 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1985 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1986 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1987 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1988 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1989 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1990 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1991 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1992 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1993 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1994 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1995 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001996 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1997 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1998 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1999 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002000 */
2001__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
2002 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002003 IMAGE_DECLARATION(dst),
2004 uint src0_stride_z,
2005 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002006 uint dst_stride_z
2007#if defined(REINTERPRET_OUTPUT_AS_3D)
2008 ,
2009 uint pad_bottom
2010#endif // REINTERPRET_OUTPUT_AS_3D
2011 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002012{
2013 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2014 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2015
2016 // Compute starting address for matrix A and Matrix B
2017 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2018
2019 // Update address for the matrix A
2020 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2021
2022 // Update address for the matrix B
2023 src_addr.s1 += idx * sizeof(float);
2024
Gian Marcoae2af742018-02-15 12:35:44 +00002025 // Add offset for batched GEMM
2026 src_addr.s0 += get_global_id(2) * src0_stride_z;
2027
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002028#if defined(MATRIX_B_DEPTH)
2029 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2030 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2031#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002032 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002033#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002034
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002035 // Initialize accumulators
2036 float acc00 = 0.0f;
2037 float acc01 = 0.0f;
2038
2039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2040 float acc10 = 0.0f;
2041 float acc11 = 0.0f;
2042#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2043#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2044 float acc20 = 0.0f;
2045 float acc21 = 0.0f;
2046#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2048 float acc30 = 0.0f;
2049 float acc31 = 0.0f;
2050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2051
2052 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002053 int i = 0;
2054 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002055 {
2056 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002057 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002058
2059 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002060 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2061 src_addr.s1 += src1_stride_y;
2062 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2063 src_addr.s1 += src1_stride_y;
2064 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2065 src_addr.s1 += src1_stride_y;
2066 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2067 src_addr.s1 += src1_stride_y;
2068 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2069 src_addr.s1 += src1_stride_y;
2070 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2071 src_addr.s1 += src1_stride_y;
2072 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2073 src_addr.s1 += src1_stride_y;
2074 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2075 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002076
2077 // Multiply and accumulate
2078 acc00 = fma(a0.s0, b0.s0, acc00);
2079 acc00 = fma(a0.s1, b1.s0, acc00);
2080 acc00 = fma(a0.s2, b2.s0, acc00);
2081 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002082 acc00 = fma(a0.s4, b4.s0, acc00);
2083 acc00 = fma(a0.s5, b5.s0, acc00);
2084 acc00 = fma(a0.s6, b6.s0, acc00);
2085 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002086
2087 acc01 = fma(a0.s0, b0.s1, acc01);
2088 acc01 = fma(a0.s1, b1.s1, acc01);
2089 acc01 = fma(a0.s2, b2.s1, acc01);
2090 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002091 acc01 = fma(a0.s4, b4.s1, acc01);
2092 acc01 = fma(a0.s5, b5.s1, acc01);
2093 acc01 = fma(a0.s6, b6.s1, acc01);
2094 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002095
2096#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002097 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002098 acc10 = fma(a0.s0, b0.s0, acc10);
2099 acc10 = fma(a0.s1, b1.s0, acc10);
2100 acc10 = fma(a0.s2, b2.s0, acc10);
2101 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002102 acc10 = fma(a0.s4, b4.s0, acc10);
2103 acc10 = fma(a0.s5, b5.s0, acc10);
2104 acc10 = fma(a0.s6, b6.s0, acc10);
2105 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002106
2107 acc11 = fma(a0.s0, b0.s1, acc11);
2108 acc11 = fma(a0.s1, b1.s1, acc11);
2109 acc11 = fma(a0.s2, b2.s1, acc11);
2110 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002111 acc11 = fma(a0.s4, b4.s1, acc11);
2112 acc11 = fma(a0.s5, b5.s1, acc11);
2113 acc11 = fma(a0.s6, b6.s1, acc11);
2114 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002115#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2116#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002117 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002118 acc20 = fma(a0.s0, b0.s0, acc20);
2119 acc20 = fma(a0.s1, b1.s0, acc20);
2120 acc20 = fma(a0.s2, b2.s0, acc20);
2121 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002122 acc20 = fma(a0.s4, b4.s0, acc20);
2123 acc20 = fma(a0.s5, b5.s0, acc20);
2124 acc20 = fma(a0.s6, b6.s0, acc20);
2125 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002126
2127 acc21 = fma(a0.s0, b0.s1, acc21);
2128 acc21 = fma(a0.s1, b1.s1, acc21);
2129 acc21 = fma(a0.s2, b2.s1, acc21);
2130 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002131 acc21 = fma(a0.s4, b4.s1, acc21);
2132 acc21 = fma(a0.s5, b5.s1, acc21);
2133 acc21 = fma(a0.s6, b6.s1, acc21);
2134 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002135#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2136#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002137 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002138 acc30 = fma(a0.s0, b0.s0, acc30);
2139 acc30 = fma(a0.s1, b1.s0, acc30);
2140 acc30 = fma(a0.s2, b2.s0, acc30);
2141 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002142 acc30 = fma(a0.s4, b4.s0, acc30);
2143 acc30 = fma(a0.s5, b5.s0, acc30);
2144 acc30 = fma(a0.s6, b6.s0, acc30);
2145 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002146
2147 acc31 = fma(a0.s0, b0.s1, acc31);
2148 acc31 = fma(a0.s1, b1.s1, acc31);
2149 acc31 = fma(a0.s2, b2.s1, acc31);
2150 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002151 acc31 = fma(a0.s4, b4.s1, acc31);
2152 acc31 = fma(a0.s5, b5.s1, acc31);
2153 acc31 = fma(a0.s6, b6.s1, acc31);
2154 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002155#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002156
2157 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002158 }
2159 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002160 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002161 {
2162 // Load values from matrix A
2163 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2165 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2166#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2167#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2168 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2169#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2171 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2172#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2173 // Load values from matrix B
2174 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002175 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002176
2177 // Multiply and accumulate
2178 acc00 = fma(a0, b0.s0, acc00);
2179 acc01 = fma(a0, b0.s1, acc01);
2180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2181 acc10 = fma(a1, b0.s0, acc10);
2182 acc11 = fma(a1, b0.s1, acc11);
2183#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2184#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2185 acc20 = fma(a2, b0.s0, acc20);
2186 acc21 = fma(a2, b0.s1, acc21);
2187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2189 acc30 = fma(a3, b0.s0, acc30);
2190 acc31 = fma(a3, b0.s1, acc31);
2191#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002192
2193 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002194 }
2195
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002196 // Multiply by the weight of matrix-matrix product and store the result
2197#if defined(ALPHA)
2198 acc00 = acc00 * ALPHA;
2199 acc01 = acc01 * ALPHA;
2200#endif // defined(ALPHA)
2201#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2202 acc10 = acc10 * ALPHA;
2203 acc11 = acc11 * ALPHA;
2204#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2206 acc20 = acc20 * ALPHA;
2207 acc21 = acc21 * ALPHA;
2208#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2210 acc30 = acc30 * ALPHA;
2211 acc31 = acc31 * ALPHA;
2212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2213
2214 int z = get_global_id(2);
2215
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002216 // Compute destination address
2217 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2218
Gian Marcoae2af742018-02-15 12:35:44 +00002219 // Compute dst address
2220 __global uchar *dst_addr = offset(&dst, 0, 0);
2221
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002222#if defined(REINTERPRET_OUTPUT_AS_3D)
2223 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2224 // in order to take into account the presence of possible bottom paddings
2225 //
2226 // | |
2227 // | plane0 |
2228 // | |
2229 // |_____________|
2230 // |*************|
2231 // | pad_bottom |
2232 // |*************|
2233 // | |
2234 // | plane1 |
2235 // | |
2236 // |_____________|
Gian Marcoae2af742018-02-15 12:35:44 +00002237
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002238 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2239 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2240 zout = min(DEPTH_GEMM3D - 1, zout);
2241
2242 // Add offset due to the bottom paddings
2243 zout *= (pad_bottom * dst_stride_y);
2244
2245 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2246 // multiply dst_stride_z by DEPTH_GEMM3D
2247 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2248
2249 // Store the output block
2250 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002252 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002255 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002258 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002260
2261#else // defined(REINTERPRET_OUTPUT_AS_3D)
2262 // Add offset for batched GEMM
2263 dst_addr += z * dst_stride_z;
2264
2265 // Store the output block
2266 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2267#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2268 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2269#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2270#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2271 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2272#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2273#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2274 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2275#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2276#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002277}
2278
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002279#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002280/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2281 *
2282 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2283 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2284 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2285 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2286 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2287 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2288 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2289 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002290 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2291 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2292 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2293 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2294 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2295 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002296 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2297 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2298 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2299 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2300 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2301 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2302 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2303 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2304 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2305 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2306 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2307 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2308 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2309 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2310 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2311 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2312 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2313 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002314 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2315 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2316 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2317 * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002318 */
2319__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2320 IMAGE_DECLARATION(src1),
2321 IMAGE_DECLARATION(dst),
2322 uint src0_stride_z,
2323 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002324 uint dst_stride_z
2325#if defined(REINTERPRET_OUTPUT_AS_3D)
2326 ,
2327 uint pad_bottom
2328#endif // REINTERPRET_OUTPUT_AS_3D
2329 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002330{
2331 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2332
2333 // Compute starting address for matrix A and Matrix B
2334 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2335
2336 // Update address for the matrix A
2337 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2338
2339 // Update address for the matrix B
2340 src_addr.s1 += idx * sizeof(half);
2341
2342 // Add offset for batched GEMM
2343 src_addr.s0 += get_global_id(2) * src0_stride_z;
2344
2345#if defined(MATRIX_B_DEPTH)
2346 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2347 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2348#else // defined(MATRIX_B_DEPTH)
2349 src_addr.s1 += get_global_id(2) * src1_stride_z;
2350#endif // defined(MATRIX_B_DEPTH)
2351
2352 half8 acc0 = 0.0h;
2353#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2354 half8 acc1 = 0.0h;
2355#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2356#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2357 half8 acc2 = 0.0h;
2358#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2360 half8 acc3 = 0.0h;
2361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2362
2363 int i = 0;
2364 for(; i <= ((int)COLS_A - 4); i += 4)
2365 {
2366 // Load values from matrix A
2367 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2368#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2369 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2370#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2371#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2372 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2373#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2375 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2377 // Load values from matrix B
2378 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2379 src_addr.s1 += src1_stride_y;
2380
2381 // Accumulate
2382 acc0 = fma(b0, (half8)a0.s0, acc0);
2383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2384 acc1 = fma(b0, (half8)a1.s0, acc1);
2385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2387 acc2 = fma(b0, (half8)a2.s0, acc2);
2388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2390 acc3 = fma(b0, (half8)a3.s0, acc3);
2391#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2392
2393 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2394 src_addr.s1 += src1_stride_y;
2395 acc0 = fma(b0, (half8)a0.s1, acc0);
2396#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2397 acc1 = fma(b0, (half8)a1.s1, acc1);
2398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2399#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2400 acc2 = fma(b0, (half8)a2.s1, acc2);
2401#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2402#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2403 acc3 = fma(b0, (half8)a3.s1, acc3);
2404#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2405
2406 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2407 src_addr.s1 += src1_stride_y;
2408 acc0 = fma(b0, (half8)a0.s2, acc0);
2409#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2410 acc1 = fma(b0, (half8)a1.s2, acc1);
2411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2412#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2413 acc2 = fma(b0, (half8)a2.s2, acc2);
2414#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2416 acc3 = fma(b0, (half8)a3.s2, acc3);
2417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2418
2419 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2420 src_addr.s1 += src1_stride_y;
2421 acc0 = fma(b0, (half8)a0.s3, acc0);
2422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2423 acc1 = fma(b0, (half8)a1.s3, acc1);
2424#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2425#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2426 acc2 = fma(b0, (half8)a2.s3, acc2);
2427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2429 acc3 = fma(b0, (half8)a3.s3, acc3);
2430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2431
2432 src_addr.s0 += 4 * sizeof(half);
2433 }
2434
2435 for(; i < (int)COLS_A; ++i)
2436 {
2437 // Load values from matrix A
2438 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2439#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2440 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2441#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2443 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2444#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2445#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2446 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2447#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2448 // Load values from matrix B
2449 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2450
2451 src_addr += (int2)(sizeof(half), src1_stride_y);
2452
2453 // Accumulate
2454 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2455#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2456 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2457#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2458#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2459 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2460#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2461#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2462 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2464 }
2465
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002466 // Multiply by the weight of matrix-matrix product and store the result
2467#if defined(ALPHA)
2468 acc0 = acc0 * (half8)ALPHA;
2469#endif // defined(ALPHA)
2470#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2471 acc1 = acc1 * (half8)ALPHA;
2472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2474 acc2 = acc2 * (half8)ALPHA;
2475#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2477 acc3 = acc3 * (half8)ALPHA;
2478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2479
2480 int z = get_global_id(2);
2481
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002482 // Compute destination address
2483 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2484
2485 // Compute dst address
2486 __global uchar *dst_addr = offset(&dst, 0, 0);
2487
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002488#if defined(REINTERPRET_OUTPUT_AS_3D)
2489 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2490 // in order to take into account the presence of possible bottom paddings
2491 //
2492 // | |
2493 // | plane0 |
2494 // | |
2495 // |_____________|
2496 // |*************|
2497 // | pad_bottom |
2498 // |*************|
2499 // | |
2500 // | plane1 |
2501 // | |
2502 // |_____________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002503
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002504 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2505 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2506 zout = min(DEPTH_GEMM3D - 1, zout);
2507
2508 // Add offset due to the bottom paddings
2509 zout *= (pad_bottom * dst_stride_y);
2510
2511 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2512 // multiply dst_stride_z by DEPTH_GEMM3D
2513 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2514
2515 // Store the output block
2516 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2517#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2518 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2519#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2520#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2521 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2522#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2523#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2524 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2525#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2526
2527#else // defined(REINTERPRET_OUTPUT_AS_3D)
2528 // Add offset for batched GEMM
2529 dst_addr += z * dst_stride_z;
2530
2531 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002532 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2533#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002534 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2535#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2536#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002537 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2538#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2539#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002540 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2541#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002542#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002543}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002544#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002545
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002546#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002547/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002548 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002549 * @note This OpenCL kernel works with fixed point data types QS8
2550 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002551 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002552 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002553 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002554 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2555 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002556 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002557 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002558 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2559 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2560 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2561 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2562 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2563 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2564 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2565 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2566 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2567 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2568 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2569 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2570 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2571 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2572 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2573 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2574 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002575 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2576 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2577 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002578 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002579__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002580 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002581 IMAGE_DECLARATION(dst),
2582 uint src0_stride_z,
2583 uint src1_stride_z,
2584 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002585{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002586 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002587
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002588 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002589 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002590
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002591 // Update address for the matrix A
2592 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002593
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002594 // Update address for the matrix B
2595 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002596
Gian Marcoae2af742018-02-15 12:35:44 +00002597 // Add offset for batched GEMM
2598 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002599
2600#if defined(MATRIX_B_DEPTH)
2601 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2602 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2603#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002604 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002605#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002606
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002607 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
2608
2609 short8 acc00 = 0;
2610 short8 acc01 = 0;
2611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2612 short8 acc10 = 0;
2613 short8 acc11 = 0;
2614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2616 short8 acc20 = 0;
2617 short8 acc21 = 0;
2618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2620 short8 acc30 = 0;
2621 short8 acc31 = 0;
2622#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2623
2624 // This for loop performs 4 accumulations per iteration
2625 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002626 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002627 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2628#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2629 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2631#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2632 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2633#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2635 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002637 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2638 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002639
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002640 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
2641 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
2642 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2643 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2644#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2645 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
2646 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
2647 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2648 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2649#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2650#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2651 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
2652 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
2653 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2654 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2655#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2656#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2657 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
2658 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
2659 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2660 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2661#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002662 }
2663
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002664 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002665 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
2666 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002667 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2668#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2669 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2670#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2671#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2672 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2673#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2674#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2675 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2676#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002677 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
2678
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002679 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
2680 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2682 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
2683 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
2684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2686 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
2687 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
2688#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2689#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2690 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
2691 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
2692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002693 }
2694
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002695 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002696 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2697
Gian Marcoae2af742018-02-15 12:35:44 +00002698 // Compute dst address
2699 __global uchar *dst_addr = offset(&dst, 0, 0);
2700
2701 // Add offset for batched GEMM
2702 dst_addr += get_global_id(2) * dst_stride_z;
2703
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002704 // Multiply by the weight of matrix product and store the result
2705 char16 acc_qs8;
2706 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002707#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002708 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002709#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002710 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002711#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2712 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002713#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002714 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002715#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002716 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002717#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2718#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2719 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002720#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002721 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002722#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002723 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002724#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2725#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2726 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002727#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002728 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002729#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002730 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002731#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002732}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002733
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002734/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002735 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002736 * @note This OpenCL kernel works with fixed point data types QS16
2737 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002738 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002739 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002740 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002741 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2742 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002743 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002744 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002745 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2746 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2747 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2748 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2749 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2750 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2751 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2752 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2753 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2754 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2755 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2756 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2757 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2758 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2759 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2760 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2761 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002762 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2763 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2764 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002765 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002766__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002767 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002768 IMAGE_DECLARATION(dst),
2769 uint src0_stride_z,
2770 uint src1_stride_z,
2771 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002772{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002773 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002774
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002775 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002776 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002777
2778 // Update address for the matrix A
2779 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2780
2781 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002782 src_addr.s1 += idx * sizeof(short);
2783
Gian Marcoae2af742018-02-15 12:35:44 +00002784 // Add offset for batched GEMM
2785 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002786
2787#if defined(MATRIX_B_DEPTH)
2788 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2789 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2790#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002791 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002792#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002793
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002794 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002795
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002796 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2798 int8 acc1 = 0;
2799#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2800#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2801 int8 acc2 = 0;
2802#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2803#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2804 int8 acc3 = 0;
2805#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002806
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002807 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002808 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002809 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002810 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2811#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2812 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2813#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2814#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2815 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2816#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2817#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2818 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2819#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002820 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2821 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002822
2823 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
2824 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002825#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2826 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
2827 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
2828#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2829#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2830 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
2831 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
2832#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2833#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2834 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
2835 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
2836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002837 }
2838
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002839 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002840 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
2841 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002842 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2843#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2844 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2846#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2847 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2848#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2849#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2850 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002852 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
2853
2854 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002855#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2856 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
2857#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2859 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
2860#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2862 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
2863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002864 }
2865
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002866 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002867 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2868
Gian Marcoae2af742018-02-15 12:35:44 +00002869 // Compute dst address
2870 __global uchar *dst_addr = offset(&dst, 0, 0);
2871
Gian Marco Iodice81b28c42018-03-29 10:29:36 +01002872 // Add offset for batched GEMM
2873 dst_addr += get_global_id(2) * dst_stride_z;
2874
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002875 // Multiply by the weight of matrix product and store the result
2876 short8 acc_qs16;
2877 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002878#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002879 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002880#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002881 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002882#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2883 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002884#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002885 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002886#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002887 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2889#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2890 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002891#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002892 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002893#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002894 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002895#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2896#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2897 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002898#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002899 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002900#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002901 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002903}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002904#endif // defined(FIXED_POINT_POSITION)
2905#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002906
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002907#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002908/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2909 *
Gian Marco19835e52018-01-30 13:35:54 +00002910 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002911 *
2912 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2913 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2914 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2915 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2916 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002917 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2918 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002919 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002920 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002921 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2922 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2923 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2924 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002925 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2926 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002927 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2928 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002929__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
2930 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002931{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002932 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002933 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2934 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002935
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002936 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002937 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2938
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002939 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002940 float4 c = vload4(0, (__global float *)src.ptr);
2941
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002942 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002943 float4 out = alpha_ab + (float4)BETA * c;
2944
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002945 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002946 vstore4(out, 0, (__global float *)dst.ptr);
2947}
2948
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002949#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002950/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2951 *
Gian Marco19835e52018-01-30 13:35:54 +00002952 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002953 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002954 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2955 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2956 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2957 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2958 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002959 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2960 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002961 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002962 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002963 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2964 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2965 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2966 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002967 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2968 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002969 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2970 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002971__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
2972 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002973{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002974 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002975 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2976 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002977
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002978 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002979 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2980
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002981 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002982 half8 c = vload8(0, (__global half *)src.ptr);
2983
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002984 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002985 half8 out = alpha_ab + (half8)BETA * c;
2986
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002987 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002988 vstore8(out, 0, (__global half *)dst.ptr);
2989}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002990#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002991
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002992#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002993/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2994 *
Gian Marco19835e52018-01-30 13:35:54 +00002995 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002996 *
2997 * @note: BETA must be passed in 8 bit fixed point format
2998 *
2999 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
3000 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3001 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3002 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3003 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003004 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3005 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003006 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
3007 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
3008 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3009 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3010 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3011 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003012 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3013 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003014 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3015 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003016__kernel void gemm_ma_qs8(TENSOR3D_DECLARATION(src),
3017 TENSOR3D_DECLARATION(dst))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003018{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003019 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003020 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3021 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003022
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003023 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003024 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
3025
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003026 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003027 char16 c = vload16(0, (__global char *)src.ptr);
3028
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003029 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003030 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
3031
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003032 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003033 vstore16(out, 0, (__global char *)dst.ptr);
3034}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003035
3036/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
3037 *
Gian Marco19835e52018-01-30 13:35:54 +00003038 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003039 *
3040 * @note: BETA must be passed in 16 bit fixed point format
3041 *
3042 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
3043 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3044 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3045 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3046 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003047 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3048 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003049 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
3050 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
3051 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3052 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3053 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3054 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003055 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3056 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3058 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003059__kernel void gemm_ma_qs16(TENSOR3D_DECLARATION(src),
3060 TENSOR3D_DECLARATION(dst))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003061{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003062 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003063 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3064 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003065
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003066 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003067 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
3068
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003069 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003070 short8 c = vload8(0, (__global short *)src.ptr);
3071
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003072 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003073 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
3074
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003075 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01003076 vstore8(out, 0, (__global short *)dst.ptr);
3077}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003078#endif // defined(FIXED_POINT_POSITION)
3079#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003080
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003081#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003082/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
3083 *
Gian Marco19835e52018-01-30 13:35:54 +00003084 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003085 *
Gian Marco19835e52018-01-30 13:35:54 +00003086 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003087 *
3088 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3089 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3090 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3091 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3092 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3093 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003094 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003095 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3096 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3097 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3098 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3099 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3100 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3101 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003102 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003103 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3104 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3105 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3106 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3107 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3108 */
3109__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
3110 TENSOR3D_DECLARATION(src1),
3111 IMAGE_DECLARATION(dst))
3112{
3113 int idx = get_global_id(0) * 4;
3114 int idy = get_global_id(1);
3115
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003116 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003117 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
3118 src_addr.s1 += idx * sizeof(float);
3119
3120 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
3121
3122 float4 acc = 0.0f;
3123
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003124 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003125 {
3126 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
3127 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3128 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
3129
3130 acc += b0 * (float4)a0.s0;
3131 acc += b1 * (float4)a0.s1;
3132 }
3133
3134 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
3135 {
3136 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
3137 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3138
3139 acc += b0 * (float4)a0;
3140 }
3141
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003142 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003143 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3144
3145 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
3146}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003147#endif // defined(WIDTH_VECTOR_A)
3148
3149/** This kernel accumulates each row with the biases vector.
3150 *
3151 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
3152 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
3153 *
3154 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
3155 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
3156 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
3157 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
3158 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3159 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
3160 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
3161 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
3162 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3163 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
3164 */
3165#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
3166__kernel void gemm_accumulate_biases(
3167 IMAGE_DECLARATION(accum),
3168 VECTOR_DECLARATION(biases))
3169{
3170 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
3171 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
3172
3173 // Vector size, i.e. number of vector elements.
3174 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3175 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
3176 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3177 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
3178#ifdef FIXED_POINT_POSITION
3179 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
3180#else // FIXED_POINT_POSITION
3181 accum_value = biases_value + accum_value;
3182#endif // FIXED_POINT_POSITION
3183 // Store result in the accumulate buffer
3184 VSTORE(VECTOR_SIZE)
3185 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
3186}
3187#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)