blob: 00b130f5a9a0057f9002611be5a92208ca990101 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Gian Marco36a0a462018-01-12 10:21:40 +000030#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
Gian Marco19835e52018-01-30 13:35:54 +000032#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000033#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000034#elif ELEMENT_SIZE == 2
35#define DATA_TYPE ushort
36#elif ELEMENT_SIZE == 4
37#define DATA_TYPE uint
38#else // ELEMENT_SIZE == 1
39#error "Element size not supported"
40#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000041
42/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043 *
Gian Marco19835e52018-01-30 13:35:54 +000044 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000046 *
47 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010048 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000052 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010055 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000057 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000059 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
61 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
63 */
Gian Marcoae2af742018-02-15 12:35:44 +000064__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066{
67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000069 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010071 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000072 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010074 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000075 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077
Gian Marcoae2af742018-02-15 12:35:44 +000078 // Add offset for batched GEMM
79 dst_addr_in_bytes += z * dst_stride_z;
80
Gian Marco36a0a462018-01-12 10:21:40 +000081 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
82 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083
Gian Marco36a0a462018-01-12 10:21:40 +000084 VSTORE(TRANSPOSE_W)
85 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086}
Gian Marco36a0a462018-01-12 10:21:40 +000087#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088
Gian Marco36a0a462018-01-12 10:21:40 +000089#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
90
91/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092 *
Gian Marco19835e52018-01-30 13:35:54 +000093 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
94 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
95 *
Gian Marco36a0a462018-01-12 10:21:40 +000096 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
102 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100104 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
106 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
107 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
108 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000109 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
110 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
112 */
Gian Marcoae2af742018-02-15 12:35:44 +0000113__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
114 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115{
Gian Marco36a0a462018-01-12 10:21:40 +0000116 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000119 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120
Gian Marcoae2af742018-02-15 12:35:44 +0000121 // Compute address for source tensor
122 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000124 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marcoae2af742018-02-15 12:35:44 +0000128 // Add offset for batched GEMM
129 dst_addr_in_bytes += z * dst_stride_z;
130
131 __global uchar *input_ptr = src.ptr;
132
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000133 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000134 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000135 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000136 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000138 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000140 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142
Gian Marco36a0a462018-01-12 10:21:40 +0000143 VEC_DATA_TYPE(DATA_TYPE, 4)
144 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100146
Gian Marco36a0a462018-01-12 10:21:40 +0000147 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149
Gian Marco36a0a462018-01-12 10:21:40 +0000150 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152
Gian Marco36a0a462018-01-12 10:21:40 +0000153 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155}
Gian Marco36a0a462018-01-12 10:21:40 +0000156#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157
Gian Marco36a0a462018-01-12 10:21:40 +0000158#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100160 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100161 *
Gian Marco19835e52018-01-30 13:35:54 +0000162 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
163 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
164 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000165 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
166 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167 *
168 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
169 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
170 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
171 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
172 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
173 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100174 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100175 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
176 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
178 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100180 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100181 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000182 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000184 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
186 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100187__kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
188 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000189 IMAGE_DECLARATION(dst),
190 uint src0_stride_z,
191 uint src1_stride_z,
192 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193{
Gian Marco36a0a462018-01-12 10:21:40 +0000194 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
195 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000196 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197
Gian Marco36a0a462018-01-12 10:21:40 +0000198 // Offset
199 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
200 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201
Gian Marco36a0a462018-01-12 10:21:40 +0000202 // src_addr_a = address of matrix A
203 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000204 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
205 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
206
207#if defined(MATRIX_B_DEPTH)
208 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
209 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
210#else // defined(MATRIX_B_DEPTH)
211 src1_addr_in_bytes += z * src1_stride_z;
212#endif // defined(MATRIX_B_DEPTH)
213
214 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
215 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000217 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000218 __global float *src_end_addr_b = src_addr_b + COLS_B;
219
220 src_addr_a += offset_row_a;
221 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000223 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224 float4 c00 = 0.0f;
225 float4 c10 = 0.0f;
226 float4 c20 = 0.0f;
227 float4 c30 = 0.0f;
228
Gian Marco36a0a462018-01-12 10:21:40 +0000229 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100230 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000231 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000232 float4 a0 = vload4(0, src_addr_a);
233 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234
235 c00 += (float4)a0.s0 * b0;
236 c10 += (float4)a0.s1 * b0;
237 c20 += (float4)a0.s2 * b0;
238 c30 += (float4)a0.s3 * b0;
239
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000240 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000241 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
242 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243
244 c00 += (float4)a0.s0 * b0;
245 c10 += (float4)a0.s1 * b0;
246 c20 += (float4)a0.s2 * b0;
247 c30 += (float4)a0.s3 * b0;
248 }
249
Gian Marco36a0a462018-01-12 10:21:40 +0000250 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000252 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000253 float4 a0 = vload4(0, src_addr_a);
254 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255
256 c00 += (float4)a0.s0 * b0;
257 c10 += (float4)a0.s1 * b0;
258 c20 += (float4)a0.s2 * b0;
259 c30 += (float4)a0.s3 * b0;
260 }
261
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000262 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
264
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000265#if defined(ALPHA)
266 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267 c00 = c00 * (float4)ALPHA;
268 c10 = c10 * (float4)ALPHA;
269 c20 = c20 * (float4)ALPHA;
270 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000271#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272
Gian Marcoae2af742018-02-15 12:35:44 +0000273 // Compute dst address
274 __global uchar *dst_addr = offset(&dst, 0, 0);
275
276 // Add offset for batched GEMM
277 dst_addr += z * dst_stride_z;
278
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000279 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000280 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
281 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
282 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
283 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284}
285
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000286/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100287 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100288 *
Gian Marco19835e52018-01-30 13:35:54 +0000289 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
290 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
291 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000292 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
293 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
294 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295 *
296 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
297 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
298 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
299 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
300 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
301 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100302 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100303 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
304 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
305 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
306 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
307 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100308 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000310 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100311 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000312 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
314 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100315__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
316 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000317 IMAGE_DECLARATION(dst),
318 uint src0_stride_z,
319 uint src1_stride_z,
320 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100321{
Gian Marco36a0a462018-01-12 10:21:40 +0000322 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
323 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000324 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000325
326 // Offset
327 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
328 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
329
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100330 // src_addr_a = address of matrix A
331 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000332 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
333 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
334
335#if defined(MATRIX_B_DEPTH)
336 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
337 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
338#else // defined(MATRIX_B_DEPTH)
339 src1_addr_in_bytes += z * src1_stride_z;
340#endif // defined(MATRIX_B_DEPTH)
341
342 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
343 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100344
345 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100346 __global float *src_end_addr_b = src_addr_b + COLS_B;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347
Gian Marco36a0a462018-01-12 10:21:40 +0000348 src_addr_a += offset_row_a;
349 src_addr_b += offset_row_b;
350
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351 // Reset accumulators
352 float c00 = 0.0f;
353 float c01 = 0.0f;
354 float c02 = 0.0f;
355 float c03 = 0.0f;
356 float c10 = 0.0f;
357 float c11 = 0.0f;
358 float c12 = 0.0f;
359 float c13 = 0.0f;
360 float c20 = 0.0f;
361 float c21 = 0.0f;
362 float c22 = 0.0f;
363 float c23 = 0.0f;
364 float c30 = 0.0f;
365 float c31 = 0.0f;
366 float c32 = 0.0f;
367 float c33 = 0.0f;
368
Gian Marco36a0a462018-01-12 10:21:40 +0000369 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100370 {
371 // Load values from matrix A (interleaved) and matrix B (transposed)
372 float4 a0 = vload4(0, src_addr_a);
373 float4 b0 = vload4(0, src_addr_b);
374
375 c00 = fma(a0.s0, b0.s0, c00);
376 c01 = fma(a0.s0, b0.s1, c01);
377 c02 = fma(a0.s0, b0.s2, c02);
378 c03 = fma(a0.s0, b0.s3, c03);
379
380 c10 = fma(a0.s1, b0.s0, c10);
381 c11 = fma(a0.s1, b0.s1, c11);
382 c12 = fma(a0.s1, b0.s2, c12);
383 c13 = fma(a0.s1, b0.s3, c13);
384
385 c20 = fma(a0.s2, b0.s0, c20);
386 c21 = fma(a0.s2, b0.s1, c21);
387 c22 = fma(a0.s2, b0.s2, c22);
388 c23 = fma(a0.s2, b0.s3, c23);
389
390 c30 = fma(a0.s3, b0.s0, c30);
391 c31 = fma(a0.s3, b0.s1, c31);
392 c32 = fma(a0.s3, b0.s2, c32);
393 c33 = fma(a0.s3, b0.s3, c33);
394
395 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000396 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
397 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100398
399 c00 = fma(a0.s0, b0.s0, c00);
400 c01 = fma(a0.s0, b0.s1, c01);
401 c02 = fma(a0.s0, b0.s2, c02);
402 c03 = fma(a0.s0, b0.s3, c03);
403
404 c10 = fma(a0.s1, b0.s0, c10);
405 c11 = fma(a0.s1, b0.s1, c11);
406 c12 = fma(a0.s1, b0.s2, c12);
407 c13 = fma(a0.s1, b0.s3, c13);
408
409 c20 = fma(a0.s2, b0.s0, c20);
410 c21 = fma(a0.s2, b0.s1, c21);
411 c22 = fma(a0.s2, b0.s2, c22);
412 c23 = fma(a0.s2, b0.s3, c23);
413
414 c30 = fma(a0.s3, b0.s0, c30);
415 c31 = fma(a0.s3, b0.s1, c31);
416 c32 = fma(a0.s3, b0.s2, c32);
417 c33 = fma(a0.s3, b0.s3, c33);
418
419 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000420 a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
421 b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100422
423 c00 = fma(a0.s0, b0.s0, c00);
424 c01 = fma(a0.s0, b0.s1, c01);
425 c02 = fma(a0.s0, b0.s2, c02);
426 c03 = fma(a0.s0, b0.s3, c03);
427
428 c10 = fma(a0.s1, b0.s0, c10);
429 c11 = fma(a0.s1, b0.s1, c11);
430 c12 = fma(a0.s1, b0.s2, c12);
431 c13 = fma(a0.s1, b0.s3, c13);
432
433 c20 = fma(a0.s2, b0.s0, c20);
434 c21 = fma(a0.s2, b0.s1, c21);
435 c22 = fma(a0.s2, b0.s2, c22);
436 c23 = fma(a0.s2, b0.s3, c23);
437
438 c30 = fma(a0.s3, b0.s0, c30);
439 c31 = fma(a0.s3, b0.s1, c31);
440 c32 = fma(a0.s3, b0.s2, c32);
441 c33 = fma(a0.s3, b0.s3, c33);
442
443 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000444 a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
445 b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100446
447 c00 = fma(a0.s0, b0.s0, c00);
448 c01 = fma(a0.s0, b0.s1, c01);
449 c02 = fma(a0.s0, b0.s2, c02);
450 c03 = fma(a0.s0, b0.s3, c03);
451
452 c10 = fma(a0.s1, b0.s0, c10);
453 c11 = fma(a0.s1, b0.s1, c11);
454 c12 = fma(a0.s1, b0.s2, c12);
455 c13 = fma(a0.s1, b0.s3, c13);
456
457 c20 = fma(a0.s2, b0.s0, c20);
458 c21 = fma(a0.s2, b0.s1, c21);
459 c22 = fma(a0.s2, b0.s2, c22);
460 c23 = fma(a0.s2, b0.s3, c23);
461
462 c30 = fma(a0.s3, b0.s0, c30);
463 c31 = fma(a0.s3, b0.s1, c31);
464 c32 = fma(a0.s3, b0.s2, c32);
465 c33 = fma(a0.s3, b0.s3, c33);
466 }
467
Gian Marco36a0a462018-01-12 10:21:40 +0000468 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469 {
470 // Load values from matrix A (interleaved) and matrix B (transposed)
471 float4 a0 = vload4(0, src_addr_a);
472 float4 b0 = vload4(0, src_addr_b);
473
474 c00 = fma(a0.s0, b0.s0, c00);
475 c01 = fma(a0.s0, b0.s1, c01);
476 c02 = fma(a0.s0, b0.s2, c02);
477 c03 = fma(a0.s0, b0.s3, c03);
478
479 c10 = fma(a0.s1, b0.s0, c10);
480 c11 = fma(a0.s1, b0.s1, c11);
481 c12 = fma(a0.s1, b0.s2, c12);
482 c13 = fma(a0.s1, b0.s3, c13);
483
484 c20 = fma(a0.s2, b0.s0, c20);
485 c21 = fma(a0.s2, b0.s1, c21);
486 c22 = fma(a0.s2, b0.s2, c22);
487 c23 = fma(a0.s2, b0.s3, c23);
488
489 c30 = fma(a0.s3, b0.s0, c30);
490 c31 = fma(a0.s3, b0.s1, c31);
491 c32 = fma(a0.s3, b0.s2, c32);
492 c33 = fma(a0.s3, b0.s3, c33);
493 }
494
495 // Compute destination address
496 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
497
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000498#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100499 // Multiply by the weight of matrix product
500 c00 = c00 * ALPHA;
501 c01 = c01 * ALPHA;
502 c02 = c02 * ALPHA;
503 c03 = c03 * ALPHA;
504 c10 = c10 * ALPHA;
505 c11 = c11 * ALPHA;
506 c12 = c12 * ALPHA;
507 c13 = c13 * ALPHA;
508 c20 = c20 * ALPHA;
509 c21 = c21 * ALPHA;
510 c22 = c22 * ALPHA;
511 c23 = c23 * ALPHA;
512 c30 = c30 * ALPHA;
513 c31 = c31 * ALPHA;
514 c32 = c32 * ALPHA;
515 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000516#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100517
Gian Marcoae2af742018-02-15 12:35:44 +0000518 // Compute dst address
519 __global uchar *dst_addr = offset(&dst, 0, 0);
520
521 // Add offset for batched GEMM
522 dst_addr += z * dst_stride_z;
523
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100524 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000525 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
526 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
527 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
528 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100529}
530
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100531#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100533 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100534 *
Gian Marco19835e52018-01-30 13:35:54 +0000535 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
536 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
537 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000538 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
539 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100540 *
541 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
542 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
543 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
544 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
545 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
546 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100547 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100548 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
549 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
550 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
551 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
552 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100553 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100554 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000555 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100556 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000557 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
559 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100560__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
561 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000562 IMAGE_DECLARATION(dst),
563 uint src0_stride_z,
564 uint src1_stride_z,
565 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100566{
Gian Marco36a0a462018-01-12 10:21:40 +0000567 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
568 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000569 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570
Gian Marco36a0a462018-01-12 10:21:40 +0000571 // Offset
572 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
573 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100574
Gian Marco36a0a462018-01-12 10:21:40 +0000575 // src_addr_a = address of matrix A
576 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000577 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
578 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
579
580#if defined(MATRIX_B_DEPTH)
581 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
582 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
583#else // defined(MATRIX_B_DEPTH)
584 src1_addr_in_bytes += z * src1_stride_z;
585#endif // defined(MATRIX_B_DEPTH)
586
587 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
588 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100589
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000590 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000591 __global half *src_end_addr_b = src_addr_b + COLS_B;
592
593 src_addr_a += offset_row_a;
594 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100595
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000596 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100597 half8 c00 = 0.0f;
598 half8 c10 = 0.0f;
599 half8 c20 = 0.0f;
600 half8 c30 = 0.0f;
601
Gian Marco36a0a462018-01-12 10:21:40 +0000602 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100603 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000604 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000605 half4 a0 = vload4(0, src_addr_a);
606 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607
608 c00 += (half8)a0.s0 * b0;
609 c10 += (half8)a0.s1 * b0;
610 c20 += (half8)a0.s2 * b0;
611 c30 += (half8)a0.s3 * b0;
612
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000613 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000614 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
615 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100616
617 c00 += (half8)a0.s0 * b0;
618 c10 += (half8)a0.s1 * b0;
619 c20 += (half8)a0.s2 * b0;
620 c30 += (half8)a0.s3 * b0;
621 }
622
Gian Marco36a0a462018-01-12 10:21:40 +0000623 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100624 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000625 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000626 half4 a0 = vload4(0, src_addr_a);
627 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100628
629 c00 += (half8)a0.s0 * b0;
630 c10 += (half8)a0.s1 * b0;
631 c20 += (half8)a0.s2 * b0;
632 c30 += (half8)a0.s3 * b0;
633 }
634
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000635 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
637
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000638#if defined(ALPHA)
639 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100640 c00 = c00 * (half8)ALPHA;
641 c10 = c10 * (half8)ALPHA;
642 c20 = c20 * (half8)ALPHA;
643 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000644#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100645
Gian Marcoae2af742018-02-15 12:35:44 +0000646 // Compute dst address
647 __global uchar *dst_addr = offset(&dst, 0, 0);
648
649 // Add offset for batched GEMM
650 dst_addr += z * dst_stride_z;
651
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000652 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000653 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
654 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
655 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
656 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100657}
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100658#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100659
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000660#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100661/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
662 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
663 *
Gian Marco19835e52018-01-30 13:35:54 +0000664 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
665 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
666 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000667 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
668 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
669 * @note:ALPHA must be passed in 8 bit fixed point format
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100670 *
671 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
672 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
673 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
674 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
675 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
676 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
677 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
678 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
679 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
680 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
681 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
682 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
683 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
684 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000685 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100686 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000687 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100688 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
689 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100690__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
691 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000692 IMAGE_DECLARATION(dst),
693 uint src0_stride_z,
694 uint src1_stride_z,
695 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100696{
Gian Marco36a0a462018-01-12 10:21:40 +0000697 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
698 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000699 int z = get_global_id(2);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100700
Gian Marco36a0a462018-01-12 10:21:40 +0000701 // Offset
702 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
703 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100704
Gian Marco36a0a462018-01-12 10:21:40 +0000705 // src_addr_a = address of matrix A
706 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000707 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
708 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
709
710#if defined(MATRIX_B_DEPTH)
711 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
712 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
713#else // defined(MATRIX_B_DEPTH)
714 src1_addr_in_bytes += z * src1_stride_z;
715#endif // defined(MATRIX_B_DEPTH)
716
717 __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
718 __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100719
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000720 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000721 __global char *src_end_addr_b = src_addr_b + COLS_B;
722
723 src_addr_a += offset_row_a;
724 src_addr_b += offset_row_b;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100725
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000726 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100727 short8 c00 = 0.0f;
728 short8 c10 = 0.0f;
729 short8 c20 = 0.0f;
730 short8 c30 = 0.0f;
731 short8 c01 = 0.0f;
732 short8 c11 = 0.0f;
733 short8 c21 = 0.0f;
734 short8 c31 = 0.0f;
735
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000736 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000737 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100738 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000739 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000740 char4 a0 = vload4(0, src_addr_a);
741 char16 b0 = vload16(0, src_addr_b);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100742
743 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
744 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
745 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
746 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
747
748 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
749 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
750 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
751 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
752 }
753
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000754 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100755 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
756
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000757 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100758 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
759 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
760 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
761 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
762
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000763#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100764 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
765 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
766 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
767 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000768#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100769
Gian Marcoae2af742018-02-15 12:35:44 +0000770 // Compute dst address
771 __global uchar *dst_addr = offset(&dst, 0, 0);
772
773 // Add offset for batched GEMM
774 dst_addr += z * dst_stride_z;
775
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000776 // Store 16x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000777 vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
778 vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
779 vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
780 vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100781}
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100782
783/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
784 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
785 *
Gian Marco19835e52018-01-30 13:35:54 +0000786 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
787 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
788 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000789 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
790 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
791 * @note:ALPHA must be passed in 16 bit fixed point format
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100792 *
793 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
794 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
795 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
796 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
797 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
798 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
799 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
800 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
801 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
802 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
803 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
804 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
805 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
806 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000807 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100808 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000809 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100810 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
811 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100812__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
813 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000814 IMAGE_DECLARATION(dst),
815 uint src0_stride_z,
816 uint src1_stride_z,
817 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100818{
Gian Marco36a0a462018-01-12 10:21:40 +0000819 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
820 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000821 int z = get_global_id(2);
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100822
Gian Marco36a0a462018-01-12 10:21:40 +0000823 // Offset
824 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
825 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100826
Gian Marco36a0a462018-01-12 10:21:40 +0000827 // src_addr_a = address of matrix A
828 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000829 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
830 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
831
832#if defined(MATRIX_B_DEPTH)
833 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
834 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
835#else // defined(MATRIX_B_DEPTH)
836 src1_addr_in_bytes += z * src1_stride_z;
837#endif // defined(MATRIX_B_DEPTH)
838
839 __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
840 __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100841
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000842 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000843 __global short *src_end_addr_b = src_addr_b + COLS_B;
844
845 src_addr_a += offset_row_a;
846 src_addr_b += offset_row_b;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100847
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000848 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100849 int8 c00 = 0.0f;
850 int8 c10 = 0.0f;
851 int8 c20 = 0.0f;
852 int8 c30 = 0.0f;
853
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000854 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000855 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100856 {
857 /* Load values from matrix A (interleaved) and matrix B (transposed) */
Gian Marco36a0a462018-01-12 10:21:40 +0000858 short4 a0 = vload4(0, src_addr_a);
859 short8 b0 = vload8(0, src_addr_b);
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100860
861 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
862 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
863 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
864 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
865 }
866
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000867 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100868 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
869
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000870 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100871 short8 c00_qs16 = convert_short8_sat(c00);
872 short8 c10_qs16 = convert_short8_sat(c10);
873 short8 c20_qs16 = convert_short8_sat(c20);
874 short8 c30_qs16 = convert_short8_sat(c30);
875
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000876#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100877 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
878 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
879 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
880 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000881#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100882
Gian Marcoae2af742018-02-15 12:35:44 +0000883 // Compute dst address
884 __global uchar *dst_addr = offset(&dst, 0, 0);
885
886 // Add offset for batched GEMM
887 dst_addr += z * dst_stride_z;
888
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000889 // Store 8x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000890 vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
891 vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
892 vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
893 vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100894}
895#endif // defined(FIXED_POINT_POSITION)
Gian Marco36a0a462018-01-12 10:21:40 +0000896#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100897
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100898#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
899#if defined(DATA_TYPE)
900#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
901/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100902 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100903 * @note This OpenCL kernel works with floating point data types (F16/F32)
904 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
905 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000906 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000907 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
908 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100909 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100910 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100911 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
912 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
913 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
914 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
915 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100916 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100917 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
918 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
919 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
920 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
921 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100922 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100923 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
924 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
925 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
926 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
927 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
928 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100929__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
930 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000931 IMAGE_DECLARATION(dst),
932 uint src0_stride_z,
933 uint src1_stride_z,
934 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100935{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100936 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100937
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100938 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100939 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100940
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100941 // Update address for the matrix A
942 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100943
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100944 // Update address for the matrix B
945 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100946
Gian Marcoae2af742018-02-15 12:35:44 +0000947 // Add offset for batched GEMM
948 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000949
950#if defined(MATRIX_B_DEPTH)
951 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
952 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
953#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +0000954 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000955#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +0000956
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100957 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
958
959 VECTOR_TYPE acc0 = 0.0f;
960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
961 VECTOR_TYPE acc1 = 0.0f;
962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
963#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
964 VECTOR_TYPE acc2 = 0.0f;
965#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
966#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
967 VECTOR_TYPE acc3 = 0.0f;
968#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
969
Georgios Pinitas96880cf2017-10-20 18:52:20 +0100970 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100971 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100972 // Load values from matrix A
973 VEC_DATA_TYPE(DATA_TYPE, 2)
974 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
975#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
976 VEC_DATA_TYPE(DATA_TYPE, 2)
977 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
978#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
980 VEC_DATA_TYPE(DATA_TYPE, 2)
981 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
982#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
983#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
984 VEC_DATA_TYPE(DATA_TYPE, 2)
985 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
987 // Load values from matrix B
988 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
989 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100990
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100991 // Accumulate
992 acc0 += b0 * (VECTOR_TYPE)a0.s0;
993 acc0 += b1 * (VECTOR_TYPE)a0.s1;
994#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
995 acc1 += b0 * (VECTOR_TYPE)a1.s0;
996 acc1 += b1 * (VECTOR_TYPE)a1.s1;
997#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
998#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
999 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1000 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1001#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1002#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1003 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1004 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1005#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001006 }
1007
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001008 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001009 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001010 // Load values from matrix A
1011 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1012#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1013 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1014#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1016 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1019 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1021 // Load values from matrix B
1022 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001024 // Accumulate
1025 acc0 += b0 * (VECTOR_TYPE)a0;
1026#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1027 acc1 += b0 * (VECTOR_TYPE)a1;
1028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1029#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1030 acc2 += b0 * (VECTOR_TYPE)a2;
1031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1033 acc3 += b0 * (VECTOR_TYPE)a3;
1034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001035 }
1036
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001037 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001038 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1039
Gian Marcoae2af742018-02-15 12:35:44 +00001040 // Compute dst address
1041 __global uchar *dst_addr = offset(&dst, 0, 0);
1042
1043 // Add offset for batched GEMM
1044 dst_addr += get_global_id(2) * dst_stride_z;
1045
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001046 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001047#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001048 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001049#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001050 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001051 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001052#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001053#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001054 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001055#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001056 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001057 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001060#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001061 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001062#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001063 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001064 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001067#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001068 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001069#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001070 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001071 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001072#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001073}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001074#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001075
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001076/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1077 *
1078 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1079 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1080 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1081 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1082 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001083 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1084 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001085 *
1086 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1087 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1088 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1089 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1090 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1091 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1092 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1093 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1094 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1095 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1096 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1097 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1098 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1099 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1100 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1101 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1102 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1103 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1104 */
1105__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1106 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001107 IMAGE_DECLARATION(dst),
1108 uint src0_stride_z,
1109 uint src1_stride_z,
1110 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001111{
1112 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1113
1114 // Compute starting address for matrix A and matrix B
1115 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1116
1117 // Update address for matrix A
1118 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1119
1120 // Update address for matrix B
1121 src_addr.s1 += idx * sizeof(float);
1122
Gian Marcoae2af742018-02-15 12:35:44 +00001123 // Add offset for batched GEMM
1124 src_addr.s0 += get_global_id(2) * src0_stride_z;
1125
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001126#if defined(MATRIX_B_DEPTH)
1127 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1128 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1129#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001130 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001131#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001132
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001133 // Address boundary for matrix A
1134 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1135
1136 // Initialize accumulators
1137 float acc00 = 0.0f;
1138 float acc01 = 0.0f;
1139 float acc02 = 0.0f;
1140 float acc03 = 0.0f;
1141
1142#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1143 float acc10 = 0.0f;
1144 float acc11 = 0.0f;
1145 float acc12 = 0.0f;
1146 float acc13 = 0.0f;
1147#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1148
1149#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1150 float acc20 = 0.0f;
1151 float acc21 = 0.0f;
1152 float acc22 = 0.0f;
1153 float acc23 = 0.0f;
1154#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1155
1156#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1157 float acc30 = 0.0f;
1158 float acc31 = 0.0f;
1159 float acc32 = 0.0f;
1160 float acc33 = 0.0f;
1161#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1162
1163 // A and B src indices get incremented at the same time.
1164 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
1165 {
1166 // Load values from matrix A
1167 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1169 float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1170#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1172 float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1175 float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1176#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1177 // Load values from matrix B
1178 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1179 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1180
1181 // Multiply and accumulate
1182 acc00 = fma(a0.s0, b0.s0, acc00);
1183 acc00 = fma(a0.s1, b1.s0, acc00);
1184 acc01 = fma(a0.s0, b0.s1, acc01);
1185 acc01 = fma(a0.s1, b1.s1, acc01);
1186 acc02 = fma(a0.s0, b0.s2, acc02);
1187 acc02 = fma(a0.s1, b1.s2, acc02);
1188 acc03 = fma(a0.s1, b1.s3, acc03);
1189 acc03 = fma(a0.s0, b0.s3, acc03);
1190
1191#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1192 acc10 = fma(a1.s0, b0.s0, acc10);
1193 acc11 = fma(a1.s0, b0.s1, acc11);
1194 acc12 = fma(a1.s0, b0.s2, acc12);
1195 acc13 = fma(a1.s0, b0.s3, acc13);
1196
1197 acc10 = fma(a1.s1, b1.s0, acc10);
1198 acc11 = fma(a1.s1, b1.s1, acc11);
1199 acc12 = fma(a1.s1, b1.s2, acc12);
1200 acc13 = fma(a1.s1, b1.s3, acc13);
1201#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1202#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1203 acc20 = fma(a2.s0, b0.s0, acc20);
1204 acc21 = fma(a2.s0, b0.s1, acc21);
1205 acc22 = fma(a2.s0, b0.s2, acc22);
1206 acc23 = fma(a2.s0, b0.s3, acc23);
1207
1208 acc20 = fma(a2.s1, b1.s0, acc20);
1209 acc21 = fma(a2.s1, b1.s1, acc21);
1210 acc22 = fma(a2.s1, b1.s2, acc22);
1211 acc23 = fma(a2.s1, b1.s3, acc23);
1212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1213#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1214 acc30 = fma(a3.s0, b0.s0, acc30);
1215 acc31 = fma(a3.s0, b0.s1, acc31);
1216 acc32 = fma(a3.s0, b0.s2, acc32);
1217 acc33 = fma(a3.s0, b0.s3, acc33);
1218
1219 acc30 = fma(a3.s1, b1.s0, acc30);
1220 acc31 = fma(a3.s1, b1.s1, acc31);
1221 acc32 = fma(a3.s1, b1.s2, acc32);
1222 acc33 = fma(a3.s1, b1.s3, acc33);
1223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1224 }
1225
1226 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1227 {
1228 // Load values from matrix A
1229 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1230#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1231 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1232#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1233#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1234 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1235#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1237 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1239 // Load values from matrix B
1240 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1241
1242 // Multiply and accumulate
1243 acc00 = fma(a0, b0.s0, acc00);
1244 acc01 = fma(a0, b0.s1, acc01);
1245 acc02 = fma(a0, b0.s2, acc02);
1246 acc03 = fma(a0, b0.s3, acc03);
1247#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1248 acc10 = fma(a1, b0.s0, acc10);
1249 acc11 = fma(a1, b0.s1, acc11);
1250 acc12 = fma(a1, b0.s2, acc12);
1251 acc13 = fma(a1, b0.s3, acc13);
1252#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1253#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1254 acc20 = fma(a2, b0.s0, acc20);
1255 acc21 = fma(a2, b0.s1, acc21);
1256 acc22 = fma(a2, b0.s2, acc22);
1257 acc23 = fma(a2, b0.s3, acc23);
1258#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1259#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1260 acc30 = fma(a3, b0.s0, acc30);
1261 acc31 = fma(a3, b0.s1, acc31);
1262 acc32 = fma(a3, b0.s2, acc32);
1263 acc33 = fma(a3, b0.s3, acc33);
1264#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1265 }
1266
1267 // Compute destination address
1268 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1269
1270 // Multiply by the weight of matrix-matrix product and store the result
1271#if defined(ALPHA)
1272 acc00 = acc00 * ALPHA;
1273 acc01 = acc01 * ALPHA;
1274 acc02 = acc02 * ALPHA;
1275 acc03 = acc03 * ALPHA;
1276#endif // defined(ALPHA)
1277
Gian Marcoae2af742018-02-15 12:35:44 +00001278 // Compute dst address
1279 __global uchar *dst_addr = offset(&dst, 0, 0);
1280
1281 // Add offset for batched GEMM
1282 dst_addr += get_global_id(2) * dst_stride_z;
1283
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001284 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
Gian Marcoae2af742018-02-15 12:35:44 +00001285 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001286
1287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1288#if defined(ALPHA)
1289 acc10 = acc10 * ALPHA;
1290 acc11 = acc11 * ALPHA;
1291 acc12 = acc12 * ALPHA;
1292 acc13 = acc13 * ALPHA;
1293#endif // defined(ALPHA)
1294 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
Gian Marcoae2af742018-02-15 12:35:44 +00001295 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1297#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1298#if defined(ALPHA)
1299 acc20 = acc20 * ALPHA;
1300 acc21 = acc21 * ALPHA;
1301 acc22 = acc22 * ALPHA;
1302 acc23 = acc23 * ALPHA;
1303#endif // defined(ALPHA)
1304 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
Gian Marcoae2af742018-02-15 12:35:44 +00001305 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001306#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1307#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1308#if defined(ALPHA)
1309 acc30 = acc30 * ALPHA;
1310 acc31 = acc31 * ALPHA;
1311 acc32 = acc32 * ALPHA;
1312 acc33 = acc33 * ALPHA;
1313#endif // defined(ALPHA)
1314 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
Gian Marcoae2af742018-02-15 12:35:44 +00001315 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1317}
1318
1319/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1320 *
1321 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1322 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1323 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1324 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1325 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1326 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001327 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1328 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001329 *
1330 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1331 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1332 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1333 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1334 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1335 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1336 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1337 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1338 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1339 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1340 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1341 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1342 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1343 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1344 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1345 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1346 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1347 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1348 */
1349__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1350 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001351 IMAGE_DECLARATION(dst),
1352 uint src0_stride_z,
1353 uint src1_stride_z,
1354 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001355{
1356 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1357 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1358
1359 // Compute starting address for matrix A and Matrix B
1360 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1361
1362 // Update address for the matrix A
1363 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1364
1365 // Update address for the matrix B
1366 src_addr.s1 += idx * sizeof(float);
1367
Gian Marcoae2af742018-02-15 12:35:44 +00001368 // Add offset for batched GEMM
1369 src_addr.s0 += get_global_id(2) * src0_stride_z;
1370
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001371#if defined(MATRIX_B_DEPTH)
1372 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1373 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1374#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001375 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001376#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001377
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001378 // Address boundary for the matrix A
1379 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1380
1381 // Initialize accumulators
1382 float acc00 = 0.0f;
1383 float acc01 = 0.0f;
1384
1385#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1386 float acc10 = 0.0f;
1387 float acc11 = 0.0f;
1388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1390 float acc20 = 0.0f;
1391 float acc21 = 0.0f;
1392#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1393#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1394 float acc30 = 0.0f;
1395 float acc31 = 0.0f;
1396#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1397
1398 // A and B src indices get incremented at the same time.
1399 for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
1400 {
1401 // Load values from matrix A
1402 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1403
1404 // Load values from matrix B
1405 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1406 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1407 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
1408 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
1409
1410 // Multiply and accumulate
1411 acc00 = fma(a0.s0, b0.s0, acc00);
1412 acc00 = fma(a0.s1, b1.s0, acc00);
1413 acc00 = fma(a0.s2, b2.s0, acc00);
1414 acc00 = fma(a0.s3, b3.s0, acc00);
1415
1416 acc01 = fma(a0.s0, b0.s1, acc01);
1417 acc01 = fma(a0.s1, b1.s1, acc01);
1418 acc01 = fma(a0.s2, b2.s1, acc01);
1419 acc01 = fma(a0.s3, b3.s1, acc01);
1420
1421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1422 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1423 acc10 = fma(a0.s0, b0.s0, acc10);
1424 acc10 = fma(a0.s1, b1.s0, acc10);
1425 acc10 = fma(a0.s2, b2.s0, acc10);
1426 acc10 = fma(a0.s3, b3.s0, acc10);
1427
1428 acc11 = fma(a0.s0, b0.s1, acc11);
1429 acc11 = fma(a0.s1, b1.s1, acc11);
1430 acc11 = fma(a0.s2, b2.s1, acc11);
1431 acc11 = fma(a0.s3, b3.s1, acc11);
1432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1434 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1435 acc20 = fma(a0.s0, b0.s0, acc20);
1436 acc20 = fma(a0.s1, b1.s0, acc20);
1437 acc20 = fma(a0.s2, b2.s0, acc20);
1438 acc20 = fma(a0.s3, b3.s0, acc20);
1439
1440 acc21 = fma(a0.s0, b0.s1, acc21);
1441 acc21 = fma(a0.s1, b1.s1, acc21);
1442 acc21 = fma(a0.s2, b2.s1, acc21);
1443 acc21 = fma(a0.s3, b3.s1, acc21);
1444#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1445#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1446 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1447 acc30 = fma(a0.s0, b0.s0, acc30);
1448 acc30 = fma(a0.s1, b1.s0, acc30);
1449 acc30 = fma(a0.s2, b2.s0, acc30);
1450 acc30 = fma(a0.s3, b3.s0, acc30);
1451
1452 acc31 = fma(a0.s0, b0.s1, acc31);
1453 acc31 = fma(a0.s1, b1.s1, acc31);
1454 acc31 = fma(a0.s2, b2.s1, acc31);
1455 acc31 = fma(a0.s3, b3.s1, acc31);
1456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1457 }
1458 // float size increment
1459 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
1460 {
1461 // Load values from matrix A
1462 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1463#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1464 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1465#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1466#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1467 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1468#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1469#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1470 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1471#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1472 // Load values from matrix B
1473 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1474
1475 // Multiply and accumulate
1476 acc00 = fma(a0, b0.s0, acc00);
1477 acc01 = fma(a0, b0.s1, acc01);
1478#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1479 acc10 = fma(a1, b0.s0, acc10);
1480 acc11 = fma(a1, b0.s1, acc11);
1481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1482#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1483 acc20 = fma(a2, b0.s0, acc20);
1484 acc21 = fma(a2, b0.s1, acc21);
1485#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1487 acc30 = fma(a3, b0.s0, acc30);
1488 acc31 = fma(a3, b0.s1, acc31);
1489#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1490 }
1491
1492 // Compute destination address
1493 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1494
Gian Marcoae2af742018-02-15 12:35:44 +00001495 // Compute dst address
1496 __global uchar *dst_addr = offset(&dst, 0, 0);
1497
1498 // Add offset for batched GEMM
1499 dst_addr += get_global_id(2) * dst_stride_z;
1500
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001501 // Multiply by the weight of matrix-matrix product and store the result
1502#if defined(ALPHA)
1503 acc00 = acc00 * ALPHA;
1504 acc01 = acc01 * ALPHA;
1505#endif // defined(ALPHA)
1506 float2 acc0 = ((float2)(acc00, acc01));
Gian Marcoae2af742018-02-15 12:35:44 +00001507 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001508#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1509#if defined(ALPHA)
1510 acc10 = acc10 * ALPHA;
1511 acc11 = acc11 * ALPHA;
1512#endif // defined(ALPHA)
1513 float2 acc1 = ((float2)(acc10, acc11));
Gian Marcoae2af742018-02-15 12:35:44 +00001514 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001515#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1516#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1517#if defined(ALPHA)
1518 acc20 = acc20 * ALPHA;
1519 acc21 = acc21 * ALPHA;
1520#endif // defined(ALPHA)
1521 float2 acc2 = ((float2)(acc20, acc21));
Gian Marcoae2af742018-02-15 12:35:44 +00001522 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001523#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1524#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1525#if defined(ALPHA)
1526 acc30 = acc30 * ALPHA;
1527 acc31 = acc31 * ALPHA;
1528#endif // defined(ALPHA)
1529 float2 acc3 = (float2)(acc30, acc31);
Gian Marcoae2af742018-02-15 12:35:44 +00001530 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001531#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1532}
1533
1534#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001535/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001536 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001537 * @note This OpenCL kernel works with fixed point data types QS8
1538 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001539 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001540 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001541 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001542 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1543 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001544 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001545 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001546 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1547 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1548 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1549 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1550 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1551 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1552 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1553 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1554 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1555 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1556 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1557 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1558 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1559 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1560 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1561 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1562 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1563 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001564__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001565 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001566 IMAGE_DECLARATION(dst),
1567 uint src0_stride_z,
1568 uint src1_stride_z,
1569 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001570{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001571 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001572
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001573 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001574 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001575
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001576 // Update address for the matrix A
1577 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001578
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001579 // Update address for the matrix B
1580 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001581
Gian Marcoae2af742018-02-15 12:35:44 +00001582 // Add offset for batched GEMM
1583 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001584
1585#if defined(MATRIX_B_DEPTH)
1586 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1587 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1588#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001589 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001590#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001591
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001592 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
1593
1594 short8 acc00 = 0;
1595 short8 acc01 = 0;
1596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1597 short8 acc10 = 0;
1598 short8 acc11 = 0;
1599#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1600#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1601 short8 acc20 = 0;
1602 short8 acc21 = 0;
1603#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1604#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1605 short8 acc30 = 0;
1606 short8 acc31 = 0;
1607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1608
1609 // This for loop performs 4 accumulations per iteration
1610 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001611 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001612 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1614 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1616#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1617 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1620 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1621#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001622 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1623 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001624
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001625 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
1626 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
1627 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1628 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1629#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1630 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
1631 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
1632 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1633 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1634#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1635#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1636 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
1637 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
1638 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1639 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1640#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1642 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
1643 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
1644 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1645 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1646#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001647 }
1648
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001649 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001650 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1651 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001652 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1653#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1654 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1655#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1656#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1657 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1658#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1660 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1661#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001662 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
1663
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001664 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
1665 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1667 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
1668 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
1669#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1670#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1671 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
1672 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
1673#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1674#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1675 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
1676 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
1677#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001678 }
1679
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001680 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001681 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1682
Gian Marcoae2af742018-02-15 12:35:44 +00001683 // Compute dst address
1684 __global uchar *dst_addr = offset(&dst, 0, 0);
1685
1686 // Add offset for batched GEMM
1687 dst_addr += get_global_id(2) * dst_stride_z;
1688
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001689 // Multiply by the weight of matrix product and store the result
1690 char16 acc_qs8;
1691 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001692#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001693 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001694#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001695 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1697 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001698#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001699 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001700#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001701 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001702#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1704 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001705#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001706 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001707#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001708 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001709#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1710#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1711 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001712#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001713 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001714#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001715 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001716#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001717}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001718
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001719/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001720 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001721 * @note This OpenCL kernel works with fixed point data types QS16
1722 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001723 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001724 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001725 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001726 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1727 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001728 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001729 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001730 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1731 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1732 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1733 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1734 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1735 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1736 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1737 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1738 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1739 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1740 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1741 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1742 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1743 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1744 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1745 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1746 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1747 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001748__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001749 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001750 IMAGE_DECLARATION(dst),
1751 uint src0_stride_z,
1752 uint src1_stride_z,
1753 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001754{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001755 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001756
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001757 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001758 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001759
1760 // Update address for the matrix A
1761 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1762
1763 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001764 src_addr.s1 += idx * sizeof(short);
1765
Gian Marcoae2af742018-02-15 12:35:44 +00001766 // Add offset for batched GEMM
1767 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001768
1769#if defined(MATRIX_B_DEPTH)
1770 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1771 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1772#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001773 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001774#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001775
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001776 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001777
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001778 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001779#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1780 int8 acc1 = 0;
1781#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1782#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1783 int8 acc2 = 0;
1784#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1786 int8 acc3 = 0;
1787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001788
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001789 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001790 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001791 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001792 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1794 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1795#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1796#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1797 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1798#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1799#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1800 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1801#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001802 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1803 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001804
1805 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1806 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001807#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1808 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
1809 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
1810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1811#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1812 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
1813 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
1814#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1815#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1816 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
1817 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
1818#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001819 }
1820
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001821 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001822 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
1823 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001824 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1825#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1826 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1827#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1829 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1832 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1833#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001834 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
1835
1836 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001837#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1838 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
1839#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1840#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1841 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
1842#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1843#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1844 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
1845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001846 }
1847
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001848 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001849 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1850
Gian Marcoae2af742018-02-15 12:35:44 +00001851 // Compute dst address
1852 __global uchar *dst_addr = offset(&dst, 0, 0);
1853
Gian Marco Iodice81b28c42018-03-29 10:29:36 +01001854 // Add offset for batched GEMM
1855 dst_addr += get_global_id(2) * dst_stride_z;
1856
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001857 // Multiply by the weight of matrix product and store the result
1858 short8 acc_qs16;
1859 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001860#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001861 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001862#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001863 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001864#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1865 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001866#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001867 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001868#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001869 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001870#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1871#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1872 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001873#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001874 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001875#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001876 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001877#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1878#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1879 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001880#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001881 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001882#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00001883 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001885}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001886#endif // defined(FIXED_POINT_POSITION)
1887#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001888
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001889#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001890/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1891 *
Gian Marco19835e52018-01-30 13:35:54 +00001892 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001893 *
1894 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
1895 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1896 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1897 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1898 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1899 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001900 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001901 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1902 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1903 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1904 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1905 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1906 */
1907__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
1908 IMAGE_DECLARATION(dst))
1909{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001910 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001911 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1912 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1913
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001914 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001915 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
1916
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001917 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001918 float4 c = vload4(0, (__global float *)src.ptr);
1919
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001920 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001921 float4 out = alpha_ab + (float4)BETA * c;
1922
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001923 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001924 vstore4(out, 0, (__global float *)dst.ptr);
1925}
1926
1927/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1928 *
Gian Marco19835e52018-01-30 13:35:54 +00001929 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001930 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001931 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
1932 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1933 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1934 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1935 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1936 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001937 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001938 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1939 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1940 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1941 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1942 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1943 */
1944__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
1945 IMAGE_DECLARATION(dst))
1946{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001947 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001948 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1949 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1950
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001951 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001952 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
1953
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001954 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001955 half8 c = vload8(0, (__global half *)src.ptr);
1956
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001957 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001958 half8 out = alpha_ab + (half8)BETA * c;
1959
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001960 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001961 vstore8(out, 0, (__global half *)dst.ptr);
1962}
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001963
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001964#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001965/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1966 *
Gian Marco19835e52018-01-30 13:35:54 +00001967 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001968 *
1969 * @note: BETA must be passed in 8 bit fixed point format
1970 *
1971 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
1972 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1973 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1974 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1975 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1976 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
1977 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
1978 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1979 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1980 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1981 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1982 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1983 */
1984__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
1985 IMAGE_DECLARATION(dst))
1986{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001987 // Compute source and destination addresses
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001988 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1989 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1990
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001991 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001992 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
1993
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001994 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001995 char16 c = vload16(0, (__global char *)src.ptr);
1996
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001997 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001998 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
1999
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002000 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002001 vstore16(out, 0, (__global char *)dst.ptr);
2002}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002003
2004/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2005 *
Gian Marco19835e52018-01-30 13:35:54 +00002006 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002007 *
2008 * @note: BETA must be passed in 16 bit fixed point format
2009 *
2010 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
2011 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2012 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2013 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2014 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2015 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2016 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2017 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2018 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2019 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2020 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2021 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2022 */
2023__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
2024 IMAGE_DECLARATION(dst))
2025{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002026 // Compute source and destination addresses
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002027 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2028 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2029
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002030 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002031 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
2032
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002033 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002034 short8 c = vload8(0, (__global short *)src.ptr);
2035
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002036 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002037 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
2038
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002039 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002040 vstore8(out, 0, (__global short *)dst.ptr);
2041}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002042#endif // defined(FIXED_POINT_POSITION)
2043#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002044
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002045#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002046/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2047 *
Gian Marco19835e52018-01-30 13:35:54 +00002048 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002049 *
Gian Marco19835e52018-01-30 13:35:54 +00002050 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002051 *
2052 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2053 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2054 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2055 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2056 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2057 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002058 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002059 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2060 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2061 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2062 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2063 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2064 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2065 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002066 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002067 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2068 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2069 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2070 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2071 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2072 */
2073__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2074 TENSOR3D_DECLARATION(src1),
2075 IMAGE_DECLARATION(dst))
2076{
2077 int idx = get_global_id(0) * 4;
2078 int idy = get_global_id(1);
2079
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002080 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002081 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2082 src_addr.s1 += idx * sizeof(float);
2083
2084 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2085
2086 float4 acc = 0.0f;
2087
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002088 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002089 {
2090 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2091 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2092 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2093
2094 acc += b0 * (float4)a0.s0;
2095 acc += b1 * (float4)a0.s1;
2096 }
2097
2098 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2099 {
2100 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2101 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2102
2103 acc += b0 * (float4)a0;
2104 }
2105
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002106 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002107 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2108
2109 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2110}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002111#endif // defined(WIDTH_VECTOR_A)
2112
2113/** This kernel accumulates each row with the biases vector.
2114 *
2115 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2116 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2117 *
2118 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
2119 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2120 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2121 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2122 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2123 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2124 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2125 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2126 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2127 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2128 */
2129#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2130__kernel void gemm_accumulate_biases(
2131 IMAGE_DECLARATION(accum),
2132 VECTOR_DECLARATION(biases))
2133{
2134 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2135 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2136
2137 // Vector size, i.e. number of vector elements.
2138 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2139 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2140 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2141 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
2142#ifdef FIXED_POINT_POSITION
2143 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
2144#else // FIXED_POINT_POSITION
2145 accum_value = biases_value + accum_value;
2146#endif // FIXED_POINT_POSITION
2147 // Store result in the accumulate buffer
2148 VSTORE(VECTOR_SIZE)
2149 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2150}
2151#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)