blob: 89d80367d193b47c5d49ebb7eb5ee961a26f0ac1 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Gian Marco36a0a462018-01-12 10:21:40 +000030#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
Gian Marco19835e52018-01-30 13:35:54 +000032#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000033#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000034#elif ELEMENT_SIZE == 2
35#define DATA_TYPE ushort
36#elif ELEMENT_SIZE == 4
37#define DATA_TYPE uint
38#else // ELEMENT_SIZE == 1
39#error "Element size not supported"
40#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000041
42/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043 *
Gian Marco19835e52018-01-30 13:35:54 +000044 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000046 *
47 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010048 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000052 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010055 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000057 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000059 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
61 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
63 */
Gian Marcoae2af742018-02-15 12:35:44 +000064__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066{
67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000069 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010071 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000072 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010074 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000075 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077
Gian Marcoae2af742018-02-15 12:35:44 +000078 // Add offset for batched GEMM
79 dst_addr_in_bytes += z * dst_stride_z;
80
Gian Marco36a0a462018-01-12 10:21:40 +000081 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
82 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083
Gian Marco36a0a462018-01-12 10:21:40 +000084 VSTORE(TRANSPOSE_W)
85 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086}
Gian Marco36a0a462018-01-12 10:21:40 +000087#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088
Gian Marco36a0a462018-01-12 10:21:40 +000089#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
90
91/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092 *
Gian Marco19835e52018-01-30 13:35:54 +000093 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
94 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
95 *
Gian Marco36a0a462018-01-12 10:21:40 +000096 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
102 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100104 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
106 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
107 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
108 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000109 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
110 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
112 */
Gian Marcoae2af742018-02-15 12:35:44 +0000113__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
114 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115{
Gian Marco36a0a462018-01-12 10:21:40 +0000116 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000119 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120
Gian Marcoae2af742018-02-15 12:35:44 +0000121 // Compute address for source tensor
122 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000124 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marcoae2af742018-02-15 12:35:44 +0000128 // Add offset for batched GEMM
129 dst_addr_in_bytes += z * dst_stride_z;
130
131 __global uchar *input_ptr = src.ptr;
132
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000133 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000134 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000135 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000136 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000138 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000140 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142
Gian Marco36a0a462018-01-12 10:21:40 +0000143 VEC_DATA_TYPE(DATA_TYPE, 4)
144 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100146
Gian Marco36a0a462018-01-12 10:21:40 +0000147 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149
Gian Marco36a0a462018-01-12 10:21:40 +0000150 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152
Gian Marco36a0a462018-01-12 10:21:40 +0000153 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155}
Gian Marco36a0a462018-01-12 10:21:40 +0000156#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157
Gian Marco36a0a462018-01-12 10:21:40 +0000158#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100160 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100161 *
Gian Marco19835e52018-01-30 13:35:54 +0000162 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
163 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
164 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000165 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
166 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167 *
168 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
169 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
170 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
171 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
172 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
173 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100174 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100175 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
176 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
178 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100180 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100181 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000182 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000184 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
186 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100187__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
188 IMAGE_DECLARATION(src1),
189 IMAGE_DECLARATION(dst),
190 uint src0_stride_z,
191 uint src1_stride_z,
192 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193{
Gian Marco36a0a462018-01-12 10:21:40 +0000194 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
195 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000196 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197
Gian Marco36a0a462018-01-12 10:21:40 +0000198 // Offset
199 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
200 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201
Gian Marco36a0a462018-01-12 10:21:40 +0000202 // src_addr_a = address of matrix A
203 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000204 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
205 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
206
207#if defined(MATRIX_B_DEPTH)
208 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
209 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
210#else // defined(MATRIX_B_DEPTH)
211 src1_addr_in_bytes += z * src1_stride_z;
212#endif // defined(MATRIX_B_DEPTH)
213
214 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
215 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000217 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000218 __global float *src_end_addr_b = src_addr_b + COLS_B;
219
220 src_addr_a += offset_row_a;
221 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000223 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224 float4 c00 = 0.0f;
225 float4 c10 = 0.0f;
226 float4 c20 = 0.0f;
227 float4 c30 = 0.0f;
228
Gian Marco36a0a462018-01-12 10:21:40 +0000229 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100230 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000231 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000232 float4 a0 = vload4(0, src_addr_a);
233 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234
235 c00 += (float4)a0.s0 * b0;
236 c10 += (float4)a0.s1 * b0;
237 c20 += (float4)a0.s2 * b0;
238 c30 += (float4)a0.s3 * b0;
239
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000240 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000241 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
242 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243
244 c00 += (float4)a0.s0 * b0;
245 c10 += (float4)a0.s1 * b0;
246 c20 += (float4)a0.s2 * b0;
247 c30 += (float4)a0.s3 * b0;
248 }
249
Gian Marco36a0a462018-01-12 10:21:40 +0000250 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000252 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000253 float4 a0 = vload4(0, src_addr_a);
254 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255
256 c00 += (float4)a0.s0 * b0;
257 c10 += (float4)a0.s1 * b0;
258 c20 += (float4)a0.s2 * b0;
259 c30 += (float4)a0.s3 * b0;
260 }
261
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000262 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
264
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000265#if defined(ALPHA)
266 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267 c00 = c00 * (float4)ALPHA;
268 c10 = c10 * (float4)ALPHA;
269 c20 = c20 * (float4)ALPHA;
270 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000271#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272
Gian Marcoae2af742018-02-15 12:35:44 +0000273 // Compute dst address
274 __global uchar *dst_addr = offset(&dst, 0, 0);
275
276 // Add offset for batched GEMM
277 dst_addr += z * dst_stride_z;
278
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000279 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000280 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
281 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
282 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
283 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284}
285
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000286/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100287 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100288 *
Gian Marco19835e52018-01-30 13:35:54 +0000289 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
290 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
291 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000292 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
293 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
294 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295 *
296 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
297 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
298 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
299 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
300 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
301 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100302 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100303 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
304 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
305 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
306 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
307 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100308 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000310 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100311 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000312 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
314 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100315__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
316 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000317 IMAGE_DECLARATION(dst),
318 uint src0_stride_z,
319 uint src1_stride_z,
320 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100321{
Gian Marco36a0a462018-01-12 10:21:40 +0000322 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
323 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000324 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000325
326 // Offset
327 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
328 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
329
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100330 // src_addr_a = address of matrix A
331 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000332 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
333 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
334
335#if defined(MATRIX_B_DEPTH)
336 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
337 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
338#else // defined(MATRIX_B_DEPTH)
339 src1_addr_in_bytes += z * src1_stride_z;
340#endif // defined(MATRIX_B_DEPTH)
341
342 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
343 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100344
Gian Marco36a0a462018-01-12 10:21:40 +0000345 src_addr_a += offset_row_a;
346 src_addr_b += offset_row_b;
347
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100348 // Reset accumulators
349 float c00 = 0.0f;
350 float c01 = 0.0f;
351 float c02 = 0.0f;
352 float c03 = 0.0f;
353 float c10 = 0.0f;
354 float c11 = 0.0f;
355 float c12 = 0.0f;
356 float c13 = 0.0f;
357 float c20 = 0.0f;
358 float c21 = 0.0f;
359 float c22 = 0.0f;
360 float c23 = 0.0f;
361 float c30 = 0.0f;
362 float c31 = 0.0f;
363 float c32 = 0.0f;
364 float c33 = 0.0f;
365
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100366#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
367
368 int i = 0;
369 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100370 {
371 // Load values from matrix A (interleaved) and matrix B (transposed)
372 float4 a0 = vload4(0, src_addr_a);
373 float4 b0 = vload4(0, src_addr_b);
374
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100375 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
376 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100377
378 c00 = fma(a0.s0, b0.s0, c00);
379 c01 = fma(a0.s0, b0.s1, c01);
380 c02 = fma(a0.s0, b0.s2, c02);
381 c03 = fma(a0.s0, b0.s3, c03);
382
383 c10 = fma(a0.s1, b0.s0, c10);
384 c11 = fma(a0.s1, b0.s1, c11);
385 c12 = fma(a0.s1, b0.s2, c12);
386 c13 = fma(a0.s1, b0.s3, c13);
387
388 c20 = fma(a0.s2, b0.s0, c20);
389 c21 = fma(a0.s2, b0.s1, c21);
390 c22 = fma(a0.s2, b0.s2, c22);
391 c23 = fma(a0.s2, b0.s3, c23);
392
393 c30 = fma(a0.s3, b0.s0, c30);
394 c31 = fma(a0.s3, b0.s1, c31);
395 c32 = fma(a0.s3, b0.s2, c32);
396 c33 = fma(a0.s3, b0.s3, c33);
397
398 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100399 a0 = vload4(0, src_addr_a);
400 b0 = vload4(0, src_addr_b);
401
402 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
403 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404
405 c00 = fma(a0.s0, b0.s0, c00);
406 c01 = fma(a0.s0, b0.s1, c01);
407 c02 = fma(a0.s0, b0.s2, c02);
408 c03 = fma(a0.s0, b0.s3, c03);
409
410 c10 = fma(a0.s1, b0.s0, c10);
411 c11 = fma(a0.s1, b0.s1, c11);
412 c12 = fma(a0.s1, b0.s2, c12);
413 c13 = fma(a0.s1, b0.s3, c13);
414
415 c20 = fma(a0.s2, b0.s0, c20);
416 c21 = fma(a0.s2, b0.s1, c21);
417 c22 = fma(a0.s2, b0.s2, c22);
418 c23 = fma(a0.s2, b0.s3, c23);
419
420 c30 = fma(a0.s3, b0.s0, c30);
421 c31 = fma(a0.s3, b0.s1, c31);
422 c32 = fma(a0.s3, b0.s2, c32);
423 c33 = fma(a0.s3, b0.s3, c33);
424
425 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100426 a0 = vload4(0, src_addr_a);
427 b0 = vload4(0, src_addr_b);
428
429 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
430 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
431
432 c00 = fma(a0.s0, b0.s0, c00);
433 c01 = fma(a0.s0, b0.s1, c01);
434 c02 = fma(a0.s0, b0.s2, c02);
435 c03 = fma(a0.s0, b0.s3, c03);
436
437 c10 = fma(a0.s1, b0.s0, c10);
438 c11 = fma(a0.s1, b0.s1, c11);
439 c12 = fma(a0.s1, b0.s2, c12);
440 c13 = fma(a0.s1, b0.s3, c13);
441
442 c20 = fma(a0.s2, b0.s0, c20);
443 c21 = fma(a0.s2, b0.s1, c21);
444 c22 = fma(a0.s2, b0.s2, c22);
445 c23 = fma(a0.s2, b0.s3, c23);
446
447 c30 = fma(a0.s3, b0.s0, c30);
448 c31 = fma(a0.s3, b0.s1, c31);
449 c32 = fma(a0.s3, b0.s2, c32);
450 c33 = fma(a0.s3, b0.s3, c33);
451
452 // Load values from matrix A (interleaved) and matrix B (transposed)
453 a0 = vload4(0, src_addr_a);
454 b0 = vload4(0, src_addr_b);
455
456 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
457 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458
459 c00 = fma(a0.s0, b0.s0, c00);
460 c01 = fma(a0.s0, b0.s1, c01);
461 c02 = fma(a0.s0, b0.s2, c02);
462 c03 = fma(a0.s0, b0.s3, c03);
463
464 c10 = fma(a0.s1, b0.s0, c10);
465 c11 = fma(a0.s1, b0.s1, c11);
466 c12 = fma(a0.s1, b0.s2, c12);
467 c13 = fma(a0.s1, b0.s3, c13);
468
469 c20 = fma(a0.s2, b0.s0, c20);
470 c21 = fma(a0.s2, b0.s1, c21);
471 c22 = fma(a0.s2, b0.s2, c22);
472 c23 = fma(a0.s2, b0.s3, c23);
473
474 c30 = fma(a0.s3, b0.s0, c30);
475 c31 = fma(a0.s3, b0.s1, c31);
476 c32 = fma(a0.s3, b0.s2, c32);
477 c33 = fma(a0.s3, b0.s3, c33);
478 }
479
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100480 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100481 {
482 // Load values from matrix A (interleaved) and matrix B (transposed)
483 float4 a0 = vload4(0, src_addr_a);
484 float4 b0 = vload4(0, src_addr_b);
485
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100486 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
487 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
488
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 c00 = fma(a0.s0, b0.s0, c00);
490 c01 = fma(a0.s0, b0.s1, c01);
491 c02 = fma(a0.s0, b0.s2, c02);
492 c03 = fma(a0.s0, b0.s3, c03);
493
494 c10 = fma(a0.s1, b0.s0, c10);
495 c11 = fma(a0.s1, b0.s1, c11);
496 c12 = fma(a0.s1, b0.s2, c12);
497 c13 = fma(a0.s1, b0.s3, c13);
498
499 c20 = fma(a0.s2, b0.s0, c20);
500 c21 = fma(a0.s2, b0.s1, c21);
501 c22 = fma(a0.s2, b0.s2, c22);
502 c23 = fma(a0.s2, b0.s3, c23);
503
504 c30 = fma(a0.s3, b0.s0, c30);
505 c31 = fma(a0.s3, b0.s1, c31);
506 c32 = fma(a0.s3, b0.s2, c32);
507 c33 = fma(a0.s3, b0.s3, c33);
508 }
509
510 // Compute destination address
511 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
512
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000513#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 // Multiply by the weight of matrix product
515 c00 = c00 * ALPHA;
516 c01 = c01 * ALPHA;
517 c02 = c02 * ALPHA;
518 c03 = c03 * ALPHA;
519 c10 = c10 * ALPHA;
520 c11 = c11 * ALPHA;
521 c12 = c12 * ALPHA;
522 c13 = c13 * ALPHA;
523 c20 = c20 * ALPHA;
524 c21 = c21 * ALPHA;
525 c22 = c22 * ALPHA;
526 c23 = c23 * ALPHA;
527 c30 = c30 * ALPHA;
528 c31 = c31 * ALPHA;
529 c32 = c32 * ALPHA;
530 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000531#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532
Gian Marcoae2af742018-02-15 12:35:44 +0000533 // Compute dst address
534 __global uchar *dst_addr = offset(&dst, 0, 0);
535
536 // Add offset for batched GEMM
537 dst_addr += z * dst_stride_z;
538
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000540 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
541 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
542 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
543 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100544}
545
Georgios Pinitas84225582018-05-14 12:00:05 +0100546// Undefine local defines
547#undef COLS_MTX_B
548
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100549#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100550/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100551 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100552 *
Gian Marco19835e52018-01-30 13:35:54 +0000553 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
554 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
555 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000556 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
557 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558 *
559 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
560 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
561 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
562 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
563 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
564 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100565 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100566 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
567 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
568 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
569 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
570 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100571 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100572 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000573 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100574 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000575 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100576 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
577 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100578__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
579 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000580 IMAGE_DECLARATION(dst),
581 uint src0_stride_z,
582 uint src1_stride_z,
583 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100584{
Gian Marco36a0a462018-01-12 10:21:40 +0000585 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
586 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000587 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100588
Gian Marco36a0a462018-01-12 10:21:40 +0000589 // Offset
590 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
591 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100592
Gian Marco36a0a462018-01-12 10:21:40 +0000593 // src_addr_a = address of matrix A
594 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000595 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
596 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
597
598#if defined(MATRIX_B_DEPTH)
599 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
600 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
601#else // defined(MATRIX_B_DEPTH)
602 src1_addr_in_bytes += z * src1_stride_z;
603#endif // defined(MATRIX_B_DEPTH)
604
605 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
606 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000608 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000609 __global half *src_end_addr_b = src_addr_b + COLS_B;
610
611 src_addr_a += offset_row_a;
612 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100613
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000614 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100615 half8 c00 = 0.0f;
616 half8 c10 = 0.0f;
617 half8 c20 = 0.0f;
618 half8 c30 = 0.0f;
619
Gian Marco36a0a462018-01-12 10:21:40 +0000620 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100621 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000622 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000623 half4 a0 = vload4(0, src_addr_a);
624 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100625
626 c00 += (half8)a0.s0 * b0;
627 c10 += (half8)a0.s1 * b0;
628 c20 += (half8)a0.s2 * b0;
629 c30 += (half8)a0.s3 * b0;
630
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000631 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000632 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
633 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634
635 c00 += (half8)a0.s0 * b0;
636 c10 += (half8)a0.s1 * b0;
637 c20 += (half8)a0.s2 * b0;
638 c30 += (half8)a0.s3 * b0;
639 }
640
Gian Marco36a0a462018-01-12 10:21:40 +0000641 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000643 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000644 half4 a0 = vload4(0, src_addr_a);
645 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100646
647 c00 += (half8)a0.s0 * b0;
648 c10 += (half8)a0.s1 * b0;
649 c20 += (half8)a0.s2 * b0;
650 c30 += (half8)a0.s3 * b0;
651 }
652
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000653 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100654 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
655
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000656#if defined(ALPHA)
657 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100658 c00 = c00 * (half8)ALPHA;
659 c10 = c10 * (half8)ALPHA;
660 c20 = c20 * (half8)ALPHA;
661 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000662#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100663
Gian Marcoae2af742018-02-15 12:35:44 +0000664 // Compute dst address
665 __global uchar *dst_addr = offset(&dst, 0, 0);
666
667 // Add offset for batched GEMM
668 dst_addr += z * dst_stride_z;
669
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000670 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000671 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
672 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
673 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
674 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100675}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100676
677/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
678 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
679 *
680 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
681 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
682 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
683 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
684 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
685 *
686 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
687 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
688 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
689 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
690 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
691 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
692 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
693 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
694 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
695 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
696 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
697 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
698 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
699 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
700 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
701 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
702 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
703 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
704 */
705__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
706 IMAGE_DECLARATION(src1),
707 IMAGE_DECLARATION(dst),
708 uint src0_stride_z,
709 uint src1_stride_z,
710 uint dst_stride_z)
711{
712 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
713 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
714 int z = get_global_id(2);
715
716 // Offset
717 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
718 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
719
720 // src_addr_a = address of matrix A
721 // src_addr_b = address of matrix B
722 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
723 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
724
725#if defined(MATRIX_B_DEPTH)
726 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
727 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
728#else // defined(MATRIX_B_DEPTH)
729 src1_addr_in_bytes += z * src1_stride_z;
730#endif // defined(MATRIX_B_DEPTH)
731
732 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
733 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
734
735 // Compute end row address for matrix B
736 __global half *src_end_addr_b = src_addr_b + COLS_B;
737
738 src_addr_a += offset_row_a;
739 src_addr_b += offset_row_b;
740
741 // Reset accumulators
742 half8 c00 = 0.0f;
743 half8 c10 = 0.0f;
744 half8 c20 = 0.0f;
745 half8 c30 = 0.0f;
746
747#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
748
749 int i = 0;
750 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
751 {
752#if MULT_INTERLEAVE4X4_HEIGHT == 1
753 // Load values from matrix A (interleaved) and matrix B (transposed)
754 half8 a0 = vload8(0, src_addr_a);
755 half8 b0 = vload8(0, src_addr_b);
756
757 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
758 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
759
760 c00 = fma((half8)a0.s0, b0, c00);
761 c10 = fma((half8)a0.s1, b0, c10);
762 c20 = fma((half8)a0.s2, b0, c20);
763 c30 = fma((half8)a0.s3, b0, c30);
764
765 // Load values from matrix B (transposed)
766 b0 = vload8(0, src_addr_b);
767
768 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
769
770 c00 = fma((half8)a0.s4, b0, c00);
771 c10 = fma((half8)a0.s5, b0, c10);
772 c20 = fma((half8)a0.s6, b0, c20);
773 c30 = fma((half8)a0.s7, b0, c30);
774
775 // Load values from matrix A (interleaved) and matrix B (transposed)
776 a0 = vload8(0, src_addr_a);
777 b0 = vload8(0, src_addr_b);
778
779 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
780 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
781
782 c00 = fma((half8)a0.s0, b0, c00);
783 c10 = fma((half8)a0.s1, b0, c10);
784 c20 = fma((half8)a0.s2, b0, c20);
785 c30 = fma((half8)a0.s3, b0, c30);
786
787 // Load values from matrix B (transposed)
788 b0 = vload8(0, src_addr_b);
789
790 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
791
792 c00 = fma((half8)a0.s4, b0, c00);
793 c10 = fma((half8)a0.s5, b0, c10);
794 c20 = fma((half8)a0.s6, b0, c20);
795 c30 = fma((half8)a0.s7, b0, c30);
796#else // MULT_INTERLEAVE4X4_HEIGHT == 1
797 // Load values from matrix A (interleaved) and matrix B (transposed)
798 half4 a0 = vload4(0, src_addr_a);
799 half8 b0 = vload8(0, src_addr_b);
800
801 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
802 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
803
804 c00 = fma((half8)a0.s0, b0, c00);
805 c10 = fma((half8)a0.s1, b0, c10);
806 c20 = fma((half8)a0.s2, b0, c20);
807 c30 = fma((half8)a0.s3, b0, c30);
808
809 // Load values from matrix A (interleaved) and matrix B (transposed)
810 a0 = vload4(0, src_addr_a);
811 b0 = vload8(0, src_addr_b);
812
813 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
814 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
815
816 c00 = fma((half8)a0.s0, b0, c00);
817 c10 = fma((half8)a0.s1, b0, c10);
818 c20 = fma((half8)a0.s2, b0, c20);
819 c30 = fma((half8)a0.s3, b0, c30);
820
821 // Load values from matrix A (interleaved) and matrix B (transposed)
822 a0 = vload4(0, src_addr_a);
823 b0 = vload8(0, src_addr_b);
824
825 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
826 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
827
828 c00 = fma((half8)a0.s0, b0, c00);
829 c10 = fma((half8)a0.s1, b0, c10);
830 c20 = fma((half8)a0.s2, b0, c20);
831 c30 = fma((half8)a0.s3, b0, c30);
832
833 // Load values from matrix A (interleaved) and matrix B (transposed)
834 a0 = vload4(0, src_addr_a);
835 b0 = vload8(0, src_addr_b);
836
837 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
838 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
839
840 c00 = fma((half8)a0.s0, b0, c00);
841 c10 = fma((half8)a0.s1, b0, c10);
842 c20 = fma((half8)a0.s2, b0, c20);
843 c30 = fma((half8)a0.s3, b0, c30);
844#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
845 }
846
847 for(; i < (int)(COLS_MTX_B); ++i)
848 {
849 // Load values from matrix A (interleaved) and matrix B (transposed)
850 half4 a0 = vload4(0, src_addr_a);
851 half8 b0 = vload8(0, src_addr_b);
852
853 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
854 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
855
856 c00 = fma((half8)a0.s0, b0, c00);
857 c10 = fma((half8)a0.s1, b0, c10);
858 c20 = fma((half8)a0.s2, b0, c20);
859 c30 = fma((half8)a0.s3, b0, c30);
860 }
861
862 // Compute destination address
863 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
864
865#if defined(ALPHA)
866 // Multiply by the weight of matrix product
867 c00 = c00 * (half8)ALPHA;
868 c10 = c10 * (half8)ALPHA;
869 c20 = c20 * (half8)ALPHA;
870 c30 = c30 * (half8)ALPHA;
871#endif // defined(ALPHA)
872
873 // Compute dst address
874 __global uchar *dst_addr = offset(&dst, 0, 0);
875
876 // Add offset for batched GEMM
877 dst_addr += z * dst_stride_z;
878
879 // Store 4x8 block
880 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
881 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
882 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
883 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
884}
Georgios Pinitas84225582018-05-14 12:00:05 +0100885
886// Undefine local defines
887#undef COLS_MTX_B
888
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100889#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100890
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000891#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100892/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
893 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
894 *
Gian Marco19835e52018-01-30 13:35:54 +0000895 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
896 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
897 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000898 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
899 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
900 * @note:ALPHA must be passed in 8 bit fixed point format
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100901 *
902 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
903 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
904 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
905 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
906 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
907 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
908 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
909 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
910 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
911 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
912 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
913 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
914 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
915 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000916 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100917 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000918 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100919 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
920 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100921__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
922 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000923 IMAGE_DECLARATION(dst),
924 uint src0_stride_z,
925 uint src1_stride_z,
926 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100927{
Gian Marco36a0a462018-01-12 10:21:40 +0000928 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
929 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000930 int z = get_global_id(2);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100931
Gian Marco36a0a462018-01-12 10:21:40 +0000932 // Offset
933 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
934 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100935
Gian Marco36a0a462018-01-12 10:21:40 +0000936 // src_addr_a = address of matrix A
937 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000938 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
939 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
940
941#if defined(MATRIX_B_DEPTH)
942 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
943 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
944#else // defined(MATRIX_B_DEPTH)
945 src1_addr_in_bytes += z * src1_stride_z;
946#endif // defined(MATRIX_B_DEPTH)
947
948 __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
949 __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100950
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000951 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000952 __global char *src_end_addr_b = src_addr_b + COLS_B;
953
954 src_addr_a += offset_row_a;
955 src_addr_b += offset_row_b;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100956
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000957 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100958 short8 c00 = 0.0f;
959 short8 c10 = 0.0f;
960 short8 c20 = 0.0f;
961 short8 c30 = 0.0f;
962 short8 c01 = 0.0f;
963 short8 c11 = 0.0f;
964 short8 c21 = 0.0f;
965 short8 c31 = 0.0f;
966
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000967 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000968 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100969 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000970 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000971 char4 a0 = vload4(0, src_addr_a);
972 char16 b0 = vload16(0, src_addr_b);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100973
974 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
975 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
976 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
977 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
978
979 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
980 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
981 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
982 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
983 }
984
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000985 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100986 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
987
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000988 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100989 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
990 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
991 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
992 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
993
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000994#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100995 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
996 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
997 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
998 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000999#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001000
Gian Marcoae2af742018-02-15 12:35:44 +00001001 // Compute dst address
1002 __global uchar *dst_addr = offset(&dst, 0, 0);
1003
1004 // Add offset for batched GEMM
1005 dst_addr += z * dst_stride_z;
1006
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001007 // Store 16x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001008 vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
1009 vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
1010 vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
1011 vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001012}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001013
1014/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
1015 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1016 *
Gian Marco19835e52018-01-30 13:35:54 +00001017 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1018 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1019 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001020 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1021 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1022 * @note:ALPHA must be passed in 16 bit fixed point format
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001023 *
1024 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
1025 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1026 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1027 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1028 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1029 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1030 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1031 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1032 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1033 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1034 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1035 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1036 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1037 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001038 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001039 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001040 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001041 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1042 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001043__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
1044 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001045 IMAGE_DECLARATION(dst),
1046 uint src0_stride_z,
1047 uint src1_stride_z,
1048 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001049{
Gian Marco36a0a462018-01-12 10:21:40 +00001050 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1051 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001052 int z = get_global_id(2);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001053
Gian Marco36a0a462018-01-12 10:21:40 +00001054 // Offset
1055 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1056 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001057
Gian Marco36a0a462018-01-12 10:21:40 +00001058 // src_addr_a = address of matrix A
1059 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001060 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1061 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1062
1063#if defined(MATRIX_B_DEPTH)
1064 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1065 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1066#else // defined(MATRIX_B_DEPTH)
1067 src1_addr_in_bytes += z * src1_stride_z;
1068#endif // defined(MATRIX_B_DEPTH)
1069
1070 __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
1071 __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001072
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001073 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001074 __global short *src_end_addr_b = src_addr_b + COLS_B;
1075
1076 src_addr_a += offset_row_a;
1077 src_addr_b += offset_row_b;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001078
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001079 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001080 int8 c00 = 0.0f;
1081 int8 c10 = 0.0f;
1082 int8 c20 = 0.0f;
1083 int8 c30 = 0.0f;
1084
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001085 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +00001086 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001087 {
1088 /* Load values from matrix A (interleaved) and matrix B (transposed) */
Gian Marco36a0a462018-01-12 10:21:40 +00001089 short4 a0 = vload4(0, src_addr_a);
1090 short8 b0 = vload8(0, src_addr_b);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001091
1092 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1093 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
1094 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
1095 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
1096 }
1097
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001098 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001099 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1100
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001101 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001102 short8 c00_qs16 = convert_short8_sat(c00);
1103 short8 c10_qs16 = convert_short8_sat(c10);
1104 short8 c20_qs16 = convert_short8_sat(c20);
1105 short8 c30_qs16 = convert_short8_sat(c30);
1106
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001107#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001108 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1109 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1110 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1111 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001112#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001113
Gian Marcoae2af742018-02-15 12:35:44 +00001114 // Compute dst address
1115 __global uchar *dst_addr = offset(&dst, 0, 0);
1116
1117 // Add offset for batched GEMM
1118 dst_addr += z * dst_stride_z;
1119
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001120 // Store 8x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001121 vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
1122 vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
1123 vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
1124 vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001125}
1126#endif // defined(FIXED_POINT_POSITION)
Gian Marco36a0a462018-01-12 10:21:40 +00001127#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001128
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001129#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1130#if defined(DATA_TYPE)
1131#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001132/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001133 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001134 * @note This OpenCL kernel works with floating point data types (F16/F32)
1135 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1136 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001137 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001138 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1139 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001140 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001141 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1143 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1144 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1145 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1146 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001147 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001148 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1149 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1150 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1151 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1152 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001153 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001154 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1155 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1156 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1157 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1159 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001160__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1161 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001162 IMAGE_DECLARATION(dst),
1163 uint src0_stride_z,
1164 uint src1_stride_z,
1165 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001166{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001167 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001168
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001169 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001170 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001171
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001172 // Update address for the matrix A
1173 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001174
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001175 // Update address for the matrix B
1176 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001177
Gian Marcoae2af742018-02-15 12:35:44 +00001178 // Add offset for batched GEMM
1179 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001180
1181#if defined(MATRIX_B_DEPTH)
1182 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1183 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1184#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001185 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001186#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001187
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001188 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1189
1190 VECTOR_TYPE acc0 = 0.0f;
1191#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1192 VECTOR_TYPE acc1 = 0.0f;
1193#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1194#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1195 VECTOR_TYPE acc2 = 0.0f;
1196#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1197#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1198 VECTOR_TYPE acc3 = 0.0f;
1199#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1200
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001201 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001202 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001203 // Load values from matrix A
1204 VEC_DATA_TYPE(DATA_TYPE, 2)
1205 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1207 VEC_DATA_TYPE(DATA_TYPE, 2)
1208 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1209#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1210#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1211 VEC_DATA_TYPE(DATA_TYPE, 2)
1212 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1213#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1214#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1215 VEC_DATA_TYPE(DATA_TYPE, 2)
1216 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1218 // Load values from matrix B
1219 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1220 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001221
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001222 // Accumulate
1223 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1224 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1226 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1227 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1228#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1229#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1230 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1231 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1232#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1233#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1234 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1235 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1236#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001237 }
1238
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001239 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001240 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001241 // Load values from matrix A
1242 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1243#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1244 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1245#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1246#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1247 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1249#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1250 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1251#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1252 // Load values from matrix B
1253 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001254
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001255 // Accumulate
1256 acc0 += b0 * (VECTOR_TYPE)a0;
1257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1258 acc1 += b0 * (VECTOR_TYPE)a1;
1259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1260#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1261 acc2 += b0 * (VECTOR_TYPE)a2;
1262#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1264 acc3 += b0 * (VECTOR_TYPE)a3;
1265#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001266 }
1267
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001268 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001269 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1270
Gian Marcoae2af742018-02-15 12:35:44 +00001271 // Compute dst address
1272 __global uchar *dst_addr = offset(&dst, 0, 0);
1273
1274 // Add offset for batched GEMM
1275 dst_addr += get_global_id(2) * dst_stride_z;
1276
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001277 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001278#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001279 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001280#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001281 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001282 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001283#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001284#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001285 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001286#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001287 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001288 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001291#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001292 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001293#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001294 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001295 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1297#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001298#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001299 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001300#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001301 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001302 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001303#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001304}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001305#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001306
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001307/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001308 *
1309 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1310 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1311 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1312 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1313 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001314 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1315 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001316 *
1317 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1318 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1319 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1320 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1321 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1322 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1323 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1324 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1325 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1326 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1327 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1328 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1329 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1330 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1331 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1332 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1333 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1334 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1335 */
1336__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1337 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001338 IMAGE_DECLARATION(dst),
1339 uint src0_stride_z,
1340 uint src1_stride_z,
1341 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001342{
1343 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1344
1345 // Compute starting address for matrix A and matrix B
1346 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1347
1348 // Update address for matrix A
1349 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1350
1351 // Update address for matrix B
1352 src_addr.s1 += idx * sizeof(float);
1353
Gian Marcoae2af742018-02-15 12:35:44 +00001354 // Add offset for batched GEMM
1355 src_addr.s0 += get_global_id(2) * src0_stride_z;
1356
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001357#if defined(MATRIX_B_DEPTH)
1358 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1359 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1360#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001361 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001362#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001363
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001364 // Initialize accumulators
1365 float acc00 = 0.0f;
1366 float acc01 = 0.0f;
1367 float acc02 = 0.0f;
1368 float acc03 = 0.0f;
1369
1370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1371 float acc10 = 0.0f;
1372 float acc11 = 0.0f;
1373 float acc12 = 0.0f;
1374 float acc13 = 0.0f;
1375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1376
1377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1378 float acc20 = 0.0f;
1379 float acc21 = 0.0f;
1380 float acc22 = 0.0f;
1381 float acc23 = 0.0f;
1382#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1383
1384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1385 float acc30 = 0.0f;
1386 float acc31 = 0.0f;
1387 float acc32 = 0.0f;
1388 float acc33 = 0.0f;
1389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1390
1391 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001392 int i = 0;
1393 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001394 {
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001395 // Load values from matrix A and matrix B
1396 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001397#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001398 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001399#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001401 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001404 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001406 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1407 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001408
1409 // Multiply and accumulate
1410 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001411 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001412 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001413 acc03 = fma(a0.s0, b0.s3, acc03);
1414
1415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001416
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001417 acc10 = fma(a1.s0, b0.s0, acc10);
1418 acc11 = fma(a1.s0, b0.s1, acc11);
1419 acc12 = fma(a1.s0, b0.s2, acc12);
1420 acc13 = fma(a1.s0, b0.s3, acc13);
1421
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001422#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1423#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001424
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001425 acc20 = fma(a2.s0, b0.s0, acc20);
1426 acc21 = fma(a2.s0, b0.s1, acc21);
1427 acc22 = fma(a2.s0, b0.s2, acc22);
1428 acc23 = fma(a2.s0, b0.s3, acc23);
1429
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1431#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001432
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001433 acc30 = fma(a3.s0, b0.s0, acc30);
1434 acc31 = fma(a3.s0, b0.s1, acc31);
1435 acc32 = fma(a3.s0, b0.s2, acc32);
1436 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001437#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001438
1439 // Load values from matrix A and matrix B
1440 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1441 src_addr.s1 += src1_stride_y;
1442
1443 // Multiply and accumulate
1444 acc00 = fma(a0.s1, b0.s0, acc00);
1445 acc01 = fma(a0.s1, b0.s1, acc01);
1446 acc02 = fma(a0.s1, b0.s2, acc02);
1447 acc03 = fma(a0.s1, b0.s3, acc03);
1448
1449#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1450
1451 acc10 = fma(a1.s1, b0.s0, acc10);
1452 acc11 = fma(a1.s1, b0.s1, acc11);
1453 acc12 = fma(a1.s1, b0.s2, acc12);
1454 acc13 = fma(a1.s1, b0.s3, acc13);
1455
1456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1457#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1458
1459 acc20 = fma(a2.s1, b0.s0, acc20);
1460 acc21 = fma(a2.s1, b0.s1, acc21);
1461 acc22 = fma(a2.s1, b0.s2, acc22);
1462 acc23 = fma(a2.s1, b0.s3, acc23);
1463
1464#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1465#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1466
1467 acc30 = fma(a3.s1, b0.s0, acc30);
1468 acc31 = fma(a3.s1, b0.s1, acc31);
1469 acc32 = fma(a3.s1, b0.s2, acc32);
1470 acc33 = fma(a3.s1, b0.s3, acc33);
1471#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1472
1473 // Load values from matrix A and matrix B
1474 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1475 src_addr.s1 += src1_stride_y;
1476
1477 // Multiply and accumulate
1478 acc00 = fma(a0.s2, b0.s0, acc00);
1479 acc01 = fma(a0.s2, b0.s1, acc01);
1480 acc02 = fma(a0.s2, b0.s2, acc02);
1481 acc03 = fma(a0.s2, b0.s3, acc03);
1482
1483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1484
1485 acc10 = fma(a1.s2, b0.s0, acc10);
1486 acc11 = fma(a1.s2, b0.s1, acc11);
1487 acc12 = fma(a1.s2, b0.s2, acc12);
1488 acc13 = fma(a1.s2, b0.s3, acc13);
1489
1490#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1491#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1492
1493 acc20 = fma(a2.s2, b0.s0, acc20);
1494 acc21 = fma(a2.s2, b0.s1, acc21);
1495 acc22 = fma(a2.s2, b0.s2, acc22);
1496 acc23 = fma(a2.s2, b0.s3, acc23);
1497
1498#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1499#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1500
1501 acc30 = fma(a3.s2, b0.s0, acc30);
1502 acc31 = fma(a3.s2, b0.s1, acc31);
1503 acc32 = fma(a3.s2, b0.s2, acc32);
1504 acc33 = fma(a3.s2, b0.s3, acc33);
1505#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1506
1507 // Load values from matrix A and matrix B
1508 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1509 src_addr.s1 += src1_stride_y;
1510
1511 // Multiply and accumulate
1512 acc00 = fma(a0.s3, b0.s0, acc00);
1513 acc01 = fma(a0.s3, b0.s1, acc01);
1514 acc02 = fma(a0.s3, b0.s2, acc02);
1515 acc03 = fma(a0.s3, b0.s3, acc03);
1516
1517#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1518
1519 acc10 = fma(a1.s3, b0.s0, acc10);
1520 acc11 = fma(a1.s3, b0.s1, acc11);
1521 acc12 = fma(a1.s3, b0.s2, acc12);
1522 acc13 = fma(a1.s3, b0.s3, acc13);
1523
1524#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1525#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1526
1527 acc20 = fma(a2.s3, b0.s0, acc20);
1528 acc21 = fma(a2.s3, b0.s1, acc21);
1529 acc22 = fma(a2.s3, b0.s2, acc22);
1530 acc23 = fma(a2.s3, b0.s3, acc23);
1531
1532#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1533#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1534
1535 acc30 = fma(a3.s3, b0.s0, acc30);
1536 acc31 = fma(a3.s3, b0.s1, acc31);
1537 acc32 = fma(a3.s3, b0.s2, acc32);
1538 acc33 = fma(a3.s3, b0.s3, acc33);
1539#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1540
1541 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001542 }
1543
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001544 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001545 {
1546 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001547 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001548#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1549 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1550#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1551#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1552 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1553#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1554#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1555 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1556#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1557 // Load values from matrix B
1558 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001559 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001560
1561 // Multiply and accumulate
1562 acc00 = fma(a0, b0.s0, acc00);
1563 acc01 = fma(a0, b0.s1, acc01);
1564 acc02 = fma(a0, b0.s2, acc02);
1565 acc03 = fma(a0, b0.s3, acc03);
1566#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1567 acc10 = fma(a1, b0.s0, acc10);
1568 acc11 = fma(a1, b0.s1, acc11);
1569 acc12 = fma(a1, b0.s2, acc12);
1570 acc13 = fma(a1, b0.s3, acc13);
1571#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1572#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1573 acc20 = fma(a2, b0.s0, acc20);
1574 acc21 = fma(a2, b0.s1, acc21);
1575 acc22 = fma(a2, b0.s2, acc22);
1576 acc23 = fma(a2, b0.s3, acc23);
1577#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1578#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1579 acc30 = fma(a3, b0.s0, acc30);
1580 acc31 = fma(a3, b0.s1, acc31);
1581 acc32 = fma(a3, b0.s2, acc32);
1582 acc33 = fma(a3, b0.s3, acc33);
1583#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001584
1585 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001586 }
1587
1588 // Compute destination address
1589 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1590
1591 // Multiply by the weight of matrix-matrix product and store the result
1592#if defined(ALPHA)
1593 acc00 = acc00 * ALPHA;
1594 acc01 = acc01 * ALPHA;
1595 acc02 = acc02 * ALPHA;
1596 acc03 = acc03 * ALPHA;
1597#endif // defined(ALPHA)
1598
Gian Marcoae2af742018-02-15 12:35:44 +00001599 // Compute dst address
1600 __global uchar *dst_addr = offset(&dst, 0, 0);
1601
1602 // Add offset for batched GEMM
1603 dst_addr += get_global_id(2) * dst_stride_z;
1604
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001605 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
Gian Marcoae2af742018-02-15 12:35:44 +00001606 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001607
1608#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1609#if defined(ALPHA)
1610 acc10 = acc10 * ALPHA;
1611 acc11 = acc11 * ALPHA;
1612 acc12 = acc12 * ALPHA;
1613 acc13 = acc13 * ALPHA;
1614#endif // defined(ALPHA)
1615 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
Gian Marcoae2af742018-02-15 12:35:44 +00001616 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1619#if defined(ALPHA)
1620 acc20 = acc20 * ALPHA;
1621 acc21 = acc21 * ALPHA;
1622 acc22 = acc22 * ALPHA;
1623 acc23 = acc23 * ALPHA;
1624#endif // defined(ALPHA)
1625 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
Gian Marcoae2af742018-02-15 12:35:44 +00001626 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001627#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1628#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1629#if defined(ALPHA)
1630 acc30 = acc30 * ALPHA;
1631 acc31 = acc31 * ALPHA;
1632 acc32 = acc32 * ALPHA;
1633 acc33 = acc33 * ALPHA;
1634#endif // defined(ALPHA)
1635 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
Gian Marcoae2af742018-02-15 12:35:44 +00001636 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001637#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1638}
1639
1640/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1641 *
1642 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1643 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1644 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1645 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1646 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1647 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001648 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1649 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001650 *
1651 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1652 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1653 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1654 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1655 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1656 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1657 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1658 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1659 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1660 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1661 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1662 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1663 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1664 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1665 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1666 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1667 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1668 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1669 */
1670__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1671 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001672 IMAGE_DECLARATION(dst),
1673 uint src0_stride_z,
1674 uint src1_stride_z,
1675 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001676{
1677 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1678 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1679
1680 // Compute starting address for matrix A and Matrix B
1681 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1682
1683 // Update address for the matrix A
1684 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1685
1686 // Update address for the matrix B
1687 src_addr.s1 += idx * sizeof(float);
1688
Gian Marcoae2af742018-02-15 12:35:44 +00001689 // Add offset for batched GEMM
1690 src_addr.s0 += get_global_id(2) * src0_stride_z;
1691
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001692#if defined(MATRIX_B_DEPTH)
1693 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1694 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1695#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001696 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001697#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001698
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001699 // Initialize accumulators
1700 float acc00 = 0.0f;
1701 float acc01 = 0.0f;
1702
1703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1704 float acc10 = 0.0f;
1705 float acc11 = 0.0f;
1706#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1707#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1708 float acc20 = 0.0f;
1709 float acc21 = 0.0f;
1710#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1711#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1712 float acc30 = 0.0f;
1713 float acc31 = 0.0f;
1714#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1715
1716 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001717 int i = 0;
1718 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001719 {
1720 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001721 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001722
1723 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001724 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1725 src_addr.s1 += src1_stride_y;
1726 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1727 src_addr.s1 += src1_stride_y;
1728 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1729 src_addr.s1 += src1_stride_y;
1730 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1731 src_addr.s1 += src1_stride_y;
1732 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1733 src_addr.s1 += src1_stride_y;
1734 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1735 src_addr.s1 += src1_stride_y;
1736 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1737 src_addr.s1 += src1_stride_y;
1738 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1739 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001740
1741 // Multiply and accumulate
1742 acc00 = fma(a0.s0, b0.s0, acc00);
1743 acc00 = fma(a0.s1, b1.s0, acc00);
1744 acc00 = fma(a0.s2, b2.s0, acc00);
1745 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001746 acc00 = fma(a0.s4, b4.s0, acc00);
1747 acc00 = fma(a0.s5, b5.s0, acc00);
1748 acc00 = fma(a0.s6, b6.s0, acc00);
1749 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001750
1751 acc01 = fma(a0.s0, b0.s1, acc01);
1752 acc01 = fma(a0.s1, b1.s1, acc01);
1753 acc01 = fma(a0.s2, b2.s1, acc01);
1754 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001755 acc01 = fma(a0.s4, b4.s1, acc01);
1756 acc01 = fma(a0.s5, b5.s1, acc01);
1757 acc01 = fma(a0.s6, b6.s1, acc01);
1758 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001759
1760#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001761 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001762 acc10 = fma(a0.s0, b0.s0, acc10);
1763 acc10 = fma(a0.s1, b1.s0, acc10);
1764 acc10 = fma(a0.s2, b2.s0, acc10);
1765 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001766 acc10 = fma(a0.s4, b4.s0, acc10);
1767 acc10 = fma(a0.s5, b5.s0, acc10);
1768 acc10 = fma(a0.s6, b6.s0, acc10);
1769 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001770
1771 acc11 = fma(a0.s0, b0.s1, acc11);
1772 acc11 = fma(a0.s1, b1.s1, acc11);
1773 acc11 = fma(a0.s2, b2.s1, acc11);
1774 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001775 acc11 = fma(a0.s4, b4.s1, acc11);
1776 acc11 = fma(a0.s5, b5.s1, acc11);
1777 acc11 = fma(a0.s6, b6.s1, acc11);
1778 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001779#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1780#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001781 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001782 acc20 = fma(a0.s0, b0.s0, acc20);
1783 acc20 = fma(a0.s1, b1.s0, acc20);
1784 acc20 = fma(a0.s2, b2.s0, acc20);
1785 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001786 acc20 = fma(a0.s4, b4.s0, acc20);
1787 acc20 = fma(a0.s5, b5.s0, acc20);
1788 acc20 = fma(a0.s6, b6.s0, acc20);
1789 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001790
1791 acc21 = fma(a0.s0, b0.s1, acc21);
1792 acc21 = fma(a0.s1, b1.s1, acc21);
1793 acc21 = fma(a0.s2, b2.s1, acc21);
1794 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001795 acc21 = fma(a0.s4, b4.s1, acc21);
1796 acc21 = fma(a0.s5, b5.s1, acc21);
1797 acc21 = fma(a0.s6, b6.s1, acc21);
1798 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001799#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1800#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001801 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001802 acc30 = fma(a0.s0, b0.s0, acc30);
1803 acc30 = fma(a0.s1, b1.s0, acc30);
1804 acc30 = fma(a0.s2, b2.s0, acc30);
1805 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001806 acc30 = fma(a0.s4, b4.s0, acc30);
1807 acc30 = fma(a0.s5, b5.s0, acc30);
1808 acc30 = fma(a0.s6, b6.s0, acc30);
1809 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001810
1811 acc31 = fma(a0.s0, b0.s1, acc31);
1812 acc31 = fma(a0.s1, b1.s1, acc31);
1813 acc31 = fma(a0.s2, b2.s1, acc31);
1814 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001815 acc31 = fma(a0.s4, b4.s1, acc31);
1816 acc31 = fma(a0.s5, b5.s1, acc31);
1817 acc31 = fma(a0.s6, b6.s1, acc31);
1818 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001819#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001820
1821 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001822 }
1823 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001824 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001825 {
1826 // Load values from matrix A
1827 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1829 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1832 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1833#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1834#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1835 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1837 // Load values from matrix B
1838 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001839 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001840
1841 // Multiply and accumulate
1842 acc00 = fma(a0, b0.s0, acc00);
1843 acc01 = fma(a0, b0.s1, acc01);
1844#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1845 acc10 = fma(a1, b0.s0, acc10);
1846 acc11 = fma(a1, b0.s1, acc11);
1847#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1848#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1849 acc20 = fma(a2, b0.s0, acc20);
1850 acc21 = fma(a2, b0.s1, acc21);
1851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1853 acc30 = fma(a3, b0.s0, acc30);
1854 acc31 = fma(a3, b0.s1, acc31);
1855#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001856
1857 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001858 }
1859
1860 // Compute destination address
1861 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1862
Gian Marcoae2af742018-02-15 12:35:44 +00001863 // Compute dst address
1864 __global uchar *dst_addr = offset(&dst, 0, 0);
1865
1866 // Add offset for batched GEMM
1867 dst_addr += get_global_id(2) * dst_stride_z;
1868
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001869 // Multiply by the weight of matrix-matrix product and store the result
1870#if defined(ALPHA)
1871 acc00 = acc00 * ALPHA;
1872 acc01 = acc01 * ALPHA;
1873#endif // defined(ALPHA)
1874 float2 acc0 = ((float2)(acc00, acc01));
Gian Marcoae2af742018-02-15 12:35:44 +00001875 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1877#if defined(ALPHA)
1878 acc10 = acc10 * ALPHA;
1879 acc11 = acc11 * ALPHA;
1880#endif // defined(ALPHA)
1881 float2 acc1 = ((float2)(acc10, acc11));
Gian Marcoae2af742018-02-15 12:35:44 +00001882 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001883#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1884#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1885#if defined(ALPHA)
1886 acc20 = acc20 * ALPHA;
1887 acc21 = acc21 * ALPHA;
1888#endif // defined(ALPHA)
1889 float2 acc2 = ((float2)(acc20, acc21));
Gian Marcoae2af742018-02-15 12:35:44 +00001890 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001891#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1892#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1893#if defined(ALPHA)
1894 acc30 = acc30 * ALPHA;
1895 acc31 = acc31 * ALPHA;
1896#endif // defined(ALPHA)
1897 float2 acc3 = (float2)(acc30, acc31);
Gian Marcoae2af742018-02-15 12:35:44 +00001898 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1900}
1901
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01001902#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01001903/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1904 *
1905 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
1906 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1907 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1908 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1909 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
1910 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1911 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1912 *
1913 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1914 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1915 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1916 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1917 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1918 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1919 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1920 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1921 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1922 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1923 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1924 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1925 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1926 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1927 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1928 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1929 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1930 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1931 */
1932__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
1933 IMAGE_DECLARATION(src1),
1934 IMAGE_DECLARATION(dst),
1935 uint src0_stride_z,
1936 uint src1_stride_z,
1937 uint dst_stride_z)
1938{
1939 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1940
1941 // Compute starting address for matrix A and Matrix B
1942 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1943
1944 // Update address for the matrix A
1945 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1946
1947 // Update address for the matrix B
1948 src_addr.s1 += idx * sizeof(half);
1949
1950 // Add offset for batched GEMM
1951 src_addr.s0 += get_global_id(2) * src0_stride_z;
1952
1953#if defined(MATRIX_B_DEPTH)
1954 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1955 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1956#else // defined(MATRIX_B_DEPTH)
1957 src_addr.s1 += get_global_id(2) * src1_stride_z;
1958#endif // defined(MATRIX_B_DEPTH)
1959
1960 half8 acc0 = 0.0h;
1961#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1962 half8 acc1 = 0.0h;
1963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1965 half8 acc2 = 0.0h;
1966#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1967#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1968 half8 acc3 = 0.0h;
1969#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1970
1971 int i = 0;
1972 for(; i <= ((int)COLS_A - 4); i += 4)
1973 {
1974 // Load values from matrix A
1975 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1977 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1978#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1980 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1983 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1985 // Load values from matrix B
1986 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
1987 src_addr.s1 += src1_stride_y;
1988
1989 // Accumulate
1990 acc0 = fma(b0, (half8)a0.s0, acc0);
1991#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1992 acc1 = fma(b0, (half8)a1.s0, acc1);
1993#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1994#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1995 acc2 = fma(b0, (half8)a2.s0, acc2);
1996#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1998 acc3 = fma(b0, (half8)a3.s0, acc3);
1999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2000
2001 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2002 src_addr.s1 += src1_stride_y;
2003 acc0 = fma(b0, (half8)a0.s1, acc0);
2004#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2005 acc1 = fma(b0, (half8)a1.s1, acc1);
2006#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2007#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2008 acc2 = fma(b0, (half8)a2.s1, acc2);
2009#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2010#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2011 acc3 = fma(b0, (half8)a3.s1, acc3);
2012#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2013
2014 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2015 src_addr.s1 += src1_stride_y;
2016 acc0 = fma(b0, (half8)a0.s2, acc0);
2017#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2018 acc1 = fma(b0, (half8)a1.s2, acc1);
2019#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2020#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2021 acc2 = fma(b0, (half8)a2.s2, acc2);
2022#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2023#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2024 acc3 = fma(b0, (half8)a3.s2, acc3);
2025#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2026
2027 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2028 src_addr.s1 += src1_stride_y;
2029 acc0 = fma(b0, (half8)a0.s3, acc0);
2030#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2031 acc1 = fma(b0, (half8)a1.s3, acc1);
2032#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2033#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2034 acc2 = fma(b0, (half8)a2.s3, acc2);
2035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2037 acc3 = fma(b0, (half8)a3.s3, acc3);
2038#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2039
2040 src_addr.s0 += 4 * sizeof(half);
2041 }
2042
2043 for(; i < (int)COLS_A; ++i)
2044 {
2045 // Load values from matrix A
2046 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2048 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2049#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2050#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2051 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2052#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2054 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2056 // Load values from matrix B
2057 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2058
2059 src_addr += (int2)(sizeof(half), src1_stride_y);
2060
2061 // Accumulate
2062 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2064 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2067 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2069#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2070 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2071#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2072 }
2073
2074 // Compute destination address
2075 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2076
2077 // Compute dst address
2078 __global uchar *dst_addr = offset(&dst, 0, 0);
2079
2080 // Add offset for batched GEMM
2081 dst_addr += get_global_id(2) * dst_stride_z;
2082
2083 // Multiply by the weight of matrix-matrix product and store the result
2084#if defined(ALPHA)
2085 acc0 = acc0 * (half8)ALPHA;
2086#endif // defined(ALPHA)
2087 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2088#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2089#if defined(ALPHA)
2090 acc1 = acc1 * (half8)ALPHA;
2091#endif // defined(ALPHA)
2092 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2093#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2094#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2095#if defined(ALPHA)
2096 acc2 = acc2 * (half8)ALPHA;
2097#endif // defined(ALPHA)
2098 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2099#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2100#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2101#if defined(ALPHA)
2102 acc3 = acc3 * (half8)ALPHA;
2103#endif // defined(ALPHA)
2104 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2105#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2106}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002107#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002108
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002109#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002110/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002111 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002112 * @note This OpenCL kernel works with fixed point data types QS8
2113 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002114 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002115 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002116 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002117 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2118 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002119 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002120 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002121 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2122 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2123 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2124 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2125 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2126 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2127 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2128 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2129 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2130 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2131 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2132 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2133 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2134 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2135 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2136 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2137 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2138 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002139__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002140 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002141 IMAGE_DECLARATION(dst),
2142 uint src0_stride_z,
2143 uint src1_stride_z,
2144 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002145{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002146 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002147
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002148 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002149 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002150
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002151 // Update address for the matrix A
2152 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002153
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002154 // Update address for the matrix B
2155 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002156
Gian Marcoae2af742018-02-15 12:35:44 +00002157 // Add offset for batched GEMM
2158 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002159
2160#if defined(MATRIX_B_DEPTH)
2161 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2162 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2163#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002164 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002165#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002166
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002167 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
2168
2169 short8 acc00 = 0;
2170 short8 acc01 = 0;
2171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2172 short8 acc10 = 0;
2173 short8 acc11 = 0;
2174#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2175#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2176 short8 acc20 = 0;
2177 short8 acc21 = 0;
2178#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2179#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2180 short8 acc30 = 0;
2181 short8 acc31 = 0;
2182#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2183
2184 // This for loop performs 4 accumulations per iteration
2185 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002186 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002187 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2189 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2190#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2191#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2192 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2193#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2194#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2195 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2196#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002197 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2198 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002199
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002200 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
2201 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
2202 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2203 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2204#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2205 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
2206 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
2207 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2208 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2209#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2210#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2211 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
2212 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
2213 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2214 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2215#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2216#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2217 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
2218 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
2219 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2220 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2221#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002222 }
2223
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002224 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002225 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
2226 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002227 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2228#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2229 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2230#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2231#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2232 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2233#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2234#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2235 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2236#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002237 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
2238
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002239 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
2240 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2242 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
2243 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
2244#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2245#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2246 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
2247 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
2248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2249#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2250 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
2251 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
2252#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002253 }
2254
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002255 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002256 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2257
Gian Marcoae2af742018-02-15 12:35:44 +00002258 // Compute dst address
2259 __global uchar *dst_addr = offset(&dst, 0, 0);
2260
2261 // Add offset for batched GEMM
2262 dst_addr += get_global_id(2) * dst_stride_z;
2263
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002264 // Multiply by the weight of matrix product and store the result
2265 char16 acc_qs8;
2266 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002267#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002268 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002269#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002270 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002271#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2272 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002273#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002274 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002275#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002276 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2279 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002280#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002281 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002282#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002283 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2286 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002287#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002288 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002289#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002290 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002291#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002292}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002293
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002294/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002295 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002296 * @note This OpenCL kernel works with fixed point data types QS16
2297 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002298 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002299 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002300 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002301 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2302 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002303 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002304 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002305 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2306 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2307 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2308 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2309 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2310 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2311 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2312 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2313 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2314 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2315 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2316 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2317 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2318 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2319 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2320 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2321 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2322 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002323__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002324 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002325 IMAGE_DECLARATION(dst),
2326 uint src0_stride_z,
2327 uint src1_stride_z,
2328 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002329{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002330 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002331
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002332 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002333 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002334
2335 // Update address for the matrix A
2336 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2337
2338 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002339 src_addr.s1 += idx * sizeof(short);
2340
Gian Marcoae2af742018-02-15 12:35:44 +00002341 // Add offset for batched GEMM
2342 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002343
2344#if defined(MATRIX_B_DEPTH)
2345 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2346 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2347#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002348 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002349#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002350
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002351 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002352
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002353 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002354#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2355 int8 acc1 = 0;
2356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2358 int8 acc2 = 0;
2359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2360#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2361 int8 acc3 = 0;
2362#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002363
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002364 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002365 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002366 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002367 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2368#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2369 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2370#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2371#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2372 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2373#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2375 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002377 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2378 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002379
2380 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
2381 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002382#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2383 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
2384 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
2385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2387 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
2388 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
2389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2390#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2391 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
2392 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
2393#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002394 }
2395
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002396 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002397 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
2398 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002399 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2401 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2404 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2407 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002409 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
2410
2411 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002412#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2413 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
2414#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2416 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
2417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2419 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
2420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002421 }
2422
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002423 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002424 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2425
Gian Marcoae2af742018-02-15 12:35:44 +00002426 // Compute dst address
2427 __global uchar *dst_addr = offset(&dst, 0, 0);
2428
Gian Marco Iodice81b28c42018-03-29 10:29:36 +01002429 // Add offset for batched GEMM
2430 dst_addr += get_global_id(2) * dst_stride_z;
2431
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002432 // Multiply by the weight of matrix product and store the result
2433 short8 acc_qs16;
2434 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002435#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002436 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002437#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002438 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002439#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2440 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002441#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002442 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002443#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002444 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002445#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2446#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2447 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002448#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002449 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002450#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002451 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2453#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2454 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002455#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002456 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002457#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002458 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002459#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002460}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002461#endif // defined(FIXED_POINT_POSITION)
2462#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002463
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002464#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002465/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2466 *
Gian Marco19835e52018-01-30 13:35:54 +00002467 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002468 *
2469 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2470 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2471 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2472 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2473 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2474 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002475 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002476 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2477 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2478 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2479 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2480 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2481 */
2482__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
2483 IMAGE_DECLARATION(dst))
2484{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002485 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002486 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2487 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2488
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002489 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002490 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2491
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002492 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002493 float4 c = vload4(0, (__global float *)src.ptr);
2494
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002495 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002496 float4 out = alpha_ab + (float4)BETA * c;
2497
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002498 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002499 vstore4(out, 0, (__global float *)dst.ptr);
2500}
2501
2502/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2503 *
Gian Marco19835e52018-01-30 13:35:54 +00002504 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002505 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002506 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2507 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2508 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2509 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2510 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2511 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002512 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002513 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2514 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2515 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2516 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2517 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2518 */
2519__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
2520 IMAGE_DECLARATION(dst))
2521{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002522 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002523 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2524 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2525
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002526 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002527 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2528
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002529 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002530 half8 c = vload8(0, (__global half *)src.ptr);
2531
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002532 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002533 half8 out = alpha_ab + (half8)BETA * c;
2534
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002535 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002536 vstore8(out, 0, (__global half *)dst.ptr);
2537}
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002539#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002540/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2541 *
Gian Marco19835e52018-01-30 13:35:54 +00002542 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002543 *
2544 * @note: BETA must be passed in 8 bit fixed point format
2545 *
2546 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
2547 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2548 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2549 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2550 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2551 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2552 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2553 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2554 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2555 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2556 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2557 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2558 */
2559__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
2560 IMAGE_DECLARATION(dst))
2561{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002562 // Compute source and destination addresses
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002563 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2564 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2565
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002566 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002567 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
2568
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002569 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002570 char16 c = vload16(0, (__global char *)src.ptr);
2571
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002572 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002573 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
2574
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002575 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002576 vstore16(out, 0, (__global char *)dst.ptr);
2577}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002578
2579/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2580 *
Gian Marco19835e52018-01-30 13:35:54 +00002581 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002582 *
2583 * @note: BETA must be passed in 16 bit fixed point format
2584 *
2585 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
2586 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2587 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2588 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2589 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2590 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2591 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2592 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2593 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2594 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2595 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2596 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2597 */
2598__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
2599 IMAGE_DECLARATION(dst))
2600{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002601 // Compute source and destination addresses
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002602 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2603 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2604
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002605 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002606 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
2607
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002608 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002609 short8 c = vload8(0, (__global short *)src.ptr);
2610
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002611 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002612 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
2613
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002614 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002615 vstore8(out, 0, (__global short *)dst.ptr);
2616}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002617#endif // defined(FIXED_POINT_POSITION)
2618#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002619
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002620#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002621/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2622 *
Gian Marco19835e52018-01-30 13:35:54 +00002623 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002624 *
Gian Marco19835e52018-01-30 13:35:54 +00002625 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002626 *
2627 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2628 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2629 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2630 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2631 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2632 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002633 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002634 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2635 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2636 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2637 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2638 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2639 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2640 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002641 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002642 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2643 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2644 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2645 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2646 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2647 */
2648__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2649 TENSOR3D_DECLARATION(src1),
2650 IMAGE_DECLARATION(dst))
2651{
2652 int idx = get_global_id(0) * 4;
2653 int idy = get_global_id(1);
2654
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002655 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002656 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2657 src_addr.s1 += idx * sizeof(float);
2658
2659 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2660
2661 float4 acc = 0.0f;
2662
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002663 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002664 {
2665 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2666 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2667 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2668
2669 acc += b0 * (float4)a0.s0;
2670 acc += b1 * (float4)a0.s1;
2671 }
2672
2673 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2674 {
2675 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2676 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2677
2678 acc += b0 * (float4)a0;
2679 }
2680
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002681 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002682 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2683
2684 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2685}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002686#endif // defined(WIDTH_VECTOR_A)
2687
2688/** This kernel accumulates each row with the biases vector.
2689 *
2690 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2691 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2692 *
2693 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
2694 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2695 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2696 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2697 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2698 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2699 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2700 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2701 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2702 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2703 */
2704#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2705__kernel void gemm_accumulate_biases(
2706 IMAGE_DECLARATION(accum),
2707 VECTOR_DECLARATION(biases))
2708{
2709 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2710 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2711
2712 // Vector size, i.e. number of vector elements.
2713 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2714 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2715 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2716 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
2717#ifdef FIXED_POINT_POSITION
2718 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
2719#else // FIXED_POINT_POSITION
2720 accum_value = biases_value + accum_value;
2721#endif // FIXED_POINT_POSITION
2722 // Store result in the accumulate buffer
2723 VSTORE(VECTOR_SIZE)
2724 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2725}
2726#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)