blob: 5a6efe64b98c4de34d2cb736c9c16c0112b15b9e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco36a0a462018-01-12 10:21:40 +000026#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
27
Gian Marco19835e52018-01-30 13:35:54 +000028#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000029#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000030#elif ELEMENT_SIZE == 2
31#define DATA_TYPE ushort
32#elif ELEMENT_SIZE == 4
33#define DATA_TYPE uint
34#else // ELEMENT_SIZE == 1
35#error "Element size not supported"
36#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000037
38/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039 *
Gian Marco19835e52018-01-30 13:35:54 +000040 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
41 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000042 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010043 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
45 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
46 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
47 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000048 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
49 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010050 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010051 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000053 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000055 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000056 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
57 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
59 */
Gian Marcoae2af742018-02-15 12:35:44 +000060__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
61 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062{
63 uint x = get_global_id(0);
64 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000065 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010067 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000068 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010070 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000071 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
72 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marcoae2af742018-02-15 12:35:44 +000074 // Add offset for batched GEMM
75 dst_addr_in_bytes += z * dst_stride_z;
76
Gian Marco36a0a462018-01-12 10:21:40 +000077 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
78 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079
Gian Marco36a0a462018-01-12 10:21:40 +000080 VSTORE(TRANSPOSE_W)
81 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010082}
Gian Marco36a0a462018-01-12 10:21:40 +000083#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084
Gian Marco36a0a462018-01-12 10:21:40 +000085#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
86
87/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088 *
Gian Marco19835e52018-01-30 13:35:54 +000089 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
90 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
91 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010092 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
94 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
95 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
96 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000097 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
98 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010099 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100100 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100101 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
102 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
103 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
104 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000105 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
106 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100107 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
108 */
Gian Marcoae2af742018-02-15 12:35:44 +0000109__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
110 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111{
Gian Marco36a0a462018-01-12 10:21:40 +0000112 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100113 uint x = get_global_id(0);
114 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000115 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100116
Gian Marcoae2af742018-02-15 12:35:44 +0000117 // Compute address for source tensor
118 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100119
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000120 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000121 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
122 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123
Gian Marcoae2af742018-02-15 12:35:44 +0000124 // Add offset for batched GEMM
125 dst_addr_in_bytes += z * dst_stride_z;
126
127 __global uchar *input_ptr = src.ptr;
128
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000129 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000130 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000131 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000132 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000133 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000134 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000135 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000136 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000137 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100138
Gian Marco36a0a462018-01-12 10:21:40 +0000139 VEC_DATA_TYPE(DATA_TYPE, 4)
140 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
141 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142
Gian Marco36a0a462018-01-12 10:21:40 +0000143 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
144 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100145
Gian Marco36a0a462018-01-12 10:21:40 +0000146 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
147 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100148
Gian Marco36a0a462018-01-12 10:21:40 +0000149 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
150 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100151}
Gian Marco36a0a462018-01-12 10:21:40 +0000152#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100153
Gian Marco36a0a462018-01-12 10:21:40 +0000154#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100156 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157 *
Gian Marco19835e52018-01-30 13:35:54 +0000158 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
159 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
160 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000161 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
162 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100163 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000164 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
165 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
166 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
167 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
168 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
169 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100170 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
171 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
172 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
173 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
174 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
175 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100176 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100177 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
178 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
179 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
180 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
181 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100182 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000184 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000186 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100187 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000188 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
189 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
190 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100191 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100193__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
194 IMAGE_DECLARATION(src1),
195 IMAGE_DECLARATION(dst),
196 uint src0_stride_z,
197 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000198 uint dst_stride_z
199#if defined(REINTERPRET_OUTPUT_AS_3D)
200 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100201 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000202#endif // REINTERPRET_OUTPUT_AS_3D
203 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204{
Gian Marco36a0a462018-01-12 10:21:40 +0000205 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
206 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000207 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100208
Gian Marco36a0a462018-01-12 10:21:40 +0000209 // Offset
210 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
211 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212
Gian Marco36a0a462018-01-12 10:21:40 +0000213 // src_addr_a = address of matrix A
214 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000215 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
216 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
217
218#if defined(MATRIX_B_DEPTH)
219 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
220 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
221#else // defined(MATRIX_B_DEPTH)
222 src1_addr_in_bytes += z * src1_stride_z;
223#endif // defined(MATRIX_B_DEPTH)
224
225 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
226 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100227
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000228 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000229 __global float *src_end_addr_b = src_addr_b + COLS_B;
230
231 src_addr_a += offset_row_a;
232 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100233
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000234 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100235 float4 c00 = 0.0f;
236 float4 c10 = 0.0f;
237 float4 c20 = 0.0f;
238 float4 c30 = 0.0f;
239
Gian Marco36a0a462018-01-12 10:21:40 +0000240 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000242 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000243 float4 a0 = vload4(0, src_addr_a);
244 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245
246 c00 += (float4)a0.s0 * b0;
247 c10 += (float4)a0.s1 * b0;
248 c20 += (float4)a0.s2 * b0;
249 c30 += (float4)a0.s3 * b0;
250
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000251 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000252 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
253 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100254
255 c00 += (float4)a0.s0 * b0;
256 c10 += (float4)a0.s1 * b0;
257 c20 += (float4)a0.s2 * b0;
258 c30 += (float4)a0.s3 * b0;
259 }
260
Gian Marco36a0a462018-01-12 10:21:40 +0000261 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100262 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000263 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000264 float4 a0 = vload4(0, src_addr_a);
265 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266
267 c00 += (float4)a0.s0 * b0;
268 c10 += (float4)a0.s1 * b0;
269 c20 += (float4)a0.s2 * b0;
270 c30 += (float4)a0.s3 * b0;
271 }
272
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000273 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100274 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
275
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000276#if defined(ALPHA)
277 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278 c00 = c00 * (float4)ALPHA;
279 c10 = c10 * (float4)ALPHA;
280 c20 = c20 * (float4)ALPHA;
281 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000282#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100283
Gian Marcoae2af742018-02-15 12:35:44 +0000284 // Compute dst address
285 __global uchar *dst_addr = offset(&dst, 0, 0);
286
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000287#if defined(REINTERPRET_OUTPUT_AS_3D)
288 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100289 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000290 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100291 // | |
292 // | plane0 |
293 // | |
294 // |__________________|
295 // |******************|
296 // | cross_plane_pad |
297 // |******************|
298 // | |
299 // | plane1 |
300 // | |
301 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000302
303 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
304 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
305 zout = min(DEPTH_GEMM3D - 1, zout);
306
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100307 // Add offset due to the cross plane paddings
308 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000309
310 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
311 // multiply dst_stride_z by DEPTH_GEMM3D
312 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
313
314 // Store 4x4 block
315 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
316 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
317 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
318 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
319
320#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000321 // Add offset for batched GEMM
322 dst_addr += z * dst_stride_z;
323
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000324 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000325 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
326 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
327 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
328 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000329#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100330}
331
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000332/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100333 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100334 *
Gian Marco19835e52018-01-30 13:35:54 +0000335 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
336 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
337 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000338 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
339 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
340 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000342 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
343 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
344 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
345 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
346 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
347 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100348 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
349 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
350 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
351 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
352 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
353 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100354 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100355 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
356 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
357 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
358 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
359 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100360 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100361 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000362 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100363 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000364 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100365 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000366 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
367 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
368 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100369 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100370 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100371__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
372 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000373 IMAGE_DECLARATION(dst),
374 uint src0_stride_z,
375 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000376 uint dst_stride_z
377#if defined(REINTERPRET_OUTPUT_AS_3D)
378 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100379 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000380#endif // REINTERPRET_OUTPUT_AS_3D
381 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100382{
Gian Marco36a0a462018-01-12 10:21:40 +0000383 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
384 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000385 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000386
387 // Offset
388 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
389 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
390
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100391 // src_addr_a = address of matrix A
392 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000393 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
394 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
395
396#if defined(MATRIX_B_DEPTH)
397 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
398 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
399#else // defined(MATRIX_B_DEPTH)
400 src1_addr_in_bytes += z * src1_stride_z;
401#endif // defined(MATRIX_B_DEPTH)
402
403 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
404 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100405
Gian Marco36a0a462018-01-12 10:21:40 +0000406 src_addr_a += offset_row_a;
407 src_addr_b += offset_row_b;
408
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409 // Reset accumulators
410 float c00 = 0.0f;
411 float c01 = 0.0f;
412 float c02 = 0.0f;
413 float c03 = 0.0f;
414 float c10 = 0.0f;
415 float c11 = 0.0f;
416 float c12 = 0.0f;
417 float c13 = 0.0f;
418 float c20 = 0.0f;
419 float c21 = 0.0f;
420 float c22 = 0.0f;
421 float c23 = 0.0f;
422 float c30 = 0.0f;
423 float c31 = 0.0f;
424 float c32 = 0.0f;
425 float c33 = 0.0f;
426
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100427#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
428
429 int i = 0;
430 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431 {
432 // Load values from matrix A (interleaved) and matrix B (transposed)
433 float4 a0 = vload4(0, src_addr_a);
434 float4 b0 = vload4(0, src_addr_b);
435
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100436 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
437 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100438
439 c00 = fma(a0.s0, b0.s0, c00);
440 c01 = fma(a0.s0, b0.s1, c01);
441 c02 = fma(a0.s0, b0.s2, c02);
442 c03 = fma(a0.s0, b0.s3, c03);
443
444 c10 = fma(a0.s1, b0.s0, c10);
445 c11 = fma(a0.s1, b0.s1, c11);
446 c12 = fma(a0.s1, b0.s2, c12);
447 c13 = fma(a0.s1, b0.s3, c13);
448
449 c20 = fma(a0.s2, b0.s0, c20);
450 c21 = fma(a0.s2, b0.s1, c21);
451 c22 = fma(a0.s2, b0.s2, c22);
452 c23 = fma(a0.s2, b0.s3, c23);
453
454 c30 = fma(a0.s3, b0.s0, c30);
455 c31 = fma(a0.s3, b0.s1, c31);
456 c32 = fma(a0.s3, b0.s2, c32);
457 c33 = fma(a0.s3, b0.s3, c33);
458
459 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100460 a0 = vload4(0, src_addr_a);
461 b0 = vload4(0, src_addr_b);
462
463 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
464 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100465
466 c00 = fma(a0.s0, b0.s0, c00);
467 c01 = fma(a0.s0, b0.s1, c01);
468 c02 = fma(a0.s0, b0.s2, c02);
469 c03 = fma(a0.s0, b0.s3, c03);
470
471 c10 = fma(a0.s1, b0.s0, c10);
472 c11 = fma(a0.s1, b0.s1, c11);
473 c12 = fma(a0.s1, b0.s2, c12);
474 c13 = fma(a0.s1, b0.s3, c13);
475
476 c20 = fma(a0.s2, b0.s0, c20);
477 c21 = fma(a0.s2, b0.s1, c21);
478 c22 = fma(a0.s2, b0.s2, c22);
479 c23 = fma(a0.s2, b0.s3, c23);
480
481 c30 = fma(a0.s3, b0.s0, c30);
482 c31 = fma(a0.s3, b0.s1, c31);
483 c32 = fma(a0.s3, b0.s2, c32);
484 c33 = fma(a0.s3, b0.s3, c33);
485
486 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100487 a0 = vload4(0, src_addr_a);
488 b0 = vload4(0, src_addr_b);
489
490 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
491 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
492
493 c00 = fma(a0.s0, b0.s0, c00);
494 c01 = fma(a0.s0, b0.s1, c01);
495 c02 = fma(a0.s0, b0.s2, c02);
496 c03 = fma(a0.s0, b0.s3, c03);
497
498 c10 = fma(a0.s1, b0.s0, c10);
499 c11 = fma(a0.s1, b0.s1, c11);
500 c12 = fma(a0.s1, b0.s2, c12);
501 c13 = fma(a0.s1, b0.s3, c13);
502
503 c20 = fma(a0.s2, b0.s0, c20);
504 c21 = fma(a0.s2, b0.s1, c21);
505 c22 = fma(a0.s2, b0.s2, c22);
506 c23 = fma(a0.s2, b0.s3, c23);
507
508 c30 = fma(a0.s3, b0.s0, c30);
509 c31 = fma(a0.s3, b0.s1, c31);
510 c32 = fma(a0.s3, b0.s2, c32);
511 c33 = fma(a0.s3, b0.s3, c33);
512
513 // Load values from matrix A (interleaved) and matrix B (transposed)
514 a0 = vload4(0, src_addr_a);
515 b0 = vload4(0, src_addr_b);
516
517 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
518 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519
520 c00 = fma(a0.s0, b0.s0, c00);
521 c01 = fma(a0.s0, b0.s1, c01);
522 c02 = fma(a0.s0, b0.s2, c02);
523 c03 = fma(a0.s0, b0.s3, c03);
524
525 c10 = fma(a0.s1, b0.s0, c10);
526 c11 = fma(a0.s1, b0.s1, c11);
527 c12 = fma(a0.s1, b0.s2, c12);
528 c13 = fma(a0.s1, b0.s3, c13);
529
530 c20 = fma(a0.s2, b0.s0, c20);
531 c21 = fma(a0.s2, b0.s1, c21);
532 c22 = fma(a0.s2, b0.s2, c22);
533 c23 = fma(a0.s2, b0.s3, c23);
534
535 c30 = fma(a0.s3, b0.s0, c30);
536 c31 = fma(a0.s3, b0.s1, c31);
537 c32 = fma(a0.s3, b0.s2, c32);
538 c33 = fma(a0.s3, b0.s3, c33);
539 }
540
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100541 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 {
543 // Load values from matrix A (interleaved) and matrix B (transposed)
544 float4 a0 = vload4(0, src_addr_a);
545 float4 b0 = vload4(0, src_addr_b);
546
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100547 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
548 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
549
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100550 c00 = fma(a0.s0, b0.s0, c00);
551 c01 = fma(a0.s0, b0.s1, c01);
552 c02 = fma(a0.s0, b0.s2, c02);
553 c03 = fma(a0.s0, b0.s3, c03);
554
555 c10 = fma(a0.s1, b0.s0, c10);
556 c11 = fma(a0.s1, b0.s1, c11);
557 c12 = fma(a0.s1, b0.s2, c12);
558 c13 = fma(a0.s1, b0.s3, c13);
559
560 c20 = fma(a0.s2, b0.s0, c20);
561 c21 = fma(a0.s2, b0.s1, c21);
562 c22 = fma(a0.s2, b0.s2, c22);
563 c23 = fma(a0.s2, b0.s3, c23);
564
565 c30 = fma(a0.s3, b0.s0, c30);
566 c31 = fma(a0.s3, b0.s1, c31);
567 c32 = fma(a0.s3, b0.s2, c32);
568 c33 = fma(a0.s3, b0.s3, c33);
569 }
570
571 // Compute destination address
572 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
573
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000574#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100575 // Multiply by the weight of matrix product
576 c00 = c00 * ALPHA;
577 c01 = c01 * ALPHA;
578 c02 = c02 * ALPHA;
579 c03 = c03 * ALPHA;
580 c10 = c10 * ALPHA;
581 c11 = c11 * ALPHA;
582 c12 = c12 * ALPHA;
583 c13 = c13 * ALPHA;
584 c20 = c20 * ALPHA;
585 c21 = c21 * ALPHA;
586 c22 = c22 * ALPHA;
587 c23 = c23 * ALPHA;
588 c30 = c30 * ALPHA;
589 c31 = c31 * ALPHA;
590 c32 = c32 * ALPHA;
591 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000592#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100593
Gian Marcoae2af742018-02-15 12:35:44 +0000594 // Compute dst address
595 __global uchar *dst_addr = offset(&dst, 0, 0);
596
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000597#if defined(REINTERPRET_OUTPUT_AS_3D)
598 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100599 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000600 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100601 // | |
602 // | plane0 |
603 // | |
604 // |__________________|
605 // |******************|
606 // | cross_plane_pad |
607 // |******************|
608 // | |
609 // | plane1 |
610 // | |
611 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000612
613 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
614 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
615 zout = min(DEPTH_GEMM3D - 1, zout);
616
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100617 // Add offset due to the cross plane paddings
618 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000619
620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
621 // multiply dst_stride_z by DEPTH_GEMM3D
622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
623
624 // Store 4x4 block
625 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
626 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
627 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
628 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
629
630#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000631 // Add offset for batched GEMM
632 dst_addr += z * dst_stride_z;
633
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000635 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
636 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
637 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
638 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000639#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100640}
641
Georgios Pinitas84225582018-05-14 12:00:05 +0100642// Undefine local defines
643#undef COLS_MTX_B
644
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100645#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100646/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100647 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100648 *
Gian Marco19835e52018-01-30 13:35:54 +0000649 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
650 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
651 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000652 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
653 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100654 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000655 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
656 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
657 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
658 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
659 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
660 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100661 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
662 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
663 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
664 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
665 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
666 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100667 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100668 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
669 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
670 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
671 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
672 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100673 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100674 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000675 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100676 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000677 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100678 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000679 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
680 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
681 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100682 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100683 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100684__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
685 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000686 IMAGE_DECLARATION(dst),
687 uint src0_stride_z,
688 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000689 uint dst_stride_z
690#if defined(REINTERPRET_OUTPUT_AS_3D)
691 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100692 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000693#endif // REINTERPRET_OUTPUT_AS_3D
694 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100695{
Gian Marco36a0a462018-01-12 10:21:40 +0000696 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
697 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000698 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100699
Gian Marco36a0a462018-01-12 10:21:40 +0000700 // Offset
701 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
702 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703
Gian Marco36a0a462018-01-12 10:21:40 +0000704 // src_addr_a = address of matrix A
705 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000706 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
707 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
708
709#if defined(MATRIX_B_DEPTH)
710 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
711 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
712#else // defined(MATRIX_B_DEPTH)
713 src1_addr_in_bytes += z * src1_stride_z;
714#endif // defined(MATRIX_B_DEPTH)
715
716 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
717 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000719 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000720 __global half *src_end_addr_b = src_addr_b + COLS_B;
721
722 src_addr_a += offset_row_a;
723 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100724
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000725 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726 half8 c00 = 0.0f;
727 half8 c10 = 0.0f;
728 half8 c20 = 0.0f;
729 half8 c30 = 0.0f;
730
Gian Marco36a0a462018-01-12 10:21:40 +0000731 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100732 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000733 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000734 half4 a0 = vload4(0, src_addr_a);
735 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736
737 c00 += (half8)a0.s0 * b0;
738 c10 += (half8)a0.s1 * b0;
739 c20 += (half8)a0.s2 * b0;
740 c30 += (half8)a0.s3 * b0;
741
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000742 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000743 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
744 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100745
746 c00 += (half8)a0.s0 * b0;
747 c10 += (half8)a0.s1 * b0;
748 c20 += (half8)a0.s2 * b0;
749 c30 += (half8)a0.s3 * b0;
750 }
751
Gian Marco36a0a462018-01-12 10:21:40 +0000752 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100753 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000754 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000755 half4 a0 = vload4(0, src_addr_a);
756 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100757
758 c00 += (half8)a0.s0 * b0;
759 c10 += (half8)a0.s1 * b0;
760 c20 += (half8)a0.s2 * b0;
761 c30 += (half8)a0.s3 * b0;
762 }
763
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000764 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100765 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
766
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000767#if defined(ALPHA)
768 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 c00 = c00 * (half8)ALPHA;
770 c10 = c10 * (half8)ALPHA;
771 c20 = c20 * (half8)ALPHA;
772 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000773#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100774
Gian Marcoae2af742018-02-15 12:35:44 +0000775 // Compute dst address
776 __global uchar *dst_addr = offset(&dst, 0, 0);
777
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000778#if defined(REINTERPRET_OUTPUT_AS_3D)
779 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100780 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000781 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100782 // | |
783 // | plane0 |
784 // | |
785 // |__________________|
786 // |******************|
787 // | cross_plane_pad |
788 // |******************|
789 // | |
790 // | plane1 |
791 // | |
792 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000793
794 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
795 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
796 zout = min(DEPTH_GEMM3D - 1, zout);
797
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100798 // Add offset due to the cross plane paddings
799 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000800
801 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
802 // multiply dst_stride_z by DEPTH_GEMM3D
803 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
804
805 // Store 4x8 block
806 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
807 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
808 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
809 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
810
811#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000812 // Add offset for batched GEMM
813 dst_addr += z * dst_stride_z;
814
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000815 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000816 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
817 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
818 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
819 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000820#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100821}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100822
823/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
824 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
825 *
826 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
827 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
828 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
829 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
830 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
831 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000832 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
833 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
834 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
835 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
836 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
837 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100838 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
839 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
840 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
841 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
842 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
843 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
844 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
845 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
846 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
847 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
848 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
849 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
850 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
851 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
852 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
853 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
854 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
855 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100856 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100857 */
858__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
859 IMAGE_DECLARATION(src1),
860 IMAGE_DECLARATION(dst),
861 uint src0_stride_z,
862 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000863 uint dst_stride_z
864#if defined(REINTERPRET_OUTPUT_AS_3D)
865 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100866 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000867#endif // REINTERPRET_OUTPUT_AS_3D
868 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100869{
870 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
871 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
872 int z = get_global_id(2);
873
874 // Offset
875 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
876 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
877
878 // src_addr_a = address of matrix A
879 // src_addr_b = address of matrix B
880 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
881 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
882
883#if defined(MATRIX_B_DEPTH)
884 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
885 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
886#else // defined(MATRIX_B_DEPTH)
887 src1_addr_in_bytes += z * src1_stride_z;
888#endif // defined(MATRIX_B_DEPTH)
889
890 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
891 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
892
893 // Compute end row address for matrix B
894 __global half *src_end_addr_b = src_addr_b + COLS_B;
895
896 src_addr_a += offset_row_a;
897 src_addr_b += offset_row_b;
898
899 // Reset accumulators
900 half8 c00 = 0.0f;
901 half8 c10 = 0.0f;
902 half8 c20 = 0.0f;
903 half8 c30 = 0.0f;
904
905#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
906
907 int i = 0;
908 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
909 {
910#if MULT_INTERLEAVE4X4_HEIGHT == 1
911 // Load values from matrix A (interleaved) and matrix B (transposed)
912 half8 a0 = vload8(0, src_addr_a);
913 half8 b0 = vload8(0, src_addr_b);
914
915 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
916 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
917
918 c00 = fma((half8)a0.s0, b0, c00);
919 c10 = fma((half8)a0.s1, b0, c10);
920 c20 = fma((half8)a0.s2, b0, c20);
921 c30 = fma((half8)a0.s3, b0, c30);
922
923 // Load values from matrix B (transposed)
924 b0 = vload8(0, src_addr_b);
925
926 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
927
928 c00 = fma((half8)a0.s4, b0, c00);
929 c10 = fma((half8)a0.s5, b0, c10);
930 c20 = fma((half8)a0.s6, b0, c20);
931 c30 = fma((half8)a0.s7, b0, c30);
932
933 // Load values from matrix A (interleaved) and matrix B (transposed)
934 a0 = vload8(0, src_addr_a);
935 b0 = vload8(0, src_addr_b);
936
937 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
938 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
939
940 c00 = fma((half8)a0.s0, b0, c00);
941 c10 = fma((half8)a0.s1, b0, c10);
942 c20 = fma((half8)a0.s2, b0, c20);
943 c30 = fma((half8)a0.s3, b0, c30);
944
945 // Load values from matrix B (transposed)
946 b0 = vload8(0, src_addr_b);
947
948 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
949
950 c00 = fma((half8)a0.s4, b0, c00);
951 c10 = fma((half8)a0.s5, b0, c10);
952 c20 = fma((half8)a0.s6, b0, c20);
953 c30 = fma((half8)a0.s7, b0, c30);
954#else // MULT_INTERLEAVE4X4_HEIGHT == 1
955 // Load values from matrix A (interleaved) and matrix B (transposed)
956 half4 a0 = vload4(0, src_addr_a);
957 half8 b0 = vload8(0, src_addr_b);
958
959 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
960 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
961
962 c00 = fma((half8)a0.s0, b0, c00);
963 c10 = fma((half8)a0.s1, b0, c10);
964 c20 = fma((half8)a0.s2, b0, c20);
965 c30 = fma((half8)a0.s3, b0, c30);
966
967 // Load values from matrix A (interleaved) and matrix B (transposed)
968 a0 = vload4(0, src_addr_a);
969 b0 = vload8(0, src_addr_b);
970
971 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
972 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
973
974 c00 = fma((half8)a0.s0, b0, c00);
975 c10 = fma((half8)a0.s1, b0, c10);
976 c20 = fma((half8)a0.s2, b0, c20);
977 c30 = fma((half8)a0.s3, b0, c30);
978
979 // Load values from matrix A (interleaved) and matrix B (transposed)
980 a0 = vload4(0, src_addr_a);
981 b0 = vload8(0, src_addr_b);
982
983 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
984 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
985
986 c00 = fma((half8)a0.s0, b0, c00);
987 c10 = fma((half8)a0.s1, b0, c10);
988 c20 = fma((half8)a0.s2, b0, c20);
989 c30 = fma((half8)a0.s3, b0, c30);
990
991 // Load values from matrix A (interleaved) and matrix B (transposed)
992 a0 = vload4(0, src_addr_a);
993 b0 = vload8(0, src_addr_b);
994
995 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
996 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
997
998 c00 = fma((half8)a0.s0, b0, c00);
999 c10 = fma((half8)a0.s1, b0, c10);
1000 c20 = fma((half8)a0.s2, b0, c20);
1001 c30 = fma((half8)a0.s3, b0, c30);
1002#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
1003 }
1004
1005 for(; i < (int)(COLS_MTX_B); ++i)
1006 {
1007 // Load values from matrix A (interleaved) and matrix B (transposed)
1008 half4 a0 = vload4(0, src_addr_a);
1009 half8 b0 = vload8(0, src_addr_b);
1010
1011 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1012 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1013
1014 c00 = fma((half8)a0.s0, b0, c00);
1015 c10 = fma((half8)a0.s1, b0, c10);
1016 c20 = fma((half8)a0.s2, b0, c20);
1017 c30 = fma((half8)a0.s3, b0, c30);
1018 }
1019
1020 // Compute destination address
1021 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1022
1023#if defined(ALPHA)
1024 // Multiply by the weight of matrix product
1025 c00 = c00 * (half8)ALPHA;
1026 c10 = c10 * (half8)ALPHA;
1027 c20 = c20 * (half8)ALPHA;
1028 c30 = c30 * (half8)ALPHA;
1029#endif // defined(ALPHA)
1030
1031 // Compute dst address
1032 __global uchar *dst_addr = offset(&dst, 0, 0);
1033
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001034#if defined(REINTERPRET_OUTPUT_AS_3D)
1035 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001036 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001037 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001038 // | |
1039 // | plane0 |
1040 // | |
1041 // |__________________|
1042 // |******************|
1043 // | cross_plane_pad |
1044 // |******************|
1045 // | |
1046 // | plane1 |
1047 // | |
1048 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001049
1050 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1051 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1052 zout = min(DEPTH_GEMM3D - 1, zout);
1053
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001054 // Add offset due to the cross plane paddings
1055 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001056
1057 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1058 // multiply dst_stride_z by DEPTH_GEMM3D
1059 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1060
1061 // Store 4x8 block
1062 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1063 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1064 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1065 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1066
1067#else // defined(REINTERPRET_OUTPUT_AS_3D)
1068 // Add offset for batched GEMM
1069 dst_addr += z * dst_stride_z;
1070
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001071 // Store 4x8 block
1072 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1073 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1074 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1075 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001076#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001077}
Georgios Pinitas84225582018-05-14 12:00:05 +01001078
1079// Undefine local defines
1080#undef COLS_MTX_B
1081
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001082#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001083
Gian Marco36a0a462018-01-12 10:21:40 +00001084#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001085
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001086#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1087#if defined(DATA_TYPE)
1088#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001089/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001090 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001091 * @note This OpenCL kernel works with floating point data types (F16/F32)
1092 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1093 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001094 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001095 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1096 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001097 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001098 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1099 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1100 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1101 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1102 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1103 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001104 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001105 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1106 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1107 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1108 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1109 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001110 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001111 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1112 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1113 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1114 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1115 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001116 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001117 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1118 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1119 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1120 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1121 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001122 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1123 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1124 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001125 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001126 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001127__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1128 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001129 IMAGE_DECLARATION(dst),
1130 uint src0_stride_z,
1131 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001132 uint dst_stride_z
1133#if defined(REINTERPRET_OUTPUT_AS_3D)
1134 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001135 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001136#endif // REINTERPRET_OUTPUT_AS_3D
1137 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001138{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001139 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001140
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001141 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001143
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001144 // Update address for the matrix A
1145 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001146
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001147 // Update address for the matrix B
1148 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001149
Gian Marcoae2af742018-02-15 12:35:44 +00001150 // Add offset for batched GEMM
1151 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001152
1153#if defined(MATRIX_B_DEPTH)
1154 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1155 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1156#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001157 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001158#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001159
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001160 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1161
1162 VECTOR_TYPE acc0 = 0.0f;
1163#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1164 VECTOR_TYPE acc1 = 0.0f;
1165#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1166#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1167 VECTOR_TYPE acc2 = 0.0f;
1168#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1169#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1170 VECTOR_TYPE acc3 = 0.0f;
1171#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1172
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001173 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001174 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001175 // Load values from matrix A
1176 VEC_DATA_TYPE(DATA_TYPE, 2)
1177 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1178#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1179 VEC_DATA_TYPE(DATA_TYPE, 2)
1180 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1183 VEC_DATA_TYPE(DATA_TYPE, 2)
1184 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1185#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1186#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1187 VEC_DATA_TYPE(DATA_TYPE, 2)
1188 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1189#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1190 // Load values from matrix B
1191 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1192 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001193
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001194 // Accumulate
1195 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1196 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1197#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1198 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1199 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1201#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1202 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1203 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1204#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1206 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1207 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1208#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001209 }
1210
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001211 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001213 // Load values from matrix A
1214 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1216 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1218#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1219 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1222 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1224 // Load values from matrix B
1225 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001226
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001227 // Accumulate
1228 acc0 += b0 * (VECTOR_TYPE)a0;
1229#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1230 acc1 += b0 * (VECTOR_TYPE)a1;
1231#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1232#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1233 acc2 += b0 * (VECTOR_TYPE)a2;
1234#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1236 acc3 += b0 * (VECTOR_TYPE)a3;
1237#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001238 }
1239
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001240 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001241 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1242
Gian Marcoae2af742018-02-15 12:35:44 +00001243 // Compute dst address
1244 __global uchar *dst_addr = offset(&dst, 0, 0);
1245
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001246 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001247#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001248 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001249#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001250#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1251 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1252#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1253#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1254 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1255#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1256#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1257 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1258#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1259
1260 int z = get_global_id(2);
1261
1262#if defined(REINTERPRET_OUTPUT_AS_3D)
1263 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001264 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001265 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001266 // | |
1267 // | plane0 |
1268 // | |
1269 // |__________________|
1270 // |******************|
1271 // | cross_plane_pad |
1272 // |******************|
1273 // | |
1274 // | plane1 |
1275 // | |
1276 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001277
1278 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1279 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1280 zout = min(DEPTH_GEMM3D - 1, zout);
1281
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001282 // Add offset due to the cross plane paddings
1283 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001284
1285 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1286 // multiply dst_stride_z by DEPTH_GEMM3D
1287 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1288
1289 // Store output block
1290 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1291 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
1292#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1293 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1294 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
1295#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1296#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1297 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1298 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
1299#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1300#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1301 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1302 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
1303#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1304
1305#else // defined(REINTERPRET_OUTPUT_AS_3D)
1306 // Add offset for batched GEMM
1307 dst_addr += z * dst_stride_z;
1308
1309 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001310 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001311 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001313 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001314 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001315#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1316#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001317 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001318 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001321 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001322 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001323#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001324#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001325}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001326#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001327
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001328/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001329 *
1330 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1331 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1332 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1333 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1334 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001335 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1336 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001337 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001338 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1339 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1340 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1341 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1342 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1343 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001344 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1345 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1346 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1347 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1348 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1349 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1350 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1351 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1352 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1353 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1354 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1355 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1356 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1357 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1358 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1359 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1360 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1361 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001362 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1363 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1364 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001365 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001366 */
1367__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1368 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001369 IMAGE_DECLARATION(dst),
1370 uint src0_stride_z,
1371 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001372 uint dst_stride_z
1373#if defined(REINTERPRET_OUTPUT_AS_3D)
1374 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001375 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001376#endif // REINTERPRET_OUTPUT_AS_3D
1377 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001378{
1379 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1380
1381 // Compute starting address for matrix A and matrix B
1382 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1383
1384 // Update address for matrix A
1385 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1386
1387 // Update address for matrix B
1388 src_addr.s1 += idx * sizeof(float);
1389
Gian Marcoae2af742018-02-15 12:35:44 +00001390 // Add offset for batched GEMM
1391 src_addr.s0 += get_global_id(2) * src0_stride_z;
1392
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001393#if defined(MATRIX_B_DEPTH)
1394 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1395 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1396#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001397 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001398#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001399
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001400 // Initialize accumulators
1401 float acc00 = 0.0f;
1402 float acc01 = 0.0f;
1403 float acc02 = 0.0f;
1404 float acc03 = 0.0f;
1405
1406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1407 float acc10 = 0.0f;
1408 float acc11 = 0.0f;
1409 float acc12 = 0.0f;
1410 float acc13 = 0.0f;
1411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1412
1413#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1414 float acc20 = 0.0f;
1415 float acc21 = 0.0f;
1416 float acc22 = 0.0f;
1417 float acc23 = 0.0f;
1418#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1419
1420#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1421 float acc30 = 0.0f;
1422 float acc31 = 0.0f;
1423 float acc32 = 0.0f;
1424 float acc33 = 0.0f;
1425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1426
1427 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001428 int i = 0;
1429 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001430 {
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001431 // Load values from matrix A and matrix B
1432 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001434 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001437 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001438#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1439#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001440 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001441#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001442 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1443 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001444
1445 // Multiply and accumulate
1446 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001447 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001448 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001449 acc03 = fma(a0.s0, b0.s3, acc03);
1450
1451#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001452
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001453 acc10 = fma(a1.s0, b0.s0, acc10);
1454 acc11 = fma(a1.s0, b0.s1, acc11);
1455 acc12 = fma(a1.s0, b0.s2, acc12);
1456 acc13 = fma(a1.s0, b0.s3, acc13);
1457
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1459#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001460
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001461 acc20 = fma(a2.s0, b0.s0, acc20);
1462 acc21 = fma(a2.s0, b0.s1, acc21);
1463 acc22 = fma(a2.s0, b0.s2, acc22);
1464 acc23 = fma(a2.s0, b0.s3, acc23);
1465
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001468
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001469 acc30 = fma(a3.s0, b0.s0, acc30);
1470 acc31 = fma(a3.s0, b0.s1, acc31);
1471 acc32 = fma(a3.s0, b0.s2, acc32);
1472 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001473#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001474
1475 // Load values from matrix A and matrix B
1476 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1477 src_addr.s1 += src1_stride_y;
1478
1479 // Multiply and accumulate
1480 acc00 = fma(a0.s1, b0.s0, acc00);
1481 acc01 = fma(a0.s1, b0.s1, acc01);
1482 acc02 = fma(a0.s1, b0.s2, acc02);
1483 acc03 = fma(a0.s1, b0.s3, acc03);
1484
1485#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1486
1487 acc10 = fma(a1.s1, b0.s0, acc10);
1488 acc11 = fma(a1.s1, b0.s1, acc11);
1489 acc12 = fma(a1.s1, b0.s2, acc12);
1490 acc13 = fma(a1.s1, b0.s3, acc13);
1491
1492#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1493#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1494
1495 acc20 = fma(a2.s1, b0.s0, acc20);
1496 acc21 = fma(a2.s1, b0.s1, acc21);
1497 acc22 = fma(a2.s1, b0.s2, acc22);
1498 acc23 = fma(a2.s1, b0.s3, acc23);
1499
1500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1501#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1502
1503 acc30 = fma(a3.s1, b0.s0, acc30);
1504 acc31 = fma(a3.s1, b0.s1, acc31);
1505 acc32 = fma(a3.s1, b0.s2, acc32);
1506 acc33 = fma(a3.s1, b0.s3, acc33);
1507#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1508
1509 // Load values from matrix A and matrix B
1510 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1511 src_addr.s1 += src1_stride_y;
1512
1513 // Multiply and accumulate
1514 acc00 = fma(a0.s2, b0.s0, acc00);
1515 acc01 = fma(a0.s2, b0.s1, acc01);
1516 acc02 = fma(a0.s2, b0.s2, acc02);
1517 acc03 = fma(a0.s2, b0.s3, acc03);
1518
1519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1520
1521 acc10 = fma(a1.s2, b0.s0, acc10);
1522 acc11 = fma(a1.s2, b0.s1, acc11);
1523 acc12 = fma(a1.s2, b0.s2, acc12);
1524 acc13 = fma(a1.s2, b0.s3, acc13);
1525
1526#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1527#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1528
1529 acc20 = fma(a2.s2, b0.s0, acc20);
1530 acc21 = fma(a2.s2, b0.s1, acc21);
1531 acc22 = fma(a2.s2, b0.s2, acc22);
1532 acc23 = fma(a2.s2, b0.s3, acc23);
1533
1534#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1535#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1536
1537 acc30 = fma(a3.s2, b0.s0, acc30);
1538 acc31 = fma(a3.s2, b0.s1, acc31);
1539 acc32 = fma(a3.s2, b0.s2, acc32);
1540 acc33 = fma(a3.s2, b0.s3, acc33);
1541#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1542
1543 // Load values from matrix A and matrix B
1544 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1545 src_addr.s1 += src1_stride_y;
1546
1547 // Multiply and accumulate
1548 acc00 = fma(a0.s3, b0.s0, acc00);
1549 acc01 = fma(a0.s3, b0.s1, acc01);
1550 acc02 = fma(a0.s3, b0.s2, acc02);
1551 acc03 = fma(a0.s3, b0.s3, acc03);
1552
1553#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1554
1555 acc10 = fma(a1.s3, b0.s0, acc10);
1556 acc11 = fma(a1.s3, b0.s1, acc11);
1557 acc12 = fma(a1.s3, b0.s2, acc12);
1558 acc13 = fma(a1.s3, b0.s3, acc13);
1559
1560#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1561#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1562
1563 acc20 = fma(a2.s3, b0.s0, acc20);
1564 acc21 = fma(a2.s3, b0.s1, acc21);
1565 acc22 = fma(a2.s3, b0.s2, acc22);
1566 acc23 = fma(a2.s3, b0.s3, acc23);
1567
1568#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1569#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1570
1571 acc30 = fma(a3.s3, b0.s0, acc30);
1572 acc31 = fma(a3.s3, b0.s1, acc31);
1573 acc32 = fma(a3.s3, b0.s2, acc32);
1574 acc33 = fma(a3.s3, b0.s3, acc33);
1575#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1576
1577 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001578 }
1579
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001580 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001581 {
1582 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001583 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001584#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1585 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1586#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1587#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1588 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1590#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1591 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1592#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1593 // Load values from matrix B
1594 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001595 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001596
1597 // Multiply and accumulate
1598 acc00 = fma(a0, b0.s0, acc00);
1599 acc01 = fma(a0, b0.s1, acc01);
1600 acc02 = fma(a0, b0.s2, acc02);
1601 acc03 = fma(a0, b0.s3, acc03);
1602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1603 acc10 = fma(a1, b0.s0, acc10);
1604 acc11 = fma(a1, b0.s1, acc11);
1605 acc12 = fma(a1, b0.s2, acc12);
1606 acc13 = fma(a1, b0.s3, acc13);
1607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1608#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1609 acc20 = fma(a2, b0.s0, acc20);
1610 acc21 = fma(a2, b0.s1, acc21);
1611 acc22 = fma(a2, b0.s2, acc22);
1612 acc23 = fma(a2, b0.s3, acc23);
1613#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1614#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1615 acc30 = fma(a3, b0.s0, acc30);
1616 acc31 = fma(a3, b0.s1, acc31);
1617 acc32 = fma(a3, b0.s2, acc32);
1618 acc33 = fma(a3, b0.s3, acc33);
1619#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001620
1621 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001622 }
1623
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001624 int z = get_global_id(2);
1625
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001626 // Compute destination address
1627 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1628
1629 // Multiply by the weight of matrix-matrix product and store the result
1630#if defined(ALPHA)
1631 acc00 = acc00 * ALPHA;
1632 acc01 = acc01 * ALPHA;
1633 acc02 = acc02 * ALPHA;
1634 acc03 = acc03 * ALPHA;
1635#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001637 acc10 = acc10 * ALPHA;
1638 acc11 = acc11 * ALPHA;
1639 acc12 = acc12 * ALPHA;
1640 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001643 acc20 = acc20 * ALPHA;
1644 acc21 = acc21 * ALPHA;
1645 acc22 = acc22 * ALPHA;
1646 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001647#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1648#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001649 acc30 = acc30 * ALPHA;
1650 acc31 = acc31 * ALPHA;
1651 acc32 = acc32 * ALPHA;
1652 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001653#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1654
1655 // Compute dst address
1656 __global uchar *dst_addr = offset(&dst, 0, 0);
1657
1658#if defined(REINTERPRET_OUTPUT_AS_3D)
1659 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001660 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001661 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001662 // | |
1663 // | plane0 |
1664 // | |
1665 // |__________________|
1666 // |******************|
1667 // | cross_plane_pad |
1668 // |******************|
1669 // | |
1670 // | plane1 |
1671 // | |
1672 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001673
1674 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1675 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1676 zout = min(DEPTH_GEMM3D - 1, zout);
1677
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001678 // Add offset due to the cross plane paddings
1679 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001680
1681 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1682 // multiply dst_stride_z by DEPTH_GEMM3D
1683 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1684
1685 // Store the output block
1686 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1688 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1690#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1691 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1693#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1694 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001695#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001696
1697#else // defined(REINTERPRET_OUTPUT_AS_3D)
1698 // Add offset for batched GEMM
1699 dst_addr += z * dst_stride_z;
1700
1701 // Store the output block
1702 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1704 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1705#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1706#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1707 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1708#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1709#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1710 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1711#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1712#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001713}
1714
1715/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1716 *
1717 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1718 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1719 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1720 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1721 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1722 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001723 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1724 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001725 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001726 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1727 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1728 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1729 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1730 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1731 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001732 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1733 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1734 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1735 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1736 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1737 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1738 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1739 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1740 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1741 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1742 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1743 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1744 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1745 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1746 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1747 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1748 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1749 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001750 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1751 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1752 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001753 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001754 */
1755__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1756 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001757 IMAGE_DECLARATION(dst),
1758 uint src0_stride_z,
1759 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001760 uint dst_stride_z
1761#if defined(REINTERPRET_OUTPUT_AS_3D)
1762 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001763 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001764#endif // REINTERPRET_OUTPUT_AS_3D
1765 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001766{
1767 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1768 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1769
1770 // Compute starting address for matrix A and Matrix B
1771 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1772
1773 // Update address for the matrix A
1774 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1775
1776 // Update address for the matrix B
1777 src_addr.s1 += idx * sizeof(float);
1778
Gian Marcoae2af742018-02-15 12:35:44 +00001779 // Add offset for batched GEMM
1780 src_addr.s0 += get_global_id(2) * src0_stride_z;
1781
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001782#if defined(MATRIX_B_DEPTH)
1783 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1784 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1785#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001786 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001787#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001788
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001789 // Initialize accumulators
1790 float acc00 = 0.0f;
1791 float acc01 = 0.0f;
1792
1793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1794 float acc10 = 0.0f;
1795 float acc11 = 0.0f;
1796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1798 float acc20 = 0.0f;
1799 float acc21 = 0.0f;
1800#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1801#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1802 float acc30 = 0.0f;
1803 float acc31 = 0.0f;
1804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1805
1806 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001807 int i = 0;
1808 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001809 {
1810 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001811 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001812
1813 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001814 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1815 src_addr.s1 += src1_stride_y;
1816 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1817 src_addr.s1 += src1_stride_y;
1818 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1819 src_addr.s1 += src1_stride_y;
1820 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1821 src_addr.s1 += src1_stride_y;
1822 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1823 src_addr.s1 += src1_stride_y;
1824 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1825 src_addr.s1 += src1_stride_y;
1826 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1827 src_addr.s1 += src1_stride_y;
1828 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1829 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001830
1831 // Multiply and accumulate
1832 acc00 = fma(a0.s0, b0.s0, acc00);
1833 acc00 = fma(a0.s1, b1.s0, acc00);
1834 acc00 = fma(a0.s2, b2.s0, acc00);
1835 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001836 acc00 = fma(a0.s4, b4.s0, acc00);
1837 acc00 = fma(a0.s5, b5.s0, acc00);
1838 acc00 = fma(a0.s6, b6.s0, acc00);
1839 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001840
1841 acc01 = fma(a0.s0, b0.s1, acc01);
1842 acc01 = fma(a0.s1, b1.s1, acc01);
1843 acc01 = fma(a0.s2, b2.s1, acc01);
1844 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001845 acc01 = fma(a0.s4, b4.s1, acc01);
1846 acc01 = fma(a0.s5, b5.s1, acc01);
1847 acc01 = fma(a0.s6, b6.s1, acc01);
1848 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001849
1850#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001851 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001852 acc10 = fma(a0.s0, b0.s0, acc10);
1853 acc10 = fma(a0.s1, b1.s0, acc10);
1854 acc10 = fma(a0.s2, b2.s0, acc10);
1855 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001856 acc10 = fma(a0.s4, b4.s0, acc10);
1857 acc10 = fma(a0.s5, b5.s0, acc10);
1858 acc10 = fma(a0.s6, b6.s0, acc10);
1859 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001860
1861 acc11 = fma(a0.s0, b0.s1, acc11);
1862 acc11 = fma(a0.s1, b1.s1, acc11);
1863 acc11 = fma(a0.s2, b2.s1, acc11);
1864 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001865 acc11 = fma(a0.s4, b4.s1, acc11);
1866 acc11 = fma(a0.s5, b5.s1, acc11);
1867 acc11 = fma(a0.s6, b6.s1, acc11);
1868 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001869#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1870#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001871 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001872 acc20 = fma(a0.s0, b0.s0, acc20);
1873 acc20 = fma(a0.s1, b1.s0, acc20);
1874 acc20 = fma(a0.s2, b2.s0, acc20);
1875 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001876 acc20 = fma(a0.s4, b4.s0, acc20);
1877 acc20 = fma(a0.s5, b5.s0, acc20);
1878 acc20 = fma(a0.s6, b6.s0, acc20);
1879 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001880
1881 acc21 = fma(a0.s0, b0.s1, acc21);
1882 acc21 = fma(a0.s1, b1.s1, acc21);
1883 acc21 = fma(a0.s2, b2.s1, acc21);
1884 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001885 acc21 = fma(a0.s4, b4.s1, acc21);
1886 acc21 = fma(a0.s5, b5.s1, acc21);
1887 acc21 = fma(a0.s6, b6.s1, acc21);
1888 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001889#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1890#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001891 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001892 acc30 = fma(a0.s0, b0.s0, acc30);
1893 acc30 = fma(a0.s1, b1.s0, acc30);
1894 acc30 = fma(a0.s2, b2.s0, acc30);
1895 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001896 acc30 = fma(a0.s4, b4.s0, acc30);
1897 acc30 = fma(a0.s5, b5.s0, acc30);
1898 acc30 = fma(a0.s6, b6.s0, acc30);
1899 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001900
1901 acc31 = fma(a0.s0, b0.s1, acc31);
1902 acc31 = fma(a0.s1, b1.s1, acc31);
1903 acc31 = fma(a0.s2, b2.s1, acc31);
1904 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001905 acc31 = fma(a0.s4, b4.s1, acc31);
1906 acc31 = fma(a0.s5, b5.s1, acc31);
1907 acc31 = fma(a0.s6, b6.s1, acc31);
1908 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001909#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001910
1911 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001912 }
1913 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001914 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001915 {
1916 // Load values from matrix A
1917 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1918#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1919 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1920#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1921#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1922 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1923#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1924#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1925 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1926#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1927 // Load values from matrix B
1928 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001929 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001930
1931 // Multiply and accumulate
1932 acc00 = fma(a0, b0.s0, acc00);
1933 acc01 = fma(a0, b0.s1, acc01);
1934#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1935 acc10 = fma(a1, b0.s0, acc10);
1936 acc11 = fma(a1, b0.s1, acc11);
1937#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1939 acc20 = fma(a2, b0.s0, acc20);
1940 acc21 = fma(a2, b0.s1, acc21);
1941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1942#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1943 acc30 = fma(a3, b0.s0, acc30);
1944 acc31 = fma(a3, b0.s1, acc31);
1945#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001946
1947 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001948 }
1949
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001950 // Multiply by the weight of matrix-matrix product and store the result
1951#if defined(ALPHA)
1952 acc00 = acc00 * ALPHA;
1953 acc01 = acc01 * ALPHA;
1954#endif // defined(ALPHA)
1955#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1956 acc10 = acc10 * ALPHA;
1957 acc11 = acc11 * ALPHA;
1958#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1959#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1960 acc20 = acc20 * ALPHA;
1961 acc21 = acc21 * ALPHA;
1962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1963#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1964 acc30 = acc30 * ALPHA;
1965 acc31 = acc31 * ALPHA;
1966#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1967
1968 int z = get_global_id(2);
1969
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001970 // Compute destination address
1971 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1972
Gian Marcoae2af742018-02-15 12:35:44 +00001973 // Compute dst address
1974 __global uchar *dst_addr = offset(&dst, 0, 0);
1975
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001976#if defined(REINTERPRET_OUTPUT_AS_3D)
1977 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001978 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001979 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001980 // | |
1981 // | plane0 |
1982 // | |
1983 // |__________________|
1984 // |******************|
1985 // | cross_plane_pad |
1986 // |******************|
1987 // | |
1988 // | plane1 |
1989 // | |
1990 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00001991
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001992 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1993 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1994 zout = min(DEPTH_GEMM3D - 1, zout);
1995
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001996 // Add offset due to the cross plane paddings
1997 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001998
1999 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2000 // multiply dst_stride_z by DEPTH_GEMM3D
2001 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2002
2003 // Store the output block
2004 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002005#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002006 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002007#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2008#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002009 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002010#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2011#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002012 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002013#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002014
2015#else // defined(REINTERPRET_OUTPUT_AS_3D)
2016 // Add offset for batched GEMM
2017 dst_addr += z * dst_stride_z;
2018
2019 // Store the output block
2020 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2021#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2022 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2023#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2024#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2025 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2026#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2028 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2029#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2030#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002031}
2032
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002033#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002034/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2035 *
2036 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2037 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2038 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2039 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2040 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2041 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2042 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2043 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002044 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2045 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2046 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2047 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2048 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2049 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002050 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2051 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2052 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2053 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2054 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2055 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2056 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2057 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2058 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2059 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2060 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2061 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2062 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2063 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2064 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2065 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2066 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2067 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002068 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2069 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2070 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002071 * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002072 */
2073__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2074 IMAGE_DECLARATION(src1),
2075 IMAGE_DECLARATION(dst),
2076 uint src0_stride_z,
2077 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002078 uint dst_stride_z
2079#if defined(REINTERPRET_OUTPUT_AS_3D)
2080 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002081 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002082#endif // REINTERPRET_OUTPUT_AS_3D
2083 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002084{
2085 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2086
2087 // Compute starting address for matrix A and Matrix B
2088 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2089
2090 // Update address for the matrix A
2091 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2092
2093 // Update address for the matrix B
2094 src_addr.s1 += idx * sizeof(half);
2095
2096 // Add offset for batched GEMM
2097 src_addr.s0 += get_global_id(2) * src0_stride_z;
2098
2099#if defined(MATRIX_B_DEPTH)
2100 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2101 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2102#else // defined(MATRIX_B_DEPTH)
2103 src_addr.s1 += get_global_id(2) * src1_stride_z;
2104#endif // defined(MATRIX_B_DEPTH)
2105
2106 half8 acc0 = 0.0h;
2107#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2108 half8 acc1 = 0.0h;
2109#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2110#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2111 half8 acc2 = 0.0h;
2112#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2113#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2114 half8 acc3 = 0.0h;
2115#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2116
2117 int i = 0;
2118 for(; i <= ((int)COLS_A - 4); i += 4)
2119 {
2120 // Load values from matrix A
2121 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2122#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2123 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2124#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2125#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2126 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2128#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2129 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2130#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2131 // Load values from matrix B
2132 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2133 src_addr.s1 += src1_stride_y;
2134
2135 // Accumulate
2136 acc0 = fma(b0, (half8)a0.s0, acc0);
2137#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2138 acc1 = fma(b0, (half8)a1.s0, acc1);
2139#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2140#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2141 acc2 = fma(b0, (half8)a2.s0, acc2);
2142#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2143#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2144 acc3 = fma(b0, (half8)a3.s0, acc3);
2145#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2146
2147 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2148 src_addr.s1 += src1_stride_y;
2149 acc0 = fma(b0, (half8)a0.s1, acc0);
2150#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2151 acc1 = fma(b0, (half8)a1.s1, acc1);
2152#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2153#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2154 acc2 = fma(b0, (half8)a2.s1, acc2);
2155#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2156#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2157 acc3 = fma(b0, (half8)a3.s1, acc3);
2158#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2159
2160 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2161 src_addr.s1 += src1_stride_y;
2162 acc0 = fma(b0, (half8)a0.s2, acc0);
2163#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2164 acc1 = fma(b0, (half8)a1.s2, acc1);
2165#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2166#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2167 acc2 = fma(b0, (half8)a2.s2, acc2);
2168#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2169#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2170 acc3 = fma(b0, (half8)a3.s2, acc3);
2171#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2172
2173 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2174 src_addr.s1 += src1_stride_y;
2175 acc0 = fma(b0, (half8)a0.s3, acc0);
2176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2177 acc1 = fma(b0, (half8)a1.s3, acc1);
2178#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2179#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2180 acc2 = fma(b0, (half8)a2.s3, acc2);
2181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2183 acc3 = fma(b0, (half8)a3.s3, acc3);
2184#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2185
2186 src_addr.s0 += 4 * sizeof(half);
2187 }
2188
2189 for(; i < (int)COLS_A; ++i)
2190 {
2191 // Load values from matrix A
2192 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2193#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2194 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2195#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2196#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2197 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2198#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2199#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2200 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2201#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2202 // Load values from matrix B
2203 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2204
2205 src_addr += (int2)(sizeof(half), src1_stride_y);
2206
2207 // Accumulate
2208 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2210 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2213 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2216 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2218 }
2219
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002220 // Multiply by the weight of matrix-matrix product and store the result
2221#if defined(ALPHA)
2222 acc0 = acc0 * (half8)ALPHA;
2223#endif // defined(ALPHA)
2224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2225 acc1 = acc1 * (half8)ALPHA;
2226#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2227#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2228 acc2 = acc2 * (half8)ALPHA;
2229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2230#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2231 acc3 = acc3 * (half8)ALPHA;
2232#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2233
2234 int z = get_global_id(2);
2235
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002236 // Compute destination address
2237 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2238
2239 // Compute dst address
2240 __global uchar *dst_addr = offset(&dst, 0, 0);
2241
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002242#if defined(REINTERPRET_OUTPUT_AS_3D)
2243 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002244 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002245 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002246 // | |
2247 // | plane0 |
2248 // | |
2249 // |__________________|
2250 // |******************|
2251 // | cross_plane_pad |
2252 // |******************|
2253 // | |
2254 // | plane1 |
2255 // | |
2256 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002257
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002258 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2259 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2260 zout = min(DEPTH_GEMM3D - 1, zout);
2261
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002262 // Add offset due to the cross plane paddings
2263 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002264
2265 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2266 // multiply dst_stride_z by DEPTH_GEMM3D
2267 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2268
2269 // Store the output block
2270 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2271#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2272 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2273#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2274#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2275 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2276#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2277#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2278 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2279#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2280
2281#else // defined(REINTERPRET_OUTPUT_AS_3D)
2282 // Add offset for batched GEMM
2283 dst_addr += z * dst_stride_z;
2284
2285 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002286 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002288 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002291 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2292#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002294 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2295#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002296#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002297}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002298#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002299
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002300#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002301
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002302#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002303/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2304 *
Gian Marco19835e52018-01-30 13:35:54 +00002305 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002306 *
2307 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2308 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2309 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2310 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2311 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002312 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2313 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002314 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002315 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002316 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2317 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2318 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2319 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002320 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2321 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002322 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2323 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002324__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
2325 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002326{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002327 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002328 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2329 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002330
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002331 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002332 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2333
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002334 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002335 float4 c = vload4(0, (__global float *)src.ptr);
2336
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002337 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002338 float4 out = alpha_ab + (float4)BETA * c;
2339
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002340 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002341 vstore4(out, 0, (__global float *)dst.ptr);
2342}
2343
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002344#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002345/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2346 *
Gian Marco19835e52018-01-30 13:35:54 +00002347 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002348 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002349 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2350 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2351 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2352 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2353 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002354 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2355 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002356 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002357 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002358 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2359 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2360 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2361 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002362 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2363 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002364 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2365 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002366__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
2367 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002368{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002369 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002370 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2371 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002372
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002373 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002374 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2375
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002376 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002377 half8 c = vload8(0, (__global half *)src.ptr);
2378
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002379 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002380 half8 out = alpha_ab + (half8)BETA * c;
2381
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002382 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002383 vstore8(out, 0, (__global half *)dst.ptr);
2384}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002385#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002386#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002387
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002388#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002389/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2390 *
Gian Marco19835e52018-01-30 13:35:54 +00002391 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002392 *
Gian Marco19835e52018-01-30 13:35:54 +00002393 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002394 *
2395 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2396 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2397 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2398 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2399 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2400 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002401 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002402 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2403 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2404 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2405 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2406 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2407 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2408 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002409 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002410 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2411 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2412 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2413 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2414 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2415 */
2416__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2417 TENSOR3D_DECLARATION(src1),
2418 IMAGE_DECLARATION(dst))
2419{
2420 int idx = get_global_id(0) * 4;
2421 int idy = get_global_id(1);
2422
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002423 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002424 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2425 src_addr.s1 += idx * sizeof(float);
2426
2427 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2428
2429 float4 acc = 0.0f;
2430
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002431 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002432 {
2433 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2434 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2435 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2436
2437 acc += b0 * (float4)a0.s0;
2438 acc += b1 * (float4)a0.s1;
2439 }
2440
2441 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2442 {
2443 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2444 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2445
2446 acc += b0 * (float4)a0;
2447 }
2448
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002449 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002450 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2451
2452 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2453}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002454#endif // defined(WIDTH_VECTOR_A)
2455
2456/** This kernel accumulates each row with the biases vector.
2457 *
2458 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2459 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2460 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002461 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002462 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2463 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2464 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2465 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2466 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2467 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2468 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2469 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2470 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2471 */
2472#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2473__kernel void gemm_accumulate_biases(
2474 IMAGE_DECLARATION(accum),
2475 VECTOR_DECLARATION(biases))
2476{
2477 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2478 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2479
2480 // Vector size, i.e. number of vector elements.
2481 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2482 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2483 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2484 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002485 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002486 // Store result in the accumulate buffer
2487 VSTORE(VECTOR_SIZE)
2488 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2489}
2490#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)