blob: 7215f5811fdff6a26a41fc6c1fdac316732b4ae3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Gian Marco36a0a462018-01-12 10:21:40 +000030#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
Gian Marco19835e52018-01-30 13:35:54 +000032#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000033#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000034#elif ELEMENT_SIZE == 2
35#define DATA_TYPE ushort
36#elif ELEMENT_SIZE == 4
37#define DATA_TYPE uint
38#else // ELEMENT_SIZE == 1
39#error "Element size not supported"
40#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000041
42/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043 *
Gian Marco19835e52018-01-30 13:35:54 +000044 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000046 *
47 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010048 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000052 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010055 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000057 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000059 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
61 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
63 */
Gian Marcoae2af742018-02-15 12:35:44 +000064__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066{
67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000069 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010071 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000072 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010074 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000075 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077
Gian Marcoae2af742018-02-15 12:35:44 +000078 // Add offset for batched GEMM
79 dst_addr_in_bytes += z * dst_stride_z;
80
Gian Marco36a0a462018-01-12 10:21:40 +000081 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
82 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083
Gian Marco36a0a462018-01-12 10:21:40 +000084 VSTORE(TRANSPOSE_W)
85 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086}
Gian Marco36a0a462018-01-12 10:21:40 +000087#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088
Gian Marco36a0a462018-01-12 10:21:40 +000089#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
90
91/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092 *
Gian Marco19835e52018-01-30 13:35:54 +000093 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
94 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
95 *
Gian Marco36a0a462018-01-12 10:21:40 +000096 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
102 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100104 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
106 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
107 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
108 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000109 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
110 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
112 */
Gian Marcoae2af742018-02-15 12:35:44 +0000113__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
114 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115{
Gian Marco36a0a462018-01-12 10:21:40 +0000116 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000119 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120
Gian Marcoae2af742018-02-15 12:35:44 +0000121 // Compute address for source tensor
122 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000124 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marcoae2af742018-02-15 12:35:44 +0000128 // Add offset for batched GEMM
129 dst_addr_in_bytes += z * dst_stride_z;
130
131 __global uchar *input_ptr = src.ptr;
132
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000133 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000134 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000135 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000136 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000138 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000140 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142
Gian Marco36a0a462018-01-12 10:21:40 +0000143 VEC_DATA_TYPE(DATA_TYPE, 4)
144 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100146
Gian Marco36a0a462018-01-12 10:21:40 +0000147 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149
Gian Marco36a0a462018-01-12 10:21:40 +0000150 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152
Gian Marco36a0a462018-01-12 10:21:40 +0000153 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155}
Gian Marco36a0a462018-01-12 10:21:40 +0000156#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157
Gian Marco36a0a462018-01-12 10:21:40 +0000158#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100160 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100161 *
Gian Marco19835e52018-01-30 13:35:54 +0000162 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
163 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
164 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000165 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
166 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167 *
168 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
169 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
170 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
171 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
172 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
173 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100174 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100175 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
176 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
178 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100180 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100181 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000182 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000184 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
186 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100187__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
188 IMAGE_DECLARATION(src1),
189 IMAGE_DECLARATION(dst),
190 uint src0_stride_z,
191 uint src1_stride_z,
192 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193{
Gian Marco36a0a462018-01-12 10:21:40 +0000194 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
195 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000196 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197
Gian Marco36a0a462018-01-12 10:21:40 +0000198 // Offset
199 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
200 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201
Gian Marco36a0a462018-01-12 10:21:40 +0000202 // src_addr_a = address of matrix A
203 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000204 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
205 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
206
207#if defined(MATRIX_B_DEPTH)
208 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
209 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
210#else // defined(MATRIX_B_DEPTH)
211 src1_addr_in_bytes += z * src1_stride_z;
212#endif // defined(MATRIX_B_DEPTH)
213
214 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
215 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000217 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000218 __global float *src_end_addr_b = src_addr_b + COLS_B;
219
220 src_addr_a += offset_row_a;
221 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000223 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224 float4 c00 = 0.0f;
225 float4 c10 = 0.0f;
226 float4 c20 = 0.0f;
227 float4 c30 = 0.0f;
228
Gian Marco36a0a462018-01-12 10:21:40 +0000229 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100230 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000231 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000232 float4 a0 = vload4(0, src_addr_a);
233 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234
235 c00 += (float4)a0.s0 * b0;
236 c10 += (float4)a0.s1 * b0;
237 c20 += (float4)a0.s2 * b0;
238 c30 += (float4)a0.s3 * b0;
239
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000240 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000241 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
242 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243
244 c00 += (float4)a0.s0 * b0;
245 c10 += (float4)a0.s1 * b0;
246 c20 += (float4)a0.s2 * b0;
247 c30 += (float4)a0.s3 * b0;
248 }
249
Gian Marco36a0a462018-01-12 10:21:40 +0000250 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000252 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000253 float4 a0 = vload4(0, src_addr_a);
254 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255
256 c00 += (float4)a0.s0 * b0;
257 c10 += (float4)a0.s1 * b0;
258 c20 += (float4)a0.s2 * b0;
259 c30 += (float4)a0.s3 * b0;
260 }
261
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000262 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
264
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000265#if defined(ALPHA)
266 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267 c00 = c00 * (float4)ALPHA;
268 c10 = c10 * (float4)ALPHA;
269 c20 = c20 * (float4)ALPHA;
270 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000271#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272
Gian Marcoae2af742018-02-15 12:35:44 +0000273 // Compute dst address
274 __global uchar *dst_addr = offset(&dst, 0, 0);
275
276 // Add offset for batched GEMM
277 dst_addr += z * dst_stride_z;
278
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000279 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000280 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
281 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
282 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
283 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284}
285
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000286/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100287 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100288 *
Gian Marco19835e52018-01-30 13:35:54 +0000289 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
290 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
291 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000292 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
293 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
294 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295 *
296 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
297 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
298 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
299 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
300 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
301 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100302 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100303 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
304 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
305 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
306 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
307 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100308 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000310 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100311 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000312 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
314 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100315__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
316 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000317 IMAGE_DECLARATION(dst),
318 uint src0_stride_z,
319 uint src1_stride_z,
320 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100321{
Gian Marco36a0a462018-01-12 10:21:40 +0000322 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
323 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000324 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000325
326 // Offset
327 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
328 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
329
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100330 // src_addr_a = address of matrix A
331 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000332 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
333 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
334
335#if defined(MATRIX_B_DEPTH)
336 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
337 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
338#else // defined(MATRIX_B_DEPTH)
339 src1_addr_in_bytes += z * src1_stride_z;
340#endif // defined(MATRIX_B_DEPTH)
341
342 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
343 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100344
Gian Marco36a0a462018-01-12 10:21:40 +0000345 src_addr_a += offset_row_a;
346 src_addr_b += offset_row_b;
347
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100348 // Reset accumulators
349 float c00 = 0.0f;
350 float c01 = 0.0f;
351 float c02 = 0.0f;
352 float c03 = 0.0f;
353 float c10 = 0.0f;
354 float c11 = 0.0f;
355 float c12 = 0.0f;
356 float c13 = 0.0f;
357 float c20 = 0.0f;
358 float c21 = 0.0f;
359 float c22 = 0.0f;
360 float c23 = 0.0f;
361 float c30 = 0.0f;
362 float c31 = 0.0f;
363 float c32 = 0.0f;
364 float c33 = 0.0f;
365
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100366#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
367
368 int i = 0;
369 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100370 {
371 // Load values from matrix A (interleaved) and matrix B (transposed)
372 float4 a0 = vload4(0, src_addr_a);
373 float4 b0 = vload4(0, src_addr_b);
374
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100375 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
376 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100377
378 c00 = fma(a0.s0, b0.s0, c00);
379 c01 = fma(a0.s0, b0.s1, c01);
380 c02 = fma(a0.s0, b0.s2, c02);
381 c03 = fma(a0.s0, b0.s3, c03);
382
383 c10 = fma(a0.s1, b0.s0, c10);
384 c11 = fma(a0.s1, b0.s1, c11);
385 c12 = fma(a0.s1, b0.s2, c12);
386 c13 = fma(a0.s1, b0.s3, c13);
387
388 c20 = fma(a0.s2, b0.s0, c20);
389 c21 = fma(a0.s2, b0.s1, c21);
390 c22 = fma(a0.s2, b0.s2, c22);
391 c23 = fma(a0.s2, b0.s3, c23);
392
393 c30 = fma(a0.s3, b0.s0, c30);
394 c31 = fma(a0.s3, b0.s1, c31);
395 c32 = fma(a0.s3, b0.s2, c32);
396 c33 = fma(a0.s3, b0.s3, c33);
397
398 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100399 a0 = vload4(0, src_addr_a);
400 b0 = vload4(0, src_addr_b);
401
402 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
403 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404
405 c00 = fma(a0.s0, b0.s0, c00);
406 c01 = fma(a0.s0, b0.s1, c01);
407 c02 = fma(a0.s0, b0.s2, c02);
408 c03 = fma(a0.s0, b0.s3, c03);
409
410 c10 = fma(a0.s1, b0.s0, c10);
411 c11 = fma(a0.s1, b0.s1, c11);
412 c12 = fma(a0.s1, b0.s2, c12);
413 c13 = fma(a0.s1, b0.s3, c13);
414
415 c20 = fma(a0.s2, b0.s0, c20);
416 c21 = fma(a0.s2, b0.s1, c21);
417 c22 = fma(a0.s2, b0.s2, c22);
418 c23 = fma(a0.s2, b0.s3, c23);
419
420 c30 = fma(a0.s3, b0.s0, c30);
421 c31 = fma(a0.s3, b0.s1, c31);
422 c32 = fma(a0.s3, b0.s2, c32);
423 c33 = fma(a0.s3, b0.s3, c33);
424
425 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100426 a0 = vload4(0, src_addr_a);
427 b0 = vload4(0, src_addr_b);
428
429 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
430 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
431
432 c00 = fma(a0.s0, b0.s0, c00);
433 c01 = fma(a0.s0, b0.s1, c01);
434 c02 = fma(a0.s0, b0.s2, c02);
435 c03 = fma(a0.s0, b0.s3, c03);
436
437 c10 = fma(a0.s1, b0.s0, c10);
438 c11 = fma(a0.s1, b0.s1, c11);
439 c12 = fma(a0.s1, b0.s2, c12);
440 c13 = fma(a0.s1, b0.s3, c13);
441
442 c20 = fma(a0.s2, b0.s0, c20);
443 c21 = fma(a0.s2, b0.s1, c21);
444 c22 = fma(a0.s2, b0.s2, c22);
445 c23 = fma(a0.s2, b0.s3, c23);
446
447 c30 = fma(a0.s3, b0.s0, c30);
448 c31 = fma(a0.s3, b0.s1, c31);
449 c32 = fma(a0.s3, b0.s2, c32);
450 c33 = fma(a0.s3, b0.s3, c33);
451
452 // Load values from matrix A (interleaved) and matrix B (transposed)
453 a0 = vload4(0, src_addr_a);
454 b0 = vload4(0, src_addr_b);
455
456 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
457 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458
459 c00 = fma(a0.s0, b0.s0, c00);
460 c01 = fma(a0.s0, b0.s1, c01);
461 c02 = fma(a0.s0, b0.s2, c02);
462 c03 = fma(a0.s0, b0.s3, c03);
463
464 c10 = fma(a0.s1, b0.s0, c10);
465 c11 = fma(a0.s1, b0.s1, c11);
466 c12 = fma(a0.s1, b0.s2, c12);
467 c13 = fma(a0.s1, b0.s3, c13);
468
469 c20 = fma(a0.s2, b0.s0, c20);
470 c21 = fma(a0.s2, b0.s1, c21);
471 c22 = fma(a0.s2, b0.s2, c22);
472 c23 = fma(a0.s2, b0.s3, c23);
473
474 c30 = fma(a0.s3, b0.s0, c30);
475 c31 = fma(a0.s3, b0.s1, c31);
476 c32 = fma(a0.s3, b0.s2, c32);
477 c33 = fma(a0.s3, b0.s3, c33);
478 }
479
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100480 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100481 {
482 // Load values from matrix A (interleaved) and matrix B (transposed)
483 float4 a0 = vload4(0, src_addr_a);
484 float4 b0 = vload4(0, src_addr_b);
485
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100486 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
487 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
488
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489 c00 = fma(a0.s0, b0.s0, c00);
490 c01 = fma(a0.s0, b0.s1, c01);
491 c02 = fma(a0.s0, b0.s2, c02);
492 c03 = fma(a0.s0, b0.s3, c03);
493
494 c10 = fma(a0.s1, b0.s0, c10);
495 c11 = fma(a0.s1, b0.s1, c11);
496 c12 = fma(a0.s1, b0.s2, c12);
497 c13 = fma(a0.s1, b0.s3, c13);
498
499 c20 = fma(a0.s2, b0.s0, c20);
500 c21 = fma(a0.s2, b0.s1, c21);
501 c22 = fma(a0.s2, b0.s2, c22);
502 c23 = fma(a0.s2, b0.s3, c23);
503
504 c30 = fma(a0.s3, b0.s0, c30);
505 c31 = fma(a0.s3, b0.s1, c31);
506 c32 = fma(a0.s3, b0.s2, c32);
507 c33 = fma(a0.s3, b0.s3, c33);
508 }
509
510 // Compute destination address
511 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
512
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000513#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 // Multiply by the weight of matrix product
515 c00 = c00 * ALPHA;
516 c01 = c01 * ALPHA;
517 c02 = c02 * ALPHA;
518 c03 = c03 * ALPHA;
519 c10 = c10 * ALPHA;
520 c11 = c11 * ALPHA;
521 c12 = c12 * ALPHA;
522 c13 = c13 * ALPHA;
523 c20 = c20 * ALPHA;
524 c21 = c21 * ALPHA;
525 c22 = c22 * ALPHA;
526 c23 = c23 * ALPHA;
527 c30 = c30 * ALPHA;
528 c31 = c31 * ALPHA;
529 c32 = c32 * ALPHA;
530 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000531#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532
Gian Marcoae2af742018-02-15 12:35:44 +0000533 // Compute dst address
534 __global uchar *dst_addr = offset(&dst, 0, 0);
535
536 // Add offset for batched GEMM
537 dst_addr += z * dst_stride_z;
538
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000540 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
541 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
542 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
543 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100544}
545
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100546#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100547/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100548 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100549 *
Gian Marco19835e52018-01-30 13:35:54 +0000550 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
551 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
552 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000553 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
554 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100555 *
556 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
557 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
558 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
559 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
560 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
561 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100562 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100563 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
564 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
565 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
566 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
567 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100568 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100569 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000570 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100571 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000572 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100573 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
574 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100575__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
576 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000577 IMAGE_DECLARATION(dst),
578 uint src0_stride_z,
579 uint src1_stride_z,
580 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100581{
Gian Marco36a0a462018-01-12 10:21:40 +0000582 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
583 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000584 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100585
Gian Marco36a0a462018-01-12 10:21:40 +0000586 // Offset
587 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
588 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100589
Gian Marco36a0a462018-01-12 10:21:40 +0000590 // src_addr_a = address of matrix A
591 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000592 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
593 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
594
595#if defined(MATRIX_B_DEPTH)
596 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
597 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
598#else // defined(MATRIX_B_DEPTH)
599 src1_addr_in_bytes += z * src1_stride_z;
600#endif // defined(MATRIX_B_DEPTH)
601
602 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
603 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100604
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000605 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000606 __global half *src_end_addr_b = src_addr_b + COLS_B;
607
608 src_addr_a += offset_row_a;
609 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100610
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000611 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100612 half8 c00 = 0.0f;
613 half8 c10 = 0.0f;
614 half8 c20 = 0.0f;
615 half8 c30 = 0.0f;
616
Gian Marco36a0a462018-01-12 10:21:40 +0000617 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100618 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000619 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000620 half4 a0 = vload4(0, src_addr_a);
621 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100622
623 c00 += (half8)a0.s0 * b0;
624 c10 += (half8)a0.s1 * b0;
625 c20 += (half8)a0.s2 * b0;
626 c30 += (half8)a0.s3 * b0;
627
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000628 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000629 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
630 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100631
632 c00 += (half8)a0.s0 * b0;
633 c10 += (half8)a0.s1 * b0;
634 c20 += (half8)a0.s2 * b0;
635 c30 += (half8)a0.s3 * b0;
636 }
637
Gian Marco36a0a462018-01-12 10:21:40 +0000638 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100639 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000640 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000641 half4 a0 = vload4(0, src_addr_a);
642 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100643
644 c00 += (half8)a0.s0 * b0;
645 c10 += (half8)a0.s1 * b0;
646 c20 += (half8)a0.s2 * b0;
647 c30 += (half8)a0.s3 * b0;
648 }
649
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000650 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100651 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
652
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000653#if defined(ALPHA)
654 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100655 c00 = c00 * (half8)ALPHA;
656 c10 = c10 * (half8)ALPHA;
657 c20 = c20 * (half8)ALPHA;
658 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000659#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100660
Gian Marcoae2af742018-02-15 12:35:44 +0000661 // Compute dst address
662 __global uchar *dst_addr = offset(&dst, 0, 0);
663
664 // Add offset for batched GEMM
665 dst_addr += z * dst_stride_z;
666
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000667 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000668 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
669 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
670 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
671 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100672}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100673
674/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
675 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
676 *
677 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
678 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
679 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
680 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
681 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
682 *
683 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
684 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
685 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
686 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
687 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
688 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
689 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
690 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
691 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
692 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
693 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
694 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
695 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
696 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
697 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
698 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
699 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
700 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
701 */
702__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
703 IMAGE_DECLARATION(src1),
704 IMAGE_DECLARATION(dst),
705 uint src0_stride_z,
706 uint src1_stride_z,
707 uint dst_stride_z)
708{
709 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
710 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
711 int z = get_global_id(2);
712
713 // Offset
714 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
715 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
716
717 // src_addr_a = address of matrix A
718 // src_addr_b = address of matrix B
719 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
720 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
721
722#if defined(MATRIX_B_DEPTH)
723 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
724 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
725#else // defined(MATRIX_B_DEPTH)
726 src1_addr_in_bytes += z * src1_stride_z;
727#endif // defined(MATRIX_B_DEPTH)
728
729 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
730 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
731
732 // Compute end row address for matrix B
733 __global half *src_end_addr_b = src_addr_b + COLS_B;
734
735 src_addr_a += offset_row_a;
736 src_addr_b += offset_row_b;
737
738 // Reset accumulators
739 half8 c00 = 0.0f;
740 half8 c10 = 0.0f;
741 half8 c20 = 0.0f;
742 half8 c30 = 0.0f;
743
744#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
745
746 int i = 0;
747 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
748 {
749#if MULT_INTERLEAVE4X4_HEIGHT == 1
750 // Load values from matrix A (interleaved) and matrix B (transposed)
751 half8 a0 = vload8(0, src_addr_a);
752 half8 b0 = vload8(0, src_addr_b);
753
754 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
755 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
756
757 c00 = fma((half8)a0.s0, b0, c00);
758 c10 = fma((half8)a0.s1, b0, c10);
759 c20 = fma((half8)a0.s2, b0, c20);
760 c30 = fma((half8)a0.s3, b0, c30);
761
762 // Load values from matrix B (transposed)
763 b0 = vload8(0, src_addr_b);
764
765 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
766
767 c00 = fma((half8)a0.s4, b0, c00);
768 c10 = fma((half8)a0.s5, b0, c10);
769 c20 = fma((half8)a0.s6, b0, c20);
770 c30 = fma((half8)a0.s7, b0, c30);
771
772 // Load values from matrix A (interleaved) and matrix B (transposed)
773 a0 = vload8(0, src_addr_a);
774 b0 = vload8(0, src_addr_b);
775
776 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
777 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
778
779 c00 = fma((half8)a0.s0, b0, c00);
780 c10 = fma((half8)a0.s1, b0, c10);
781 c20 = fma((half8)a0.s2, b0, c20);
782 c30 = fma((half8)a0.s3, b0, c30);
783
784 // Load values from matrix B (transposed)
785 b0 = vload8(0, src_addr_b);
786
787 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
788
789 c00 = fma((half8)a0.s4, b0, c00);
790 c10 = fma((half8)a0.s5, b0, c10);
791 c20 = fma((half8)a0.s6, b0, c20);
792 c30 = fma((half8)a0.s7, b0, c30);
793#else // MULT_INTERLEAVE4X4_HEIGHT == 1
794 // Load values from matrix A (interleaved) and matrix B (transposed)
795 half4 a0 = vload4(0, src_addr_a);
796 half8 b0 = vload8(0, src_addr_b);
797
798 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
799 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
800
801 c00 = fma((half8)a0.s0, b0, c00);
802 c10 = fma((half8)a0.s1, b0, c10);
803 c20 = fma((half8)a0.s2, b0, c20);
804 c30 = fma((half8)a0.s3, b0, c30);
805
806 // Load values from matrix A (interleaved) and matrix B (transposed)
807 a0 = vload4(0, src_addr_a);
808 b0 = vload8(0, src_addr_b);
809
810 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
811 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
812
813 c00 = fma((half8)a0.s0, b0, c00);
814 c10 = fma((half8)a0.s1, b0, c10);
815 c20 = fma((half8)a0.s2, b0, c20);
816 c30 = fma((half8)a0.s3, b0, c30);
817
818 // Load values from matrix A (interleaved) and matrix B (transposed)
819 a0 = vload4(0, src_addr_a);
820 b0 = vload8(0, src_addr_b);
821
822 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
823 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
824
825 c00 = fma((half8)a0.s0, b0, c00);
826 c10 = fma((half8)a0.s1, b0, c10);
827 c20 = fma((half8)a0.s2, b0, c20);
828 c30 = fma((half8)a0.s3, b0, c30);
829
830 // Load values from matrix A (interleaved) and matrix B (transposed)
831 a0 = vload4(0, src_addr_a);
832 b0 = vload8(0, src_addr_b);
833
834 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
835 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
836
837 c00 = fma((half8)a0.s0, b0, c00);
838 c10 = fma((half8)a0.s1, b0, c10);
839 c20 = fma((half8)a0.s2, b0, c20);
840 c30 = fma((half8)a0.s3, b0, c30);
841#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
842 }
843
844 for(; i < (int)(COLS_MTX_B); ++i)
845 {
846 // Load values from matrix A (interleaved) and matrix B (transposed)
847 half4 a0 = vload4(0, src_addr_a);
848 half8 b0 = vload8(0, src_addr_b);
849
850 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
851 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
852
853 c00 = fma((half8)a0.s0, b0, c00);
854 c10 = fma((half8)a0.s1, b0, c10);
855 c20 = fma((half8)a0.s2, b0, c20);
856 c30 = fma((half8)a0.s3, b0, c30);
857 }
858
859 // Compute destination address
860 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
861
862#if defined(ALPHA)
863 // Multiply by the weight of matrix product
864 c00 = c00 * (half8)ALPHA;
865 c10 = c10 * (half8)ALPHA;
866 c20 = c20 * (half8)ALPHA;
867 c30 = c30 * (half8)ALPHA;
868#endif // defined(ALPHA)
869
870 // Compute dst address
871 __global uchar *dst_addr = offset(&dst, 0, 0);
872
873 // Add offset for batched GEMM
874 dst_addr += z * dst_stride_z;
875
876 // Store 4x8 block
877 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
878 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
879 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
880 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
881}
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100882#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100883
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000884#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100885/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
886 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
887 *
Gian Marco19835e52018-01-30 13:35:54 +0000888 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
889 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
890 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000891 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
892 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
893 * @note:ALPHA must be passed in 8 bit fixed point format
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100894 *
895 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
896 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
897 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
898 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
899 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
900 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
901 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
902 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
903 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
904 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
905 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
906 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
907 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
908 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000909 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100910 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000911 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100912 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
913 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100914__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
915 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000916 IMAGE_DECLARATION(dst),
917 uint src0_stride_z,
918 uint src1_stride_z,
919 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100920{
Gian Marco36a0a462018-01-12 10:21:40 +0000921 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
922 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000923 int z = get_global_id(2);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100924
Gian Marco36a0a462018-01-12 10:21:40 +0000925 // Offset
926 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
927 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100928
Gian Marco36a0a462018-01-12 10:21:40 +0000929 // src_addr_a = address of matrix A
930 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000931 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
932 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
933
934#if defined(MATRIX_B_DEPTH)
935 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
936 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
937#else // defined(MATRIX_B_DEPTH)
938 src1_addr_in_bytes += z * src1_stride_z;
939#endif // defined(MATRIX_B_DEPTH)
940
941 __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
942 __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100943
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000944 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000945 __global char *src_end_addr_b = src_addr_b + COLS_B;
946
947 src_addr_a += offset_row_a;
948 src_addr_b += offset_row_b;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100949
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000950 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100951 short8 c00 = 0.0f;
952 short8 c10 = 0.0f;
953 short8 c20 = 0.0f;
954 short8 c30 = 0.0f;
955 short8 c01 = 0.0f;
956 short8 c11 = 0.0f;
957 short8 c21 = 0.0f;
958 short8 c31 = 0.0f;
959
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000960 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000961 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100962 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000963 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000964 char4 a0 = vload4(0, src_addr_a);
965 char16 b0 = vload16(0, src_addr_b);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100966
967 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
968 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
969 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
970 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
971
972 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
973 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
974 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
975 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
976 }
977
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000978 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100979 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
980
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000981 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100982 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
983 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
984 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
985 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
986
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000987#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100988 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
989 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
990 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
991 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000992#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100993
Gian Marcoae2af742018-02-15 12:35:44 +0000994 // Compute dst address
995 __global uchar *dst_addr = offset(&dst, 0, 0);
996
997 // Add offset for batched GEMM
998 dst_addr += z * dst_stride_z;
999
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001000 // Store 16x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001001 vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
1002 vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
1003 vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
1004 vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001005}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001006
1007/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
1008 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1009 *
Gian Marco19835e52018-01-30 13:35:54 +00001010 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1011 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1012 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001013 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1014 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1015 * @note:ALPHA must be passed in 16 bit fixed point format
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001016 *
1017 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
1018 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1019 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1020 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1021 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1022 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1023 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1024 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1025 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1026 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1027 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1028 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1029 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1030 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001031 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001032 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001033 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001034 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1035 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001036__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
1037 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001038 IMAGE_DECLARATION(dst),
1039 uint src0_stride_z,
1040 uint src1_stride_z,
1041 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001042{
Gian Marco36a0a462018-01-12 10:21:40 +00001043 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1044 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001045 int z = get_global_id(2);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001046
Gian Marco36a0a462018-01-12 10:21:40 +00001047 // Offset
1048 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1049 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001050
Gian Marco36a0a462018-01-12 10:21:40 +00001051 // src_addr_a = address of matrix A
1052 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001053 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1054 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1055
1056#if defined(MATRIX_B_DEPTH)
1057 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1058 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1059#else // defined(MATRIX_B_DEPTH)
1060 src1_addr_in_bytes += z * src1_stride_z;
1061#endif // defined(MATRIX_B_DEPTH)
1062
1063 __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
1064 __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001065
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001066 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001067 __global short *src_end_addr_b = src_addr_b + COLS_B;
1068
1069 src_addr_a += offset_row_a;
1070 src_addr_b += offset_row_b;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001071
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001072 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001073 int8 c00 = 0.0f;
1074 int8 c10 = 0.0f;
1075 int8 c20 = 0.0f;
1076 int8 c30 = 0.0f;
1077
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001078 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +00001079 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001080 {
1081 /* Load values from matrix A (interleaved) and matrix B (transposed) */
Gian Marco36a0a462018-01-12 10:21:40 +00001082 short4 a0 = vload4(0, src_addr_a);
1083 short8 b0 = vload8(0, src_addr_b);
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001084
1085 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1086 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
1087 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
1088 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
1089 }
1090
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001091 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001092 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1093
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001094 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001095 short8 c00_qs16 = convert_short8_sat(c00);
1096 short8 c10_qs16 = convert_short8_sat(c10);
1097 short8 c20_qs16 = convert_short8_sat(c20);
1098 short8 c30_qs16 = convert_short8_sat(c30);
1099
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001100#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001101 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1102 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1103 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1104 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001105#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001106
Gian Marcoae2af742018-02-15 12:35:44 +00001107 // Compute dst address
1108 __global uchar *dst_addr = offset(&dst, 0, 0);
1109
1110 // Add offset for batched GEMM
1111 dst_addr += z * dst_stride_z;
1112
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001113 // Store 8x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001114 vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
1115 vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
1116 vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
1117 vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001118}
1119#endif // defined(FIXED_POINT_POSITION)
Gian Marco36a0a462018-01-12 10:21:40 +00001120#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001121
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001122#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1123#if defined(DATA_TYPE)
1124#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
1125/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001126 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001127 * @note This OpenCL kernel works with floating point data types (F16/F32)
1128 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1129 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001130 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001131 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1132 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001133 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001134 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001135 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1136 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1137 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1138 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1139 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001140 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001141 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1142 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1143 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1144 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1145 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001146 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001147 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1148 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1149 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1150 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1151 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1152 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001153__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1154 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001155 IMAGE_DECLARATION(dst),
1156 uint src0_stride_z,
1157 uint src1_stride_z,
1158 uint dst_stride_z)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001159{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001160 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001161
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001162 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001163 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001164
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001165 // Update address for the matrix A
1166 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001167
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001168 // Update address for the matrix B
1169 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001170
Gian Marcoae2af742018-02-15 12:35:44 +00001171 // Add offset for batched GEMM
1172 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001173
1174#if defined(MATRIX_B_DEPTH)
1175 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1176 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1177#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001178 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001179#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001180
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001181 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1182
1183 VECTOR_TYPE acc0 = 0.0f;
1184#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1185 VECTOR_TYPE acc1 = 0.0f;
1186#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1187#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1188 VECTOR_TYPE acc2 = 0.0f;
1189#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1190#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1191 VECTOR_TYPE acc3 = 0.0f;
1192#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1193
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001194 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001195 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001196 // Load values from matrix A
1197 VEC_DATA_TYPE(DATA_TYPE, 2)
1198 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1199#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1200 VEC_DATA_TYPE(DATA_TYPE, 2)
1201 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1202#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1203#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1204 VEC_DATA_TYPE(DATA_TYPE, 2)
1205 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1206#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1208 VEC_DATA_TYPE(DATA_TYPE, 2)
1209 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1211 // Load values from matrix B
1212 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1213 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001214
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001215 // Accumulate
1216 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1217 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1218#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1219 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1220 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1221#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1223 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1224 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1225#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1226#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1227 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1228 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001230 }
1231
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001232 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001233 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001234 // Load values from matrix A
1235 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1237 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1240 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1241#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1242#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1243 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1244#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1245 // Load values from matrix B
1246 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001247
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001248 // Accumulate
1249 acc0 += b0 * (VECTOR_TYPE)a0;
1250#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1251 acc1 += b0 * (VECTOR_TYPE)a1;
1252#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1253#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1254 acc2 += b0 * (VECTOR_TYPE)a2;
1255#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1256#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1257 acc3 += b0 * (VECTOR_TYPE)a3;
1258#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001259 }
1260
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001261 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001262 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1263
Gian Marcoae2af742018-02-15 12:35:44 +00001264 // Compute dst address
1265 __global uchar *dst_addr = offset(&dst, 0, 0);
1266
1267 // Add offset for batched GEMM
1268 dst_addr += get_global_id(2) * dst_stride_z;
1269
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001270 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001271#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001272 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001273#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001274 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001275 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001276#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001277#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001278 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001279#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001280 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001281 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001282#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1283#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001284#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001285 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001286#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001287 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001288 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001291#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001292 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001293#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001294 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001295 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001297}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001298#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001299
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001300/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1301 *
1302 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1303 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1304 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1305 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1306 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001307 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1308 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001309 *
1310 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1311 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1312 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1313 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1314 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1315 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1316 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1317 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1318 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1319 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1320 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1321 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1322 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1323 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1324 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1325 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1326 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1327 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1328 */
1329__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1330 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001331 IMAGE_DECLARATION(dst),
1332 uint src0_stride_z,
1333 uint src1_stride_z,
1334 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001335{
1336 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1337
1338 // Compute starting address for matrix A and matrix B
1339 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1340
1341 // Update address for matrix A
1342 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1343
1344 // Update address for matrix B
1345 src_addr.s1 += idx * sizeof(float);
1346
Gian Marcoae2af742018-02-15 12:35:44 +00001347 // Add offset for batched GEMM
1348 src_addr.s0 += get_global_id(2) * src0_stride_z;
1349
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001350#if defined(MATRIX_B_DEPTH)
1351 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1352 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1353#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001354 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001355#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001356
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001357 // Initialize accumulators
1358 float acc00 = 0.0f;
1359 float acc01 = 0.0f;
1360 float acc02 = 0.0f;
1361 float acc03 = 0.0f;
1362
1363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1364 float acc10 = 0.0f;
1365 float acc11 = 0.0f;
1366 float acc12 = 0.0f;
1367 float acc13 = 0.0f;
1368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1369
1370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1371 float acc20 = 0.0f;
1372 float acc21 = 0.0f;
1373 float acc22 = 0.0f;
1374 float acc23 = 0.0f;
1375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1376
1377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1378 float acc30 = 0.0f;
1379 float acc31 = 0.0f;
1380 float acc32 = 0.0f;
1381 float acc33 = 0.0f;
1382#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1383
1384 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001385 int i = 0;
1386 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001387 {
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001388 // Load values from matrix A and matrix B
1389 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001390#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001391 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001392#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1393#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001394 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001395#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1396#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001397 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001399 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1400 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001401
1402 // Multiply and accumulate
1403 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001404 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001405 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001406 acc03 = fma(a0.s0, b0.s3, acc03);
1407
1408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001409
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001410 acc10 = fma(a1.s0, b0.s0, acc10);
1411 acc11 = fma(a1.s0, b0.s1, acc11);
1412 acc12 = fma(a1.s0, b0.s2, acc12);
1413 acc13 = fma(a1.s0, b0.s3, acc13);
1414
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1416#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001417
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001418 acc20 = fma(a2.s0, b0.s0, acc20);
1419 acc21 = fma(a2.s0, b0.s1, acc21);
1420 acc22 = fma(a2.s0, b0.s2, acc22);
1421 acc23 = fma(a2.s0, b0.s3, acc23);
1422
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001425
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001426 acc30 = fma(a3.s0, b0.s0, acc30);
1427 acc31 = fma(a3.s0, b0.s1, acc31);
1428 acc32 = fma(a3.s0, b0.s2, acc32);
1429 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001431
1432 // Load values from matrix A and matrix B
1433 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1434 src_addr.s1 += src1_stride_y;
1435
1436 // Multiply and accumulate
1437 acc00 = fma(a0.s1, b0.s0, acc00);
1438 acc01 = fma(a0.s1, b0.s1, acc01);
1439 acc02 = fma(a0.s1, b0.s2, acc02);
1440 acc03 = fma(a0.s1, b0.s3, acc03);
1441
1442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1443
1444 acc10 = fma(a1.s1, b0.s0, acc10);
1445 acc11 = fma(a1.s1, b0.s1, acc11);
1446 acc12 = fma(a1.s1, b0.s2, acc12);
1447 acc13 = fma(a1.s1, b0.s3, acc13);
1448
1449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1451
1452 acc20 = fma(a2.s1, b0.s0, acc20);
1453 acc21 = fma(a2.s1, b0.s1, acc21);
1454 acc22 = fma(a2.s1, b0.s2, acc22);
1455 acc23 = fma(a2.s1, b0.s3, acc23);
1456
1457#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1458#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1459
1460 acc30 = fma(a3.s1, b0.s0, acc30);
1461 acc31 = fma(a3.s1, b0.s1, acc31);
1462 acc32 = fma(a3.s1, b0.s2, acc32);
1463 acc33 = fma(a3.s1, b0.s3, acc33);
1464#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1465
1466 // Load values from matrix A and matrix B
1467 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1468 src_addr.s1 += src1_stride_y;
1469
1470 // Multiply and accumulate
1471 acc00 = fma(a0.s2, b0.s0, acc00);
1472 acc01 = fma(a0.s2, b0.s1, acc01);
1473 acc02 = fma(a0.s2, b0.s2, acc02);
1474 acc03 = fma(a0.s2, b0.s3, acc03);
1475
1476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1477
1478 acc10 = fma(a1.s2, b0.s0, acc10);
1479 acc11 = fma(a1.s2, b0.s1, acc11);
1480 acc12 = fma(a1.s2, b0.s2, acc12);
1481 acc13 = fma(a1.s2, b0.s3, acc13);
1482
1483#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1484#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1485
1486 acc20 = fma(a2.s2, b0.s0, acc20);
1487 acc21 = fma(a2.s2, b0.s1, acc21);
1488 acc22 = fma(a2.s2, b0.s2, acc22);
1489 acc23 = fma(a2.s2, b0.s3, acc23);
1490
1491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1493
1494 acc30 = fma(a3.s2, b0.s0, acc30);
1495 acc31 = fma(a3.s2, b0.s1, acc31);
1496 acc32 = fma(a3.s2, b0.s2, acc32);
1497 acc33 = fma(a3.s2, b0.s3, acc33);
1498#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1499
1500 // Load values from matrix A and matrix B
1501 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1502 src_addr.s1 += src1_stride_y;
1503
1504 // Multiply and accumulate
1505 acc00 = fma(a0.s3, b0.s0, acc00);
1506 acc01 = fma(a0.s3, b0.s1, acc01);
1507 acc02 = fma(a0.s3, b0.s2, acc02);
1508 acc03 = fma(a0.s3, b0.s3, acc03);
1509
1510#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1511
1512 acc10 = fma(a1.s3, b0.s0, acc10);
1513 acc11 = fma(a1.s3, b0.s1, acc11);
1514 acc12 = fma(a1.s3, b0.s2, acc12);
1515 acc13 = fma(a1.s3, b0.s3, acc13);
1516
1517#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1518#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1519
1520 acc20 = fma(a2.s3, b0.s0, acc20);
1521 acc21 = fma(a2.s3, b0.s1, acc21);
1522 acc22 = fma(a2.s3, b0.s2, acc22);
1523 acc23 = fma(a2.s3, b0.s3, acc23);
1524
1525#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1526#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1527
1528 acc30 = fma(a3.s3, b0.s0, acc30);
1529 acc31 = fma(a3.s3, b0.s1, acc31);
1530 acc32 = fma(a3.s3, b0.s2, acc32);
1531 acc33 = fma(a3.s3, b0.s3, acc33);
1532#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1533
1534 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001535 }
1536
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001537 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001538 {
1539 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001540 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001541#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1542 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1543#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1544#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1545 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1546#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1547#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1548 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1549#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1550 // Load values from matrix B
1551 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001552 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001553
1554 // Multiply and accumulate
1555 acc00 = fma(a0, b0.s0, acc00);
1556 acc01 = fma(a0, b0.s1, acc01);
1557 acc02 = fma(a0, b0.s2, acc02);
1558 acc03 = fma(a0, b0.s3, acc03);
1559#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1560 acc10 = fma(a1, b0.s0, acc10);
1561 acc11 = fma(a1, b0.s1, acc11);
1562 acc12 = fma(a1, b0.s2, acc12);
1563 acc13 = fma(a1, b0.s3, acc13);
1564#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1565#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1566 acc20 = fma(a2, b0.s0, acc20);
1567 acc21 = fma(a2, b0.s1, acc21);
1568 acc22 = fma(a2, b0.s2, acc22);
1569 acc23 = fma(a2, b0.s3, acc23);
1570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1572 acc30 = fma(a3, b0.s0, acc30);
1573 acc31 = fma(a3, b0.s1, acc31);
1574 acc32 = fma(a3, b0.s2, acc32);
1575 acc33 = fma(a3, b0.s3, acc33);
1576#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001577
1578 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001579 }
1580
1581 // Compute destination address
1582 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1583
1584 // Multiply by the weight of matrix-matrix product and store the result
1585#if defined(ALPHA)
1586 acc00 = acc00 * ALPHA;
1587 acc01 = acc01 * ALPHA;
1588 acc02 = acc02 * ALPHA;
1589 acc03 = acc03 * ALPHA;
1590#endif // defined(ALPHA)
1591
Gian Marcoae2af742018-02-15 12:35:44 +00001592 // Compute dst address
1593 __global uchar *dst_addr = offset(&dst, 0, 0);
1594
1595 // Add offset for batched GEMM
1596 dst_addr += get_global_id(2) * dst_stride_z;
1597
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001598 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
Gian Marcoae2af742018-02-15 12:35:44 +00001599 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001600
1601#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1602#if defined(ALPHA)
1603 acc10 = acc10 * ALPHA;
1604 acc11 = acc11 * ALPHA;
1605 acc12 = acc12 * ALPHA;
1606 acc13 = acc13 * ALPHA;
1607#endif // defined(ALPHA)
1608 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
Gian Marcoae2af742018-02-15 12:35:44 +00001609 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001610#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1612#if defined(ALPHA)
1613 acc20 = acc20 * ALPHA;
1614 acc21 = acc21 * ALPHA;
1615 acc22 = acc22 * ALPHA;
1616 acc23 = acc23 * ALPHA;
1617#endif // defined(ALPHA)
1618 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
Gian Marcoae2af742018-02-15 12:35:44 +00001619 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1621#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1622#if defined(ALPHA)
1623 acc30 = acc30 * ALPHA;
1624 acc31 = acc31 * ALPHA;
1625 acc32 = acc32 * ALPHA;
1626 acc33 = acc33 * ALPHA;
1627#endif // defined(ALPHA)
1628 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
Gian Marcoae2af742018-02-15 12:35:44 +00001629 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1631}
1632
1633/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1634 *
1635 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1636 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1637 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1638 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1639 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1640 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001641 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1642 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001643 *
1644 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1645 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1646 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1647 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1648 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1649 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1650 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1651 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1652 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1653 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1654 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1655 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1656 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1657 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1658 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1659 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1660 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1661 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1662 */
1663__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1664 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001665 IMAGE_DECLARATION(dst),
1666 uint src0_stride_z,
1667 uint src1_stride_z,
1668 uint dst_stride_z)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001669{
1670 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1671 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1672
1673 // Compute starting address for matrix A and Matrix B
1674 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1675
1676 // Update address for the matrix A
1677 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1678
1679 // Update address for the matrix B
1680 src_addr.s1 += idx * sizeof(float);
1681
Gian Marcoae2af742018-02-15 12:35:44 +00001682 // Add offset for batched GEMM
1683 src_addr.s0 += get_global_id(2) * src0_stride_z;
1684
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001685#if defined(MATRIX_B_DEPTH)
1686 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1687 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1688#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001689 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001690#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001691
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001692 // Initialize accumulators
1693 float acc00 = 0.0f;
1694 float acc01 = 0.0f;
1695
1696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1697 float acc10 = 0.0f;
1698 float acc11 = 0.0f;
1699#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1700#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1701 float acc20 = 0.0f;
1702 float acc21 = 0.0f;
1703#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1704#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1705 float acc30 = 0.0f;
1706 float acc31 = 0.0f;
1707#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1708
1709 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001710 int i = 0;
1711 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001712 {
1713 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001714 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001715
1716 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001717 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1718 src_addr.s1 += src1_stride_y;
1719 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1720 src_addr.s1 += src1_stride_y;
1721 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1722 src_addr.s1 += src1_stride_y;
1723 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1724 src_addr.s1 += src1_stride_y;
1725 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1726 src_addr.s1 += src1_stride_y;
1727 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1728 src_addr.s1 += src1_stride_y;
1729 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1730 src_addr.s1 += src1_stride_y;
1731 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1732 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001733
1734 // Multiply and accumulate
1735 acc00 = fma(a0.s0, b0.s0, acc00);
1736 acc00 = fma(a0.s1, b1.s0, acc00);
1737 acc00 = fma(a0.s2, b2.s0, acc00);
1738 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001739 acc00 = fma(a0.s4, b4.s0, acc00);
1740 acc00 = fma(a0.s5, b5.s0, acc00);
1741 acc00 = fma(a0.s6, b6.s0, acc00);
1742 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001743
1744 acc01 = fma(a0.s0, b0.s1, acc01);
1745 acc01 = fma(a0.s1, b1.s1, acc01);
1746 acc01 = fma(a0.s2, b2.s1, acc01);
1747 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001748 acc01 = fma(a0.s4, b4.s1, acc01);
1749 acc01 = fma(a0.s5, b5.s1, acc01);
1750 acc01 = fma(a0.s6, b6.s1, acc01);
1751 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001752
1753#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001754 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001755 acc10 = fma(a0.s0, b0.s0, acc10);
1756 acc10 = fma(a0.s1, b1.s0, acc10);
1757 acc10 = fma(a0.s2, b2.s0, acc10);
1758 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001759 acc10 = fma(a0.s4, b4.s0, acc10);
1760 acc10 = fma(a0.s5, b5.s0, acc10);
1761 acc10 = fma(a0.s6, b6.s0, acc10);
1762 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001763
1764 acc11 = fma(a0.s0, b0.s1, acc11);
1765 acc11 = fma(a0.s1, b1.s1, acc11);
1766 acc11 = fma(a0.s2, b2.s1, acc11);
1767 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001768 acc11 = fma(a0.s4, b4.s1, acc11);
1769 acc11 = fma(a0.s5, b5.s1, acc11);
1770 acc11 = fma(a0.s6, b6.s1, acc11);
1771 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001774 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001775 acc20 = fma(a0.s0, b0.s0, acc20);
1776 acc20 = fma(a0.s1, b1.s0, acc20);
1777 acc20 = fma(a0.s2, b2.s0, acc20);
1778 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001779 acc20 = fma(a0.s4, b4.s0, acc20);
1780 acc20 = fma(a0.s5, b5.s0, acc20);
1781 acc20 = fma(a0.s6, b6.s0, acc20);
1782 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001783
1784 acc21 = fma(a0.s0, b0.s1, acc21);
1785 acc21 = fma(a0.s1, b1.s1, acc21);
1786 acc21 = fma(a0.s2, b2.s1, acc21);
1787 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001788 acc21 = fma(a0.s4, b4.s1, acc21);
1789 acc21 = fma(a0.s5, b5.s1, acc21);
1790 acc21 = fma(a0.s6, b6.s1, acc21);
1791 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001794 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001795 acc30 = fma(a0.s0, b0.s0, acc30);
1796 acc30 = fma(a0.s1, b1.s0, acc30);
1797 acc30 = fma(a0.s2, b2.s0, acc30);
1798 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001799 acc30 = fma(a0.s4, b4.s0, acc30);
1800 acc30 = fma(a0.s5, b5.s0, acc30);
1801 acc30 = fma(a0.s6, b6.s0, acc30);
1802 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001803
1804 acc31 = fma(a0.s0, b0.s1, acc31);
1805 acc31 = fma(a0.s1, b1.s1, acc31);
1806 acc31 = fma(a0.s2, b2.s1, acc31);
1807 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001808 acc31 = fma(a0.s4, b4.s1, acc31);
1809 acc31 = fma(a0.s5, b5.s1, acc31);
1810 acc31 = fma(a0.s6, b6.s1, acc31);
1811 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001812#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001813
1814 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001815 }
1816 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001817 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001818 {
1819 // Load values from matrix A
1820 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1821#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1822 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1823#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1824#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1825 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1826#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1827#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1828 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1829#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1830 // Load values from matrix B
1831 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001832 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001833
1834 // Multiply and accumulate
1835 acc00 = fma(a0, b0.s0, acc00);
1836 acc01 = fma(a0, b0.s1, acc01);
1837#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1838 acc10 = fma(a1, b0.s0, acc10);
1839 acc11 = fma(a1, b0.s1, acc11);
1840#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1841#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1842 acc20 = fma(a2, b0.s0, acc20);
1843 acc21 = fma(a2, b0.s1, acc21);
1844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1845#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1846 acc30 = fma(a3, b0.s0, acc30);
1847 acc31 = fma(a3, b0.s1, acc31);
1848#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001849
1850 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001851 }
1852
1853 // Compute destination address
1854 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1855
Gian Marcoae2af742018-02-15 12:35:44 +00001856 // Compute dst address
1857 __global uchar *dst_addr = offset(&dst, 0, 0);
1858
1859 // Add offset for batched GEMM
1860 dst_addr += get_global_id(2) * dst_stride_z;
1861
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001862 // Multiply by the weight of matrix-matrix product and store the result
1863#if defined(ALPHA)
1864 acc00 = acc00 * ALPHA;
1865 acc01 = acc01 * ALPHA;
1866#endif // defined(ALPHA)
1867 float2 acc0 = ((float2)(acc00, acc01));
Gian Marcoae2af742018-02-15 12:35:44 +00001868 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001869#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1870#if defined(ALPHA)
1871 acc10 = acc10 * ALPHA;
1872 acc11 = acc11 * ALPHA;
1873#endif // defined(ALPHA)
1874 float2 acc1 = ((float2)(acc10, acc11));
Gian Marcoae2af742018-02-15 12:35:44 +00001875 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001876#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1878#if defined(ALPHA)
1879 acc20 = acc20 * ALPHA;
1880 acc21 = acc21 * ALPHA;
1881#endif // defined(ALPHA)
1882 float2 acc2 = ((float2)(acc20, acc21));
Gian Marcoae2af742018-02-15 12:35:44 +00001883 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1885#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1886#if defined(ALPHA)
1887 acc30 = acc30 * ALPHA;
1888 acc31 = acc31 * ALPHA;
1889#endif // defined(ALPHA)
1890 float2 acc3 = (float2)(acc30, acc31);
Gian Marcoae2af742018-02-15 12:35:44 +00001891 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001892#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1893}
1894
Gian Marco Iodicefd683112018-04-17 09:52:44 +01001895/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1896 *
1897 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
1898 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1899 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1900 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1901 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
1902 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1903 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1904 *
1905 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1906 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1907 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1908 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1909 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1910 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1911 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1912 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1913 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1914 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1915 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1916 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1917 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1918 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1919 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1920 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1921 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1922 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1923 */
1924__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
1925 IMAGE_DECLARATION(src1),
1926 IMAGE_DECLARATION(dst),
1927 uint src0_stride_z,
1928 uint src1_stride_z,
1929 uint dst_stride_z)
1930{
1931 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1932
1933 // Compute starting address for matrix A and Matrix B
1934 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1935
1936 // Update address for the matrix A
1937 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1938
1939 // Update address for the matrix B
1940 src_addr.s1 += idx * sizeof(half);
1941
1942 // Add offset for batched GEMM
1943 src_addr.s0 += get_global_id(2) * src0_stride_z;
1944
1945#if defined(MATRIX_B_DEPTH)
1946 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1947 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1948#else // defined(MATRIX_B_DEPTH)
1949 src_addr.s1 += get_global_id(2) * src1_stride_z;
1950#endif // defined(MATRIX_B_DEPTH)
1951
1952 half8 acc0 = 0.0h;
1953#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1954 half8 acc1 = 0.0h;
1955#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1956#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1957 half8 acc2 = 0.0h;
1958#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1959#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1960 half8 acc3 = 0.0h;
1961#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1962
1963 int i = 0;
1964 for(; i <= ((int)COLS_A - 4); i += 4)
1965 {
1966 // Load values from matrix A
1967 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1969 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1971#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1972 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1973#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1975 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1976#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1977 // Load values from matrix B
1978 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
1979 src_addr.s1 += src1_stride_y;
1980
1981 // Accumulate
1982 acc0 = fma(b0, (half8)a0.s0, acc0);
1983#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1984 acc1 = fma(b0, (half8)a1.s0, acc1);
1985#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1986#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1987 acc2 = fma(b0, (half8)a2.s0, acc2);
1988#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1989#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1990 acc3 = fma(b0, (half8)a3.s0, acc3);
1991#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1992
1993 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
1994 src_addr.s1 += src1_stride_y;
1995 acc0 = fma(b0, (half8)a0.s1, acc0);
1996#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1997 acc1 = fma(b0, (half8)a1.s1, acc1);
1998#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1999#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2000 acc2 = fma(b0, (half8)a2.s1, acc2);
2001#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2002#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2003 acc3 = fma(b0, (half8)a3.s1, acc3);
2004#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2005
2006 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2007 src_addr.s1 += src1_stride_y;
2008 acc0 = fma(b0, (half8)a0.s2, acc0);
2009#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2010 acc1 = fma(b0, (half8)a1.s2, acc1);
2011#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2012#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2013 acc2 = fma(b0, (half8)a2.s2, acc2);
2014#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2016 acc3 = fma(b0, (half8)a3.s2, acc3);
2017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2018
2019 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2020 src_addr.s1 += src1_stride_y;
2021 acc0 = fma(b0, (half8)a0.s3, acc0);
2022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2023 acc1 = fma(b0, (half8)a1.s3, acc1);
2024#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2025#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2026 acc2 = fma(b0, (half8)a2.s3, acc2);
2027#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2029 acc3 = fma(b0, (half8)a3.s3, acc3);
2030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2031
2032 src_addr.s0 += 4 * sizeof(half);
2033 }
2034
2035 for(; i < (int)COLS_A; ++i)
2036 {
2037 // Load values from matrix A
2038 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2040 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2043 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2044#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2045#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2046 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2048 // Load values from matrix B
2049 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2050
2051 src_addr += (int2)(sizeof(half), src1_stride_y);
2052
2053 // Accumulate
2054 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2056 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2057#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2058#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2059 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2060#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2061#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2062 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2063#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2064 }
2065
2066 // Compute destination address
2067 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2068
2069 // Compute dst address
2070 __global uchar *dst_addr = offset(&dst, 0, 0);
2071
2072 // Add offset for batched GEMM
2073 dst_addr += get_global_id(2) * dst_stride_z;
2074
2075 // Multiply by the weight of matrix-matrix product and store the result
2076#if defined(ALPHA)
2077 acc0 = acc0 * (half8)ALPHA;
2078#endif // defined(ALPHA)
2079 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2080#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2081#if defined(ALPHA)
2082 acc1 = acc1 * (half8)ALPHA;
2083#endif // defined(ALPHA)
2084 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2085#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2086#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2087#if defined(ALPHA)
2088 acc2 = acc2 * (half8)ALPHA;
2089#endif // defined(ALPHA)
2090 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2091#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2093#if defined(ALPHA)
2094 acc3 = acc3 * (half8)ALPHA;
2095#endif // defined(ALPHA)
2096 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2097#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2098}
2099
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002100#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002101/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002102 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002103 * @note This OpenCL kernel works with fixed point data types QS8
2104 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002105 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002106 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002107 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002108 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2109 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002110 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002111 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002112 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2113 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2114 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2115 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2116 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2117 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2118 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2119 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2120 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2121 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2122 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2123 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2124 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2125 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2126 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2127 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2128 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2129 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002130__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002131 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002132 IMAGE_DECLARATION(dst),
2133 uint src0_stride_z,
2134 uint src1_stride_z,
2135 uint dst_stride_z)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002136{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002137 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002138
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002139 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002140 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002141
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002142 // Update address for the matrix A
2143 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002144
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002145 // Update address for the matrix B
2146 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002147
Gian Marcoae2af742018-02-15 12:35:44 +00002148 // Add offset for batched GEMM
2149 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002150
2151#if defined(MATRIX_B_DEPTH)
2152 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2153 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2154#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002155 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002156#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002157
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002158 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
2159
2160 short8 acc00 = 0;
2161 short8 acc01 = 0;
2162#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2163 short8 acc10 = 0;
2164 short8 acc11 = 0;
2165#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2166#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2167 short8 acc20 = 0;
2168 short8 acc21 = 0;
2169#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2171 short8 acc30 = 0;
2172 short8 acc31 = 0;
2173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2174
2175 // This for loop performs 4 accumulations per iteration
2176 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002177 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002178 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2179#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2180 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2183 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2184#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2185#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2186 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002188 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2189 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002190
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002191 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
2192 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
2193 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2194 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2196 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
2197 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
2198 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2199 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2201#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2202 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
2203 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
2204 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2205 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2206#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2208 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
2209 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
2210 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2211 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002213 }
2214
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002215 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002216 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
2217 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002218 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2219#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2220 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2221#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2223 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2226 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2227#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002228 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
2229
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002230 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
2231 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2232#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2233 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
2234 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
2235#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2237 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
2238 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
2239#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2240#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2241 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
2242 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
2243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002244 }
2245
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002246 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002247 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2248
Gian Marcoae2af742018-02-15 12:35:44 +00002249 // Compute dst address
2250 __global uchar *dst_addr = offset(&dst, 0, 0);
2251
2252 // Add offset for batched GEMM
2253 dst_addr += get_global_id(2) * dst_stride_z;
2254
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002255 // Multiply by the weight of matrix product and store the result
2256 char16 acc_qs8;
2257 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002258#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002259 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002260#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002261 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002262#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2263 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002264#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002265 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002266#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002267 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2270 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002271#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002272 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002273#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002274 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002275#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2276#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2277 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002278#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002279 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002280#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002281 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002282#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002283}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002284
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002285/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002286 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002287 * @note This OpenCL kernel works with fixed point data types QS16
2288 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002289 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002290 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002291 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002292 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2293 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002294 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002295 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002296 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2297 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2298 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2299 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2300 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2301 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2302 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2303 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2304 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2305 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2306 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2307 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2308 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2309 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2310 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2311 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2312 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2313 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002314__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002315 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002316 IMAGE_DECLARATION(dst),
2317 uint src0_stride_z,
2318 uint src1_stride_z,
2319 uint dst_stride_z)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002320{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002321 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002322
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002323 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002324 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002325
2326 // Update address for the matrix A
2327 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2328
2329 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002330 src_addr.s1 += idx * sizeof(short);
2331
Gian Marcoae2af742018-02-15 12:35:44 +00002332 // Add offset for batched GEMM
2333 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002334
2335#if defined(MATRIX_B_DEPTH)
2336 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2337 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2338#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002339 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002340#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002341
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002342 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002343
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002344 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2346 int8 acc1 = 0;
2347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2348#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2349 int8 acc2 = 0;
2350#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2352 int8 acc3 = 0;
2353#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002354
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002355 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002356 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002357 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002358 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2360 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2362#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2363 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2364#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2365#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2366 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2367#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002368 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2369 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002370
2371 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
2372 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002373#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2374 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
2375 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
2376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2378 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
2379 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
2380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2382 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
2383 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
2384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002385 }
2386
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002387 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002388 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
2389 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002390 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2391#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2392 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2393#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2394#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2395 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2396#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2397#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2398 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2399#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002400 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
2401
2402 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2404 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
2405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2407 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
2408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2409#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2410 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
2411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002412 }
2413
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002414 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002415 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2416
Gian Marcoae2af742018-02-15 12:35:44 +00002417 // Compute dst address
2418 __global uchar *dst_addr = offset(&dst, 0, 0);
2419
Gian Marco Iodice81b28c42018-03-29 10:29:36 +01002420 // Add offset for batched GEMM
2421 dst_addr += get_global_id(2) * dst_stride_z;
2422
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002423 // Multiply by the weight of matrix product and store the result
2424 short8 acc_qs16;
2425 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002426#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002427 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002428#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002429 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002430#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2431 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002432#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002433 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002434#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002435 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2437#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2438 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002439#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002440 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002441#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002442 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2445 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002446#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002447 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002448#endif // defined(ALPHA)
Gian Marcoae2af742018-02-15 12:35:44 +00002449 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002450#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002451}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002452#endif // defined(FIXED_POINT_POSITION)
2453#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002454
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002455#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002456/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2457 *
Gian Marco19835e52018-01-30 13:35:54 +00002458 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002459 *
2460 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2461 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2462 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2463 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2464 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2465 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002466 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002467 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2468 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2469 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2470 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2471 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2472 */
2473__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
2474 IMAGE_DECLARATION(dst))
2475{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002476 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002477 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2478 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2479
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002480 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002481 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2482
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002483 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002484 float4 c = vload4(0, (__global float *)src.ptr);
2485
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002486 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002487 float4 out = alpha_ab + (float4)BETA * c;
2488
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002489 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002490 vstore4(out, 0, (__global float *)dst.ptr);
2491}
2492
2493/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2494 *
Gian Marco19835e52018-01-30 13:35:54 +00002495 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002496 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002497 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2498 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2499 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2500 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2501 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2502 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002503 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002504 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2505 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2506 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2507 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2508 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2509 */
2510__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
2511 IMAGE_DECLARATION(dst))
2512{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002513 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002514 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2515 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2516
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002517 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002518 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2519
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002520 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002521 half8 c = vload8(0, (__global half *)src.ptr);
2522
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002523 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002524 half8 out = alpha_ab + (half8)BETA * c;
2525
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002526 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002527 vstore8(out, 0, (__global half *)dst.ptr);
2528}
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002529
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002530#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002531/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2532 *
Gian Marco19835e52018-01-30 13:35:54 +00002533 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002534 *
2535 * @note: BETA must be passed in 8 bit fixed point format
2536 *
2537 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
2538 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2539 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2540 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2541 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2542 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2543 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2544 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2545 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2546 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2547 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2548 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2549 */
2550__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
2551 IMAGE_DECLARATION(dst))
2552{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002553 // Compute source and destination addresses
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002554 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2555 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2556
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002557 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002558 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
2559
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002560 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002561 char16 c = vload16(0, (__global char *)src.ptr);
2562
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002563 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002564 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
2565
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002566 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002567 vstore16(out, 0, (__global char *)dst.ptr);
2568}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002569
2570/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2571 *
Gian Marco19835e52018-01-30 13:35:54 +00002572 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002573 *
2574 * @note: BETA must be passed in 16 bit fixed point format
2575 *
2576 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
2577 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2578 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2579 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2580 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2581 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2582 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2583 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2584 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2585 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2586 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2587 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2588 */
2589__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
2590 IMAGE_DECLARATION(dst))
2591{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002592 // Compute source and destination addresses
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002593 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2594 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2595
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002596 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002597 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
2598
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002599 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002600 short8 c = vload8(0, (__global short *)src.ptr);
2601
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002602 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002603 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
2604
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002605 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01002606 vstore8(out, 0, (__global short *)dst.ptr);
2607}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002608#endif // defined(FIXED_POINT_POSITION)
2609#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002610
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002611#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002612/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2613 *
Gian Marco19835e52018-01-30 13:35:54 +00002614 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002615 *
Gian Marco19835e52018-01-30 13:35:54 +00002616 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002617 *
2618 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2619 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2620 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2621 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2622 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2623 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002624 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002625 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2626 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2627 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2628 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2629 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2630 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2631 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002632 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002633 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2634 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2635 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2636 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2637 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2638 */
2639__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2640 TENSOR3D_DECLARATION(src1),
2641 IMAGE_DECLARATION(dst))
2642{
2643 int idx = get_global_id(0) * 4;
2644 int idy = get_global_id(1);
2645
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002646 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002647 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2648 src_addr.s1 += idx * sizeof(float);
2649
2650 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2651
2652 float4 acc = 0.0f;
2653
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002654 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002655 {
2656 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2657 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2658 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2659
2660 acc += b0 * (float4)a0.s0;
2661 acc += b1 * (float4)a0.s1;
2662 }
2663
2664 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2665 {
2666 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2667 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2668
2669 acc += b0 * (float4)a0;
2670 }
2671
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002672 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002673 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2674
2675 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2676}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002677#endif // defined(WIDTH_VECTOR_A)
2678
2679/** This kernel accumulates each row with the biases vector.
2680 *
2681 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2682 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2683 *
2684 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
2685 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2686 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2687 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2688 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2689 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2690 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2691 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2692 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2693 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2694 */
2695#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2696__kernel void gemm_accumulate_biases(
2697 IMAGE_DECLARATION(accum),
2698 VECTOR_DECLARATION(biases))
2699{
2700 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2701 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2702
2703 // Vector size, i.e. number of vector elements.
2704 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2705 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2706 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2707 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
2708#ifdef FIXED_POINT_POSITION
2709 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
2710#else // FIXED_POINT_POSITION
2711 accum_value = biases_value + accum_value;
2712#endif // FIXED_POINT_POSITION
2713 // Store result in the accumulate buffer
2714 VSTORE(VECTOR_SIZE)
2715 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2716}
2717#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)