blob: 932e0d681a3b88ddf31c6e037cba881ba6b51f87 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco36a0a462018-01-12 10:21:40 +000026#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
27
Gian Marco19835e52018-01-30 13:35:54 +000028#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000029#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000030#elif ELEMENT_SIZE == 2
31#define DATA_TYPE ushort
32#elif ELEMENT_SIZE == 4
33#define DATA_TYPE uint
34#else // ELEMENT_SIZE == 1
35#error "Element size not supported"
36#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000037
38/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039 *
Gian Marco19835e52018-01-30 13:35:54 +000040 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
41 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000042 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010043 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
45 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
46 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
47 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000048 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
49 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010050 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010051 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000053 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000055 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000056 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
57 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
59 */
Gian Marcoae2af742018-02-15 12:35:44 +000060__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
61 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062{
63 uint x = get_global_id(0);
64 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000065 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010067 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000068 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010070 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000071 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
72 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marcoae2af742018-02-15 12:35:44 +000074 // Add offset for batched GEMM
75 dst_addr_in_bytes += z * dst_stride_z;
76
Gian Marco36a0a462018-01-12 10:21:40 +000077 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
78 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079
Gian Marco36a0a462018-01-12 10:21:40 +000080 VSTORE(TRANSPOSE_W)
81 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010082}
Gian Marco36a0a462018-01-12 10:21:40 +000083#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084
Gian Marco36a0a462018-01-12 10:21:40 +000085#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
86
87/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088 *
Gian Marco19835e52018-01-30 13:35:54 +000089 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
90 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +010091 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
92 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
93 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
94 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
95 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +000096 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010097 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010098 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
99 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
100 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
101 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000102 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
103 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100104 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100105 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100106 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
107 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
108 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
109 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000110 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
111 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100112 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100113 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100114 */
Gian Marcoae2af742018-02-15 12:35:44 +0000115__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100116 TENSOR3D_DECLARATION(dst)
117#if defined(REINTERPRET_INPUT_AS_3D)
118 ,
119 uint cross_plane_pad
120#endif // REINTERPRET_INPUT_AS_3D
121 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100122{
Gian Marco36a0a462018-01-12 10:21:40 +0000123 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100124 uint x = get_global_id(0);
125 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000126 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marcoae2af742018-02-15 12:35:44 +0000128 // Compute address for source tensor
129 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100130
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000131 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000132 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
133 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100134
Gian Marcoae2af742018-02-15 12:35:44 +0000135 // Add offset for batched GEMM
136 dst_addr_in_bytes += z * dst_stride_z;
137
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100138#if defined(REINTERPRET_INPUT_AS_3D)
139 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
140
141 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
142 // in order to take into account the presence of possible cross plane paddings
143 //
144 // | |
145 // | plane0 |
146 // | |
147 // |__________________|
148 // |******************|
149 // | cross_plane_pad |
150 // |******************|
151 // | |
152 // | plane1 |
153 // | |
154 // |__________________|
155
156 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
157 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
158 zin = min(DEPTH_GEMM3D - 1, zin);
159
160 // Add offset due to the cross plane paddings
161 zin *= (cross_plane_pad * src_stride_y);
162
163 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
164 // multiply src_stride_z by DEPTH_GEMM3D
165 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
166
167 // Load values from Matrix A
168 VEC_DATA_TYPE(DATA_TYPE, 4)
169 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
170 VEC_DATA_TYPE(DATA_TYPE, 4)
171 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
172 VEC_DATA_TYPE(DATA_TYPE, 4)
173 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
174 VEC_DATA_TYPE(DATA_TYPE, 4)
175 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
176#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000177 __global uchar *input_ptr = src.ptr;
178
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000179 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000180 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000181 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000182 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000183 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000184 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000185 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000186 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000187 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100188#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100189
Gian Marco36a0a462018-01-12 10:21:40 +0000190 VEC_DATA_TYPE(DATA_TYPE, 4)
191 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
192 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193
Gian Marco36a0a462018-01-12 10:21:40 +0000194 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
195 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196
Gian Marco36a0a462018-01-12 10:21:40 +0000197 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
198 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199
Gian Marco36a0a462018-01-12 10:21:40 +0000200 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
201 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100202}
Gian Marco36a0a462018-01-12 10:21:40 +0000203#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204
Gian Marco36a0a462018-01-12 10:21:40 +0000205#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100207 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100208 *
Gian Marco19835e52018-01-30 13:35:54 +0000209 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
210 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
211 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000212 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
213 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000215 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
216 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
217 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
218 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
219 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
220 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
222 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
223 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
224 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
225 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
226 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100227 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
229 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
230 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
231 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
232 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100233 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000235 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100236 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000237 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000239 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
240 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
241 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100242 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100244__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
245 IMAGE_DECLARATION(src1),
246 IMAGE_DECLARATION(dst),
247 uint src0_stride_z,
248 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000249 uint dst_stride_z
250#if defined(REINTERPRET_OUTPUT_AS_3D)
251 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100252 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000253#endif // REINTERPRET_OUTPUT_AS_3D
254 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255{
Gian Marco36a0a462018-01-12 10:21:40 +0000256 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
257 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000258 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100259
Gian Marco36a0a462018-01-12 10:21:40 +0000260 // Offset
261 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
262 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263
Gian Marco36a0a462018-01-12 10:21:40 +0000264 // src_addr_a = address of matrix A
265 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000266 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
267 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
268
269#if defined(MATRIX_B_DEPTH)
270 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
271 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
272#else // defined(MATRIX_B_DEPTH)
273 src1_addr_in_bytes += z * src1_stride_z;
274#endif // defined(MATRIX_B_DEPTH)
275
276 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
277 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000279 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000280 __global float *src_end_addr_b = src_addr_b + COLS_B;
281
282 src_addr_a += offset_row_a;
283 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000285 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 float4 c00 = 0.0f;
287 float4 c10 = 0.0f;
288 float4 c20 = 0.0f;
289 float4 c30 = 0.0f;
290
Gian Marco36a0a462018-01-12 10:21:40 +0000291 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100292 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000293 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000294 float4 a0 = vload4(0, src_addr_a);
295 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100296
297 c00 += (float4)a0.s0 * b0;
298 c10 += (float4)a0.s1 * b0;
299 c20 += (float4)a0.s2 * b0;
300 c30 += (float4)a0.s3 * b0;
301
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000302 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000303 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
304 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305
306 c00 += (float4)a0.s0 * b0;
307 c10 += (float4)a0.s1 * b0;
308 c20 += (float4)a0.s2 * b0;
309 c30 += (float4)a0.s3 * b0;
310 }
311
Gian Marco36a0a462018-01-12 10:21:40 +0000312 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000314 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000315 float4 a0 = vload4(0, src_addr_a);
316 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100317
318 c00 += (float4)a0.s0 * b0;
319 c10 += (float4)a0.s1 * b0;
320 c20 += (float4)a0.s2 * b0;
321 c30 += (float4)a0.s3 * b0;
322 }
323
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000324 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100325 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
326
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000327#if defined(ALPHA)
328 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100329 c00 = c00 * (float4)ALPHA;
330 c10 = c10 * (float4)ALPHA;
331 c20 = c20 * (float4)ALPHA;
332 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000333#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100334
Gian Marcoae2af742018-02-15 12:35:44 +0000335 // Compute dst address
336 __global uchar *dst_addr = offset(&dst, 0, 0);
337
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000338#if defined(REINTERPRET_OUTPUT_AS_3D)
339 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100340 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000341 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100342 // | |
343 // | plane0 |
344 // | |
345 // |__________________|
346 // |******************|
347 // | cross_plane_pad |
348 // |******************|
349 // | |
350 // | plane1 |
351 // | |
352 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000353
354 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
355 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
356 zout = min(DEPTH_GEMM3D - 1, zout);
357
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100358 // Add offset due to the cross plane paddings
359 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000360
361 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
362 // multiply dst_stride_z by DEPTH_GEMM3D
363 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
364
365 // Store 4x4 block
366 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
367 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
368 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
369 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
370
371#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000372 // Add offset for batched GEMM
373 dst_addr += z * dst_stride_z;
374
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000375 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000376 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
377 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
378 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
379 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000380#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100381}
382
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000383/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100384 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100385 *
Gian Marco19835e52018-01-30 13:35:54 +0000386 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
387 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
388 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000389 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
390 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
391 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100392 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000393 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
394 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
395 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
396 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
397 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
398 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100399 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
400 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
401 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
402 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
403 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
404 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100405 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100406 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
407 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
408 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
409 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
410 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100411 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100412 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000413 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100414 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000415 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100416 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000417 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
418 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
419 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100420 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100421 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100422__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
423 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000424 IMAGE_DECLARATION(dst),
425 uint src0_stride_z,
426 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000427 uint dst_stride_z
428#if defined(REINTERPRET_OUTPUT_AS_3D)
429 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100430 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000431#endif // REINTERPRET_OUTPUT_AS_3D
432 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100433{
Gian Marco36a0a462018-01-12 10:21:40 +0000434 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
435 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000436 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000437
438 // Offset
439 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
440 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
441
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100442 // src_addr_a = address of matrix A
443 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000444 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
445 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
446
447#if defined(MATRIX_B_DEPTH)
448 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
449 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
450#else // defined(MATRIX_B_DEPTH)
451 src1_addr_in_bytes += z * src1_stride_z;
452#endif // defined(MATRIX_B_DEPTH)
453
454 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
455 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100456
Gian Marco36a0a462018-01-12 10:21:40 +0000457 src_addr_a += offset_row_a;
458 src_addr_b += offset_row_b;
459
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100460 // Reset accumulators
461 float c00 = 0.0f;
462 float c01 = 0.0f;
463 float c02 = 0.0f;
464 float c03 = 0.0f;
465 float c10 = 0.0f;
466 float c11 = 0.0f;
467 float c12 = 0.0f;
468 float c13 = 0.0f;
469 float c20 = 0.0f;
470 float c21 = 0.0f;
471 float c22 = 0.0f;
472 float c23 = 0.0f;
473 float c30 = 0.0f;
474 float c31 = 0.0f;
475 float c32 = 0.0f;
476 float c33 = 0.0f;
477
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100478#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
479
480 int i = 0;
481 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482 {
483 // Load values from matrix A (interleaved) and matrix B (transposed)
484 float4 a0 = vload4(0, src_addr_a);
485 float4 b0 = vload4(0, src_addr_b);
486
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100487 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
488 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100489
490 c00 = fma(a0.s0, b0.s0, c00);
491 c01 = fma(a0.s0, b0.s1, c01);
492 c02 = fma(a0.s0, b0.s2, c02);
493 c03 = fma(a0.s0, b0.s3, c03);
494
495 c10 = fma(a0.s1, b0.s0, c10);
496 c11 = fma(a0.s1, b0.s1, c11);
497 c12 = fma(a0.s1, b0.s2, c12);
498 c13 = fma(a0.s1, b0.s3, c13);
499
500 c20 = fma(a0.s2, b0.s0, c20);
501 c21 = fma(a0.s2, b0.s1, c21);
502 c22 = fma(a0.s2, b0.s2, c22);
503 c23 = fma(a0.s2, b0.s3, c23);
504
505 c30 = fma(a0.s3, b0.s0, c30);
506 c31 = fma(a0.s3, b0.s1, c31);
507 c32 = fma(a0.s3, b0.s2, c32);
508 c33 = fma(a0.s3, b0.s3, c33);
509
510 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100511 a0 = vload4(0, src_addr_a);
512 b0 = vload4(0, src_addr_b);
513
514 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
515 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100516
517 c00 = fma(a0.s0, b0.s0, c00);
518 c01 = fma(a0.s0, b0.s1, c01);
519 c02 = fma(a0.s0, b0.s2, c02);
520 c03 = fma(a0.s0, b0.s3, c03);
521
522 c10 = fma(a0.s1, b0.s0, c10);
523 c11 = fma(a0.s1, b0.s1, c11);
524 c12 = fma(a0.s1, b0.s2, c12);
525 c13 = fma(a0.s1, b0.s3, c13);
526
527 c20 = fma(a0.s2, b0.s0, c20);
528 c21 = fma(a0.s2, b0.s1, c21);
529 c22 = fma(a0.s2, b0.s2, c22);
530 c23 = fma(a0.s2, b0.s3, c23);
531
532 c30 = fma(a0.s3, b0.s0, c30);
533 c31 = fma(a0.s3, b0.s1, c31);
534 c32 = fma(a0.s3, b0.s2, c32);
535 c33 = fma(a0.s3, b0.s3, c33);
536
537 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100538 a0 = vload4(0, src_addr_a);
539 b0 = vload4(0, src_addr_b);
540
541 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
542 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
543
544 c00 = fma(a0.s0, b0.s0, c00);
545 c01 = fma(a0.s0, b0.s1, c01);
546 c02 = fma(a0.s0, b0.s2, c02);
547 c03 = fma(a0.s0, b0.s3, c03);
548
549 c10 = fma(a0.s1, b0.s0, c10);
550 c11 = fma(a0.s1, b0.s1, c11);
551 c12 = fma(a0.s1, b0.s2, c12);
552 c13 = fma(a0.s1, b0.s3, c13);
553
554 c20 = fma(a0.s2, b0.s0, c20);
555 c21 = fma(a0.s2, b0.s1, c21);
556 c22 = fma(a0.s2, b0.s2, c22);
557 c23 = fma(a0.s2, b0.s3, c23);
558
559 c30 = fma(a0.s3, b0.s0, c30);
560 c31 = fma(a0.s3, b0.s1, c31);
561 c32 = fma(a0.s3, b0.s2, c32);
562 c33 = fma(a0.s3, b0.s3, c33);
563
564 // Load values from matrix A (interleaved) and matrix B (transposed)
565 a0 = vload4(0, src_addr_a);
566 b0 = vload4(0, src_addr_b);
567
568 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
569 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570
571 c00 = fma(a0.s0, b0.s0, c00);
572 c01 = fma(a0.s0, b0.s1, c01);
573 c02 = fma(a0.s0, b0.s2, c02);
574 c03 = fma(a0.s0, b0.s3, c03);
575
576 c10 = fma(a0.s1, b0.s0, c10);
577 c11 = fma(a0.s1, b0.s1, c11);
578 c12 = fma(a0.s1, b0.s2, c12);
579 c13 = fma(a0.s1, b0.s3, c13);
580
581 c20 = fma(a0.s2, b0.s0, c20);
582 c21 = fma(a0.s2, b0.s1, c21);
583 c22 = fma(a0.s2, b0.s2, c22);
584 c23 = fma(a0.s2, b0.s3, c23);
585
586 c30 = fma(a0.s3, b0.s0, c30);
587 c31 = fma(a0.s3, b0.s1, c31);
588 c32 = fma(a0.s3, b0.s2, c32);
589 c33 = fma(a0.s3, b0.s3, c33);
590 }
591
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100592 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100593 {
594 // Load values from matrix A (interleaved) and matrix B (transposed)
595 float4 a0 = vload4(0, src_addr_a);
596 float4 b0 = vload4(0, src_addr_b);
597
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100598 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
599 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
600
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100601 c00 = fma(a0.s0, b0.s0, c00);
602 c01 = fma(a0.s0, b0.s1, c01);
603 c02 = fma(a0.s0, b0.s2, c02);
604 c03 = fma(a0.s0, b0.s3, c03);
605
606 c10 = fma(a0.s1, b0.s0, c10);
607 c11 = fma(a0.s1, b0.s1, c11);
608 c12 = fma(a0.s1, b0.s2, c12);
609 c13 = fma(a0.s1, b0.s3, c13);
610
611 c20 = fma(a0.s2, b0.s0, c20);
612 c21 = fma(a0.s2, b0.s1, c21);
613 c22 = fma(a0.s2, b0.s2, c22);
614 c23 = fma(a0.s2, b0.s3, c23);
615
616 c30 = fma(a0.s3, b0.s0, c30);
617 c31 = fma(a0.s3, b0.s1, c31);
618 c32 = fma(a0.s3, b0.s2, c32);
619 c33 = fma(a0.s3, b0.s3, c33);
620 }
621
622 // Compute destination address
623 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
624
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000625#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100626 // Multiply by the weight of matrix product
627 c00 = c00 * ALPHA;
628 c01 = c01 * ALPHA;
629 c02 = c02 * ALPHA;
630 c03 = c03 * ALPHA;
631 c10 = c10 * ALPHA;
632 c11 = c11 * ALPHA;
633 c12 = c12 * ALPHA;
634 c13 = c13 * ALPHA;
635 c20 = c20 * ALPHA;
636 c21 = c21 * ALPHA;
637 c22 = c22 * ALPHA;
638 c23 = c23 * ALPHA;
639 c30 = c30 * ALPHA;
640 c31 = c31 * ALPHA;
641 c32 = c32 * ALPHA;
642 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000643#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
Gian Marcoae2af742018-02-15 12:35:44 +0000645 // Compute dst address
646 __global uchar *dst_addr = offset(&dst, 0, 0);
647
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000648#if defined(REINTERPRET_OUTPUT_AS_3D)
649 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100650 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000651 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100652 // | |
653 // | plane0 |
654 // | |
655 // |__________________|
656 // |******************|
657 // | cross_plane_pad |
658 // |******************|
659 // | |
660 // | plane1 |
661 // | |
662 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000663
664 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
665 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
666 zout = min(DEPTH_GEMM3D - 1, zout);
667
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100668 // Add offset due to the cross plane paddings
669 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000670
671 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
672 // multiply dst_stride_z by DEPTH_GEMM3D
673 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
674
675 // Store 4x4 block
676 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
677 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
678 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
679 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
680
681#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000682 // Add offset for batched GEMM
683 dst_addr += z * dst_stride_z;
684
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100685 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000686 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
687 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
688 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
689 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000690#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100691}
692
Georgios Pinitas84225582018-05-14 12:00:05 +0100693// Undefine local defines
694#undef COLS_MTX_B
695
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100696#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100697/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100698 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100699 *
Gian Marco19835e52018-01-30 13:35:54 +0000700 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
701 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
702 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000703 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
704 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100705 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000706 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
707 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
708 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
709 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
710 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
711 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100712 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
713 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
714 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
715 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
716 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
717 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100718 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100719 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
720 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
721 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
722 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
723 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100724 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100725 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000726 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100727 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000728 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100729 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000730 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
731 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
732 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100733 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100734 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100735__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
736 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000737 IMAGE_DECLARATION(dst),
738 uint src0_stride_z,
739 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000740 uint dst_stride_z
741#if defined(REINTERPRET_OUTPUT_AS_3D)
742 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100743 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000744#endif // REINTERPRET_OUTPUT_AS_3D
745 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100746{
Gian Marco36a0a462018-01-12 10:21:40 +0000747 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
748 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000749 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100750
Gian Marco36a0a462018-01-12 10:21:40 +0000751 // Offset
752 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
753 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100754
Gian Marco36a0a462018-01-12 10:21:40 +0000755 // src_addr_a = address of matrix A
756 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000757 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
758 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
759
760#if defined(MATRIX_B_DEPTH)
761 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
762 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
763#else // defined(MATRIX_B_DEPTH)
764 src1_addr_in_bytes += z * src1_stride_z;
765#endif // defined(MATRIX_B_DEPTH)
766
767 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
768 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000770 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000771 __global half *src_end_addr_b = src_addr_b + COLS_B;
772
773 src_addr_a += offset_row_a;
774 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100775
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000776 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100777 half8 c00 = 0.0f;
778 half8 c10 = 0.0f;
779 half8 c20 = 0.0f;
780 half8 c30 = 0.0f;
781
Gian Marco36a0a462018-01-12 10:21:40 +0000782 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100783 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000784 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000785 half4 a0 = vload4(0, src_addr_a);
786 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100787
788 c00 += (half8)a0.s0 * b0;
789 c10 += (half8)a0.s1 * b0;
790 c20 += (half8)a0.s2 * b0;
791 c30 += (half8)a0.s3 * b0;
792
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000793 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000794 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
795 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100796
797 c00 += (half8)a0.s0 * b0;
798 c10 += (half8)a0.s1 * b0;
799 c20 += (half8)a0.s2 * b0;
800 c30 += (half8)a0.s3 * b0;
801 }
802
Gian Marco36a0a462018-01-12 10:21:40 +0000803 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100804 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000805 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000806 half4 a0 = vload4(0, src_addr_a);
807 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100808
809 c00 += (half8)a0.s0 * b0;
810 c10 += (half8)a0.s1 * b0;
811 c20 += (half8)a0.s2 * b0;
812 c30 += (half8)a0.s3 * b0;
813 }
814
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000815 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100816 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
817
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000818#if defined(ALPHA)
819 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100820 c00 = c00 * (half8)ALPHA;
821 c10 = c10 * (half8)ALPHA;
822 c20 = c20 * (half8)ALPHA;
823 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000824#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100825
Gian Marcoae2af742018-02-15 12:35:44 +0000826 // Compute dst address
827 __global uchar *dst_addr = offset(&dst, 0, 0);
828
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000829#if defined(REINTERPRET_OUTPUT_AS_3D)
830 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100831 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000832 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100833 // | |
834 // | plane0 |
835 // | |
836 // |__________________|
837 // |******************|
838 // | cross_plane_pad |
839 // |******************|
840 // | |
841 // | plane1 |
842 // | |
843 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000844
845 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
846 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
847 zout = min(DEPTH_GEMM3D - 1, zout);
848
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100849 // Add offset due to the cross plane paddings
850 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000851
852 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
853 // multiply dst_stride_z by DEPTH_GEMM3D
854 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
855
856 // Store 4x8 block
857 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
858 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
859 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
860 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
861
862#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000863 // Add offset for batched GEMM
864 dst_addr += z * dst_stride_z;
865
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000866 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000867 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
868 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
869 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
870 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000871#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100872}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100873
874/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
875 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
876 *
877 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
878 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
879 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
880 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
881 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
882 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000883 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
884 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
885 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
886 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
887 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
888 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100889 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
890 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
891 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
892 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
893 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
894 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
895 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
896 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
897 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
898 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
899 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
900 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
901 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
902 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
903 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
904 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
905 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
906 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100907 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100908 */
909__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
910 IMAGE_DECLARATION(src1),
911 IMAGE_DECLARATION(dst),
912 uint src0_stride_z,
913 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000914 uint dst_stride_z
915#if defined(REINTERPRET_OUTPUT_AS_3D)
916 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100917 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000918#endif // REINTERPRET_OUTPUT_AS_3D
919 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100920{
921 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
922 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
923 int z = get_global_id(2);
924
925 // Offset
926 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
927 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
928
929 // src_addr_a = address of matrix A
930 // src_addr_b = address of matrix B
931 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
932 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
933
934#if defined(MATRIX_B_DEPTH)
935 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
936 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
937#else // defined(MATRIX_B_DEPTH)
938 src1_addr_in_bytes += z * src1_stride_z;
939#endif // defined(MATRIX_B_DEPTH)
940
941 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
942 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
943
944 // Compute end row address for matrix B
945 __global half *src_end_addr_b = src_addr_b + COLS_B;
946
947 src_addr_a += offset_row_a;
948 src_addr_b += offset_row_b;
949
950 // Reset accumulators
951 half8 c00 = 0.0f;
952 half8 c10 = 0.0f;
953 half8 c20 = 0.0f;
954 half8 c30 = 0.0f;
955
956#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
957
958 int i = 0;
959 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
960 {
961#if MULT_INTERLEAVE4X4_HEIGHT == 1
962 // Load values from matrix A (interleaved) and matrix B (transposed)
963 half8 a0 = vload8(0, src_addr_a);
964 half8 b0 = vload8(0, src_addr_b);
965
966 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
967 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
968
969 c00 = fma((half8)a0.s0, b0, c00);
970 c10 = fma((half8)a0.s1, b0, c10);
971 c20 = fma((half8)a0.s2, b0, c20);
972 c30 = fma((half8)a0.s3, b0, c30);
973
974 // Load values from matrix B (transposed)
975 b0 = vload8(0, src_addr_b);
976
977 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
978
979 c00 = fma((half8)a0.s4, b0, c00);
980 c10 = fma((half8)a0.s5, b0, c10);
981 c20 = fma((half8)a0.s6, b0, c20);
982 c30 = fma((half8)a0.s7, b0, c30);
983
984 // Load values from matrix A (interleaved) and matrix B (transposed)
985 a0 = vload8(0, src_addr_a);
986 b0 = vload8(0, src_addr_b);
987
988 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
989 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
990
991 c00 = fma((half8)a0.s0, b0, c00);
992 c10 = fma((half8)a0.s1, b0, c10);
993 c20 = fma((half8)a0.s2, b0, c20);
994 c30 = fma((half8)a0.s3, b0, c30);
995
996 // Load values from matrix B (transposed)
997 b0 = vload8(0, src_addr_b);
998
999 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1000
1001 c00 = fma((half8)a0.s4, b0, c00);
1002 c10 = fma((half8)a0.s5, b0, c10);
1003 c20 = fma((half8)a0.s6, b0, c20);
1004 c30 = fma((half8)a0.s7, b0, c30);
1005#else // MULT_INTERLEAVE4X4_HEIGHT == 1
1006 // Load values from matrix A (interleaved) and matrix B (transposed)
1007 half4 a0 = vload4(0, src_addr_a);
1008 half8 b0 = vload8(0, src_addr_b);
1009
1010 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1011 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1012
1013 c00 = fma((half8)a0.s0, b0, c00);
1014 c10 = fma((half8)a0.s1, b0, c10);
1015 c20 = fma((half8)a0.s2, b0, c20);
1016 c30 = fma((half8)a0.s3, b0, c30);
1017
1018 // Load values from matrix A (interleaved) and matrix B (transposed)
1019 a0 = vload4(0, src_addr_a);
1020 b0 = vload8(0, src_addr_b);
1021
1022 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1023 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1024
1025 c00 = fma((half8)a0.s0, b0, c00);
1026 c10 = fma((half8)a0.s1, b0, c10);
1027 c20 = fma((half8)a0.s2, b0, c20);
1028 c30 = fma((half8)a0.s3, b0, c30);
1029
1030 // Load values from matrix A (interleaved) and matrix B (transposed)
1031 a0 = vload4(0, src_addr_a);
1032 b0 = vload8(0, src_addr_b);
1033
1034 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1035 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1036
1037 c00 = fma((half8)a0.s0, b0, c00);
1038 c10 = fma((half8)a0.s1, b0, c10);
1039 c20 = fma((half8)a0.s2, b0, c20);
1040 c30 = fma((half8)a0.s3, b0, c30);
1041
1042 // Load values from matrix A (interleaved) and matrix B (transposed)
1043 a0 = vload4(0, src_addr_a);
1044 b0 = vload8(0, src_addr_b);
1045
1046 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1047 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1048
1049 c00 = fma((half8)a0.s0, b0, c00);
1050 c10 = fma((half8)a0.s1, b0, c10);
1051 c20 = fma((half8)a0.s2, b0, c20);
1052 c30 = fma((half8)a0.s3, b0, c30);
1053#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
1054 }
1055
1056 for(; i < (int)(COLS_MTX_B); ++i)
1057 {
1058 // Load values from matrix A (interleaved) and matrix B (transposed)
1059 half4 a0 = vload4(0, src_addr_a);
1060 half8 b0 = vload8(0, src_addr_b);
1061
1062 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1063 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1064
1065 c00 = fma((half8)a0.s0, b0, c00);
1066 c10 = fma((half8)a0.s1, b0, c10);
1067 c20 = fma((half8)a0.s2, b0, c20);
1068 c30 = fma((half8)a0.s3, b0, c30);
1069 }
1070
1071 // Compute destination address
1072 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1073
1074#if defined(ALPHA)
1075 // Multiply by the weight of matrix product
1076 c00 = c00 * (half8)ALPHA;
1077 c10 = c10 * (half8)ALPHA;
1078 c20 = c20 * (half8)ALPHA;
1079 c30 = c30 * (half8)ALPHA;
1080#endif // defined(ALPHA)
1081
1082 // Compute dst address
1083 __global uchar *dst_addr = offset(&dst, 0, 0);
1084
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001085#if defined(REINTERPRET_OUTPUT_AS_3D)
1086 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001087 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001088 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001089 // | |
1090 // | plane0 |
1091 // | |
1092 // |__________________|
1093 // |******************|
1094 // | cross_plane_pad |
1095 // |******************|
1096 // | |
1097 // | plane1 |
1098 // | |
1099 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001100
1101 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1102 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1103 zout = min(DEPTH_GEMM3D - 1, zout);
1104
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001105 // Add offset due to the cross plane paddings
1106 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001107
1108 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1109 // multiply dst_stride_z by DEPTH_GEMM3D
1110 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1111
1112 // Store 4x8 block
1113 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1114 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1115 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1116 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1117
1118#else // defined(REINTERPRET_OUTPUT_AS_3D)
1119 // Add offset for batched GEMM
1120 dst_addr += z * dst_stride_z;
1121
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001122 // Store 4x8 block
1123 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1124 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1125 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1126 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001127#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001128}
Georgios Pinitas84225582018-05-14 12:00:05 +01001129
1130// Undefine local defines
1131#undef COLS_MTX_B
1132
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001133#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001134
Gian Marco36a0a462018-01-12 10:21:40 +00001135#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001136
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001137#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1138#if defined(DATA_TYPE)
1139#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001140/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001141 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001142 * @note This OpenCL kernel works with floating point data types (F16/F32)
1143 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1144 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001145 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001146 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1147 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001148 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001149 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1150 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001151 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1152 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1153 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1154 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1155 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001156 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001157 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1158 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1159 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1160 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1161 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001162 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001163 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1164 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1165 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1166 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1167 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001168 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001169 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1170 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1171 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1172 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1173 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001174 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1175 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1176 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001177 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1178 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001179 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001180__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1181 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001182 IMAGE_DECLARATION(dst),
1183 uint src0_stride_z,
1184 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001185 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001186#if defined(REINTERPRET_INPUT_AS_3D)
1187 ,
1188 uint src_cross_plane_pad
1189#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001190#if defined(REINTERPRET_OUTPUT_AS_3D)
1191 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001192 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001193#endif // REINTERPRET_OUTPUT_AS_3D
1194 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001195{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001196 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001197
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001198 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001199 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001200
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001201 // Update address for the matrix A
1202 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001203
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001204 // Update address for the matrix B
1205 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001206
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001207#if defined(REINTERPRET_INPUT_AS_3D)
1208 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1209 // in order to take into account the presence of possible cross plane paddings
1210 //
1211 // | |
1212 // | plane0 |
1213 // | |
1214 // |__________________|
1215 // |******************|
1216 // | cross_plane_pad |
1217 // |******************|
1218 // | |
1219 // | plane1 |
1220 // | |
1221 // |__________________|
1222
1223 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1224 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1225 zin = min(DEPTH_GEMM3D - 1, zin);
1226
1227 // Add offset due to the cross plane paddings
1228 zin *= (src_cross_plane_pad * src0_stride_y);
1229
1230 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1231 // multiply src0_stride_z by DEPTH_GEMM3D
1232 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1233
1234#else // defined(REINTERPRET_INPUT_AS_3D)
1235
Gian Marcoae2af742018-02-15 12:35:44 +00001236 // Add offset for batched GEMM
1237 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001238
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001239#endif // defined(REINTERPRET_INPUT_AS_3D)
1240
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001241#if defined(MATRIX_B_DEPTH)
1242 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1243 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1244#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001245 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001246#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001247
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001248 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1249
1250 VECTOR_TYPE acc0 = 0.0f;
1251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1252 VECTOR_TYPE acc1 = 0.0f;
1253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1255 VECTOR_TYPE acc2 = 0.0f;
1256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1258 VECTOR_TYPE acc3 = 0.0f;
1259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1260
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001261 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001262 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001263#if defined(REINTERPRET_INPUT_AS_3D)
1264 // Load values from matrix A
1265 VEC_DATA_TYPE(DATA_TYPE, 2)
1266 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1267#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1268 VEC_DATA_TYPE(DATA_TYPE, 2)
1269 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1270#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1271#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1272 VEC_DATA_TYPE(DATA_TYPE, 2)
1273 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1276 VEC_DATA_TYPE(DATA_TYPE, 2)
1277 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1278#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1279#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001280 // Load values from matrix A
1281 VEC_DATA_TYPE(DATA_TYPE, 2)
1282 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1283#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1284 VEC_DATA_TYPE(DATA_TYPE, 2)
1285 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1288 VEC_DATA_TYPE(DATA_TYPE, 2)
1289 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1290#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1291#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1292 VEC_DATA_TYPE(DATA_TYPE, 2)
1293 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1294#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001295#endif // defined(REINTERPRET_INPUT_AS_3D)
1296
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001297 // Load values from matrix B
1298 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1299 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001300
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001301 // Accumulate
1302 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1303 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1304#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1305 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1306 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1307#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1309 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1310 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1313 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1314 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1315#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001316 }
1317
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001318 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001319 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001320#if defined(REINTERPRET_INPUT_AS_3D)
1321 // Load values from matrix A
1322 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1323#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1324 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1326#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1327 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1328#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1329#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1330 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1331#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1332#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001333 // Load values from matrix A
1334 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1335#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1336 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1337#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1338#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1339 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1340#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1342 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001344#endif // defined(REINTERPRET_INPUT_AS_3D)
1345
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001346 // Load values from matrix B
1347 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001348
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001349 // Accumulate
1350 acc0 += b0 * (VECTOR_TYPE)a0;
1351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1352 acc1 += b0 * (VECTOR_TYPE)a1;
1353#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1354#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1355 acc2 += b0 * (VECTOR_TYPE)a2;
1356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1358 acc3 += b0 * (VECTOR_TYPE)a3;
1359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001360 }
1361
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001362 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001363 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1364
Gian Marcoae2af742018-02-15 12:35:44 +00001365 // Compute dst address
1366 __global uchar *dst_addr = offset(&dst, 0, 0);
1367
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001368 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001369#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001370 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001371#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1373 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1376 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1378#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1379 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1381
1382 int z = get_global_id(2);
1383
1384#if defined(REINTERPRET_OUTPUT_AS_3D)
1385 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001386 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001387 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001388 // | |
1389 // | plane0 |
1390 // | |
1391 // |__________________|
1392 // |******************|
1393 // | cross_plane_pad |
1394 // |******************|
1395 // | |
1396 // | plane1 |
1397 // | |
1398 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001399
1400 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1401 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1402 zout = min(DEPTH_GEMM3D - 1, zout);
1403
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001404 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001405 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001406
1407 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1408 // multiply dst_stride_z by DEPTH_GEMM3D
1409 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1410
1411 // Store output block
1412 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1413 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
1414#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1415 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1416 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
1417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1419 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1420 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
1421#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1423 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1424 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
1425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1426
1427#else // defined(REINTERPRET_OUTPUT_AS_3D)
1428 // Add offset for batched GEMM
1429 dst_addr += z * dst_stride_z;
1430
1431 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001432 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001433 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001434#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001435 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001436 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001437#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1438#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001439 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001440 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001441#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001443 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001444 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001445#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001446#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001447}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001448#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001449
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001450/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001451 *
1452 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1453 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1454 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1455 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1456 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001457 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1458 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001459 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001460 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1461 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001462 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1463 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1464 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1465 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1466 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001467 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1468 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1469 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1470 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1471 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1472 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1473 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1474 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1475 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1476 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1477 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1478 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1479 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1480 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1481 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1482 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1483 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1484 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001485 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1486 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1487 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001488 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1489 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001490 */
1491__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1492 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001493 IMAGE_DECLARATION(dst),
1494 uint src0_stride_z,
1495 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001496 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001497#if defined(REINTERPRET_INPUT_AS_3D)
1498 ,
1499 uint src_cross_plane_pad
1500#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001501#if defined(REINTERPRET_OUTPUT_AS_3D)
1502 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001503 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001504#endif // REINTERPRET_OUTPUT_AS_3D
1505 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001506{
1507 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1508
1509 // Compute starting address for matrix A and matrix B
1510 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1511
1512 // Update address for matrix A
1513 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1514
1515 // Update address for matrix B
1516 src_addr.s1 += idx * sizeof(float);
1517
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001518#if defined(REINTERPRET_INPUT_AS_3D)
1519 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1520 // in order to take into account the presence of possible cross plane paddings
1521 //
1522 // | |
1523 // | plane0 |
1524 // | |
1525 // |__________________|
1526 // |******************|
1527 // | cross_plane_pad |
1528 // |******************|
1529 // | |
1530 // | plane1 |
1531 // | |
1532 // |__________________|
1533
1534 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1535 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1536 zin = min(DEPTH_GEMM3D - 1, zin);
1537
1538 // Add offset due to the cross plane paddings
1539 zin *= (src_cross_plane_pad * src0_stride_y);
1540
1541 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1542 // multiply src0_stride_z by DEPTH_GEMM3D
1543 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1544
1545#else // defined(REINTERPRET_INPUT_AS_3D)
1546
Gian Marcoae2af742018-02-15 12:35:44 +00001547 // Add offset for batched GEMM
1548 src_addr.s0 += get_global_id(2) * src0_stride_z;
1549
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001550#endif // defined(REINTERPRET_INPUT_AS_3D)
1551
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001552#if defined(MATRIX_B_DEPTH)
1553 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1554 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1555#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001556 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001557#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001558
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001559 // Initialize accumulators
1560 float acc00 = 0.0f;
1561 float acc01 = 0.0f;
1562 float acc02 = 0.0f;
1563 float acc03 = 0.0f;
1564
1565#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1566 float acc10 = 0.0f;
1567 float acc11 = 0.0f;
1568 float acc12 = 0.0f;
1569 float acc13 = 0.0f;
1570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1571
1572#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1573 float acc20 = 0.0f;
1574 float acc21 = 0.0f;
1575 float acc22 = 0.0f;
1576 float acc23 = 0.0f;
1577#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1578
1579#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1580 float acc30 = 0.0f;
1581 float acc31 = 0.0f;
1582 float acc32 = 0.0f;
1583 float acc33 = 0.0f;
1584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1585
1586 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001587 int i = 0;
1588 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001589 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001590#if defined(REINTERPRET_INPUT_AS_3D)
1591 // Load values from matrix A and matrix B
1592 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1593#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1594 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1595#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1597 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1598#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1599#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1600 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1602#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001603 // Load values from matrix A and matrix B
1604 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001605#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001606 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1608#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001609 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001610#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001612 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001613#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001614#endif // defined(REINTERPRET_INPUT_AS_3D)
1615
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001616 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1617 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001618
1619 // Multiply and accumulate
1620 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001621 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001622 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001623 acc03 = fma(a0.s0, b0.s3, acc03);
1624
1625#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001626
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001627 acc10 = fma(a1.s0, b0.s0, acc10);
1628 acc11 = fma(a1.s0, b0.s1, acc11);
1629 acc12 = fma(a1.s0, b0.s2, acc12);
1630 acc13 = fma(a1.s0, b0.s3, acc13);
1631
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001632#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001634
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001635 acc20 = fma(a2.s0, b0.s0, acc20);
1636 acc21 = fma(a2.s0, b0.s1, acc21);
1637 acc22 = fma(a2.s0, b0.s2, acc22);
1638 acc23 = fma(a2.s0, b0.s3, acc23);
1639
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001640#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001642
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001643 acc30 = fma(a3.s0, b0.s0, acc30);
1644 acc31 = fma(a3.s0, b0.s1, acc31);
1645 acc32 = fma(a3.s0, b0.s2, acc32);
1646 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001647#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001648
1649 // Load values from matrix A and matrix B
1650 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1651 src_addr.s1 += src1_stride_y;
1652
1653 // Multiply and accumulate
1654 acc00 = fma(a0.s1, b0.s0, acc00);
1655 acc01 = fma(a0.s1, b0.s1, acc01);
1656 acc02 = fma(a0.s1, b0.s2, acc02);
1657 acc03 = fma(a0.s1, b0.s3, acc03);
1658
1659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1660
1661 acc10 = fma(a1.s1, b0.s0, acc10);
1662 acc11 = fma(a1.s1, b0.s1, acc11);
1663 acc12 = fma(a1.s1, b0.s2, acc12);
1664 acc13 = fma(a1.s1, b0.s3, acc13);
1665
1666#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1667#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1668
1669 acc20 = fma(a2.s1, b0.s0, acc20);
1670 acc21 = fma(a2.s1, b0.s1, acc21);
1671 acc22 = fma(a2.s1, b0.s2, acc22);
1672 acc23 = fma(a2.s1, b0.s3, acc23);
1673
1674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1675#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1676
1677 acc30 = fma(a3.s1, b0.s0, acc30);
1678 acc31 = fma(a3.s1, b0.s1, acc31);
1679 acc32 = fma(a3.s1, b0.s2, acc32);
1680 acc33 = fma(a3.s1, b0.s3, acc33);
1681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1682
1683 // Load values from matrix A and matrix B
1684 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1685 src_addr.s1 += src1_stride_y;
1686
1687 // Multiply and accumulate
1688 acc00 = fma(a0.s2, b0.s0, acc00);
1689 acc01 = fma(a0.s2, b0.s1, acc01);
1690 acc02 = fma(a0.s2, b0.s2, acc02);
1691 acc03 = fma(a0.s2, b0.s3, acc03);
1692
1693#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1694
1695 acc10 = fma(a1.s2, b0.s0, acc10);
1696 acc11 = fma(a1.s2, b0.s1, acc11);
1697 acc12 = fma(a1.s2, b0.s2, acc12);
1698 acc13 = fma(a1.s2, b0.s3, acc13);
1699
1700#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1702
1703 acc20 = fma(a2.s2, b0.s0, acc20);
1704 acc21 = fma(a2.s2, b0.s1, acc21);
1705 acc22 = fma(a2.s2, b0.s2, acc22);
1706 acc23 = fma(a2.s2, b0.s3, acc23);
1707
1708#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1709#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1710
1711 acc30 = fma(a3.s2, b0.s0, acc30);
1712 acc31 = fma(a3.s2, b0.s1, acc31);
1713 acc32 = fma(a3.s2, b0.s2, acc32);
1714 acc33 = fma(a3.s2, b0.s3, acc33);
1715#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1716
1717 // Load values from matrix A and matrix B
1718 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1719 src_addr.s1 += src1_stride_y;
1720
1721 // Multiply and accumulate
1722 acc00 = fma(a0.s3, b0.s0, acc00);
1723 acc01 = fma(a0.s3, b0.s1, acc01);
1724 acc02 = fma(a0.s3, b0.s2, acc02);
1725 acc03 = fma(a0.s3, b0.s3, acc03);
1726
1727#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1728
1729 acc10 = fma(a1.s3, b0.s0, acc10);
1730 acc11 = fma(a1.s3, b0.s1, acc11);
1731 acc12 = fma(a1.s3, b0.s2, acc12);
1732 acc13 = fma(a1.s3, b0.s3, acc13);
1733
1734#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1735#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1736
1737 acc20 = fma(a2.s3, b0.s0, acc20);
1738 acc21 = fma(a2.s3, b0.s1, acc21);
1739 acc22 = fma(a2.s3, b0.s2, acc22);
1740 acc23 = fma(a2.s3, b0.s3, acc23);
1741
1742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1744
1745 acc30 = fma(a3.s3, b0.s0, acc30);
1746 acc31 = fma(a3.s3, b0.s1, acc31);
1747 acc32 = fma(a3.s3, b0.s2, acc32);
1748 acc33 = fma(a3.s3, b0.s3, acc33);
1749#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1750
1751 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001752 }
1753
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001754 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001755 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001756#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001757 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001758 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1759#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1760 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1761#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1762#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1763 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1764#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1765#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1766 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1767#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1768#else // defined(REINTERPRET_INPUT_AS_3D)
1769 // Load values from matrix A
1770 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001771#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1772 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1774#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1775 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1776#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1777#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1778 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1779#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001780#endif // defined(REINTERPRET_INPUT_AS_3D)
1781
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001782 // Load values from matrix B
1783 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001784 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001785
1786 // Multiply and accumulate
1787 acc00 = fma(a0, b0.s0, acc00);
1788 acc01 = fma(a0, b0.s1, acc01);
1789 acc02 = fma(a0, b0.s2, acc02);
1790 acc03 = fma(a0, b0.s3, acc03);
1791#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1792 acc10 = fma(a1, b0.s0, acc10);
1793 acc11 = fma(a1, b0.s1, acc11);
1794 acc12 = fma(a1, b0.s2, acc12);
1795 acc13 = fma(a1, b0.s3, acc13);
1796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1798 acc20 = fma(a2, b0.s0, acc20);
1799 acc21 = fma(a2, b0.s1, acc21);
1800 acc22 = fma(a2, b0.s2, acc22);
1801 acc23 = fma(a2, b0.s3, acc23);
1802#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1803#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1804 acc30 = fma(a3, b0.s0, acc30);
1805 acc31 = fma(a3, b0.s1, acc31);
1806 acc32 = fma(a3, b0.s2, acc32);
1807 acc33 = fma(a3, b0.s3, acc33);
1808#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001809
1810 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001811 }
1812
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001813 int z = get_global_id(2);
1814
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001815 // Compute destination address
1816 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1817
1818 // Multiply by the weight of matrix-matrix product and store the result
1819#if defined(ALPHA)
1820 acc00 = acc00 * ALPHA;
1821 acc01 = acc01 * ALPHA;
1822 acc02 = acc02 * ALPHA;
1823 acc03 = acc03 * ALPHA;
1824#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001825#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001826 acc10 = acc10 * ALPHA;
1827 acc11 = acc11 * ALPHA;
1828 acc12 = acc12 * ALPHA;
1829 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001832 acc20 = acc20 * ALPHA;
1833 acc21 = acc21 * ALPHA;
1834 acc22 = acc22 * ALPHA;
1835 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1837#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001838 acc30 = acc30 * ALPHA;
1839 acc31 = acc31 * ALPHA;
1840 acc32 = acc32 * ALPHA;
1841 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001842#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1843
1844 // Compute dst address
1845 __global uchar *dst_addr = offset(&dst, 0, 0);
1846
1847#if defined(REINTERPRET_OUTPUT_AS_3D)
1848 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001849 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001850 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001851 // | |
1852 // | plane0 |
1853 // | |
1854 // |__________________|
1855 // |******************|
1856 // | cross_plane_pad |
1857 // |******************|
1858 // | |
1859 // | plane1 |
1860 // | |
1861 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001862
1863 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1864 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1865 zout = min(DEPTH_GEMM3D - 1, zout);
1866
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001867 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001868 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001869
1870 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1871 // multiply dst_stride_z by DEPTH_GEMM3D
1872 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1873
1874 // Store the output block
1875 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1877 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1879#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1880 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1881#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1882#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1883 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001885
1886#else // defined(REINTERPRET_OUTPUT_AS_3D)
1887 // Add offset for batched GEMM
1888 dst_addr += z * dst_stride_z;
1889
1890 // Store the output block
1891 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1892#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1893 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1895#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1896 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1897#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1898#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1899 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1900#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1901#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001902}
1903
1904/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1905 *
1906 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1907 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1908 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1909 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1910 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1911 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001912 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1913 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001914 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001915 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1916 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001917 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1918 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1919 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1920 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1921 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001922 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1923 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1924 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1925 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1926 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1927 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1928 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1929 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1930 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1931 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1932 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1933 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1934 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1935 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1936 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1937 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1938 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1939 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001940 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1941 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1942 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001943 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1944 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001945 */
1946__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1947 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001948 IMAGE_DECLARATION(dst),
1949 uint src0_stride_z,
1950 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001951 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001952#if defined(REINTERPRET_INPUT_AS_3D)
1953 ,
1954 uint src_cross_plane_pad
1955#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001956#if defined(REINTERPRET_OUTPUT_AS_3D)
1957 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001958 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001959#endif // REINTERPRET_OUTPUT_AS_3D
1960 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001961{
1962 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1963 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1964
1965 // Compute starting address for matrix A and Matrix B
1966 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1967
1968 // Update address for the matrix A
1969 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1970
1971 // Update address for the matrix B
1972 src_addr.s1 += idx * sizeof(float);
1973
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001974#if defined(REINTERPRET_INPUT_AS_3D)
1975 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1976 // in order to take into account the presence of possible cross plane paddings
1977 //
1978 // | |
1979 // | plane0 |
1980 // | |
1981 // |__________________|
1982 // |******************|
1983 // | cross_plane_pad |
1984 // |******************|
1985 // | |
1986 // | plane1 |
1987 // | |
1988 // |__________________|
1989
1990 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1991 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1992 zin = min(DEPTH_GEMM3D - 1, zin);
1993
1994 // Add offset due to the cross plane paddings
1995 zin *= (src_cross_plane_pad * src0_stride_y);
1996
1997 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1998 // multiply src0_stride_z by DEPTH_GEMM3D
1999 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2000
2001#else // defined(REINTERPRET_INPUT_AS_3D)
2002
Gian Marcoae2af742018-02-15 12:35:44 +00002003 // Add offset for batched GEMM
2004 src_addr.s0 += get_global_id(2) * src0_stride_z;
2005
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002006#endif // defined(REINTERPRET_INPUT_AS_3D)
2007
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002008#if defined(MATRIX_B_DEPTH)
2009 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2010 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2011#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002012 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002013#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002014
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002015 // Initialize accumulators
2016 float acc00 = 0.0f;
2017 float acc01 = 0.0f;
2018
2019#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2020 float acc10 = 0.0f;
2021 float acc11 = 0.0f;
2022#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2023#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2024 float acc20 = 0.0f;
2025 float acc21 = 0.0f;
2026#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2028 float acc30 = 0.0f;
2029 float acc31 = 0.0f;
2030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2031
2032 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002033 int i = 0;
2034 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002035 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002036#if defined(REINTERPRET_INPUT_AS_3D)
2037 // Load values from matrix A
2038 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
2039#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002040 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002041 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002042#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002043
2044 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002045 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2046 src_addr.s1 += src1_stride_y;
2047 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2048 src_addr.s1 += src1_stride_y;
2049 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2050 src_addr.s1 += src1_stride_y;
2051 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2052 src_addr.s1 += src1_stride_y;
2053 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2054 src_addr.s1 += src1_stride_y;
2055 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2056 src_addr.s1 += src1_stride_y;
2057 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2058 src_addr.s1 += src1_stride_y;
2059 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2060 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002061
2062 // Multiply and accumulate
2063 acc00 = fma(a0.s0, b0.s0, acc00);
2064 acc00 = fma(a0.s1, b1.s0, acc00);
2065 acc00 = fma(a0.s2, b2.s0, acc00);
2066 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002067 acc00 = fma(a0.s4, b4.s0, acc00);
2068 acc00 = fma(a0.s5, b5.s0, acc00);
2069 acc00 = fma(a0.s6, b6.s0, acc00);
2070 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002071
2072 acc01 = fma(a0.s0, b0.s1, acc01);
2073 acc01 = fma(a0.s1, b1.s1, acc01);
2074 acc01 = fma(a0.s2, b2.s1, acc01);
2075 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002076 acc01 = fma(a0.s4, b4.s1, acc01);
2077 acc01 = fma(a0.s5, b5.s1, acc01);
2078 acc01 = fma(a0.s6, b6.s1, acc01);
2079 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002080
2081#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002082#if defined(REINTERPRET_INPUT_AS_3D)
2083 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2084#else // defined(REINTERPRET_INPUT_AS_3D)
2085 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2086#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002087 acc10 = fma(a0.s0, b0.s0, acc10);
2088 acc10 = fma(a0.s1, b1.s0, acc10);
2089 acc10 = fma(a0.s2, b2.s0, acc10);
2090 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002091 acc10 = fma(a0.s4, b4.s0, acc10);
2092 acc10 = fma(a0.s5, b5.s0, acc10);
2093 acc10 = fma(a0.s6, b6.s0, acc10);
2094 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002095
2096 acc11 = fma(a0.s0, b0.s1, acc11);
2097 acc11 = fma(a0.s1, b1.s1, acc11);
2098 acc11 = fma(a0.s2, b2.s1, acc11);
2099 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002100 acc11 = fma(a0.s4, b4.s1, acc11);
2101 acc11 = fma(a0.s5, b5.s1, acc11);
2102 acc11 = fma(a0.s6, b6.s1, acc11);
2103 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002104#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2105#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002106#if defined(REINTERPRET_INPUT_AS_3D)
2107 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2108#else // defined(REINTERPRET_INPUT_AS_3D)
2109 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2110#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002111 acc20 = fma(a0.s0, b0.s0, acc20);
2112 acc20 = fma(a0.s1, b1.s0, acc20);
2113 acc20 = fma(a0.s2, b2.s0, acc20);
2114 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002115 acc20 = fma(a0.s4, b4.s0, acc20);
2116 acc20 = fma(a0.s5, b5.s0, acc20);
2117 acc20 = fma(a0.s6, b6.s0, acc20);
2118 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002119
2120 acc21 = fma(a0.s0, b0.s1, acc21);
2121 acc21 = fma(a0.s1, b1.s1, acc21);
2122 acc21 = fma(a0.s2, b2.s1, acc21);
2123 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002124 acc21 = fma(a0.s4, b4.s1, acc21);
2125 acc21 = fma(a0.s5, b5.s1, acc21);
2126 acc21 = fma(a0.s6, b6.s1, acc21);
2127 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002128#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2129#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002130#if defined(REINTERPRET_INPUT_AS_3D)
2131 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2132#else // defined(REINTERPRET_INPUT_AS_3D)
2133 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2134#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002135 acc30 = fma(a0.s0, b0.s0, acc30);
2136 acc30 = fma(a0.s1, b1.s0, acc30);
2137 acc30 = fma(a0.s2, b2.s0, acc30);
2138 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002139 acc30 = fma(a0.s4, b4.s0, acc30);
2140 acc30 = fma(a0.s5, b5.s0, acc30);
2141 acc30 = fma(a0.s6, b6.s0, acc30);
2142 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002143
2144 acc31 = fma(a0.s0, b0.s1, acc31);
2145 acc31 = fma(a0.s1, b1.s1, acc31);
2146 acc31 = fma(a0.s2, b2.s1, acc31);
2147 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002148 acc31 = fma(a0.s4, b4.s1, acc31);
2149 acc31 = fma(a0.s5, b5.s1, acc31);
2150 acc31 = fma(a0.s6, b6.s1, acc31);
2151 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002152#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002153
2154 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002155 }
2156 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002157 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002158 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002159#if defined(REINTERPRET_INPUT_AS_3D)
2160 // Load values from matrix A
2161 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2162#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2163 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2164#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2165#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2166 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2167#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2169 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2170#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2171#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002172 // Load values from matrix A
2173 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2175 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2176#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2177#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2178 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2179#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2181 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2182#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002183#endif // defined(REINTERPRET_INPUT_AS_3D)
2184
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002185 // Load values from matrix B
2186 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002187 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002188
2189 // Multiply and accumulate
2190 acc00 = fma(a0, b0.s0, acc00);
2191 acc01 = fma(a0, b0.s1, acc01);
2192#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2193 acc10 = fma(a1, b0.s0, acc10);
2194 acc11 = fma(a1, b0.s1, acc11);
2195#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2196#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2197 acc20 = fma(a2, b0.s0, acc20);
2198 acc21 = fma(a2, b0.s1, acc21);
2199#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2200#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2201 acc30 = fma(a3, b0.s0, acc30);
2202 acc31 = fma(a3, b0.s1, acc31);
2203#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002204
2205 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002206 }
2207
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002208 // Multiply by the weight of matrix-matrix product and store the result
2209#if defined(ALPHA)
2210 acc00 = acc00 * ALPHA;
2211 acc01 = acc01 * ALPHA;
2212#endif // defined(ALPHA)
2213#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2214 acc10 = acc10 * ALPHA;
2215 acc11 = acc11 * ALPHA;
2216#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2217#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2218 acc20 = acc20 * ALPHA;
2219 acc21 = acc21 * ALPHA;
2220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2222 acc30 = acc30 * ALPHA;
2223 acc31 = acc31 * ALPHA;
2224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2225
2226 int z = get_global_id(2);
2227
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002228 // Compute destination address
2229 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2230
Gian Marcoae2af742018-02-15 12:35:44 +00002231 // Compute dst address
2232 __global uchar *dst_addr = offset(&dst, 0, 0);
2233
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002234#if defined(REINTERPRET_OUTPUT_AS_3D)
2235 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002236 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002237 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002238 // | |
2239 // | plane0 |
2240 // | |
2241 // |__________________|
2242 // |******************|
2243 // | cross_plane_pad |
2244 // |******************|
2245 // | |
2246 // | plane1 |
2247 // | |
2248 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00002249
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002250 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2251 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2252 zout = min(DEPTH_GEMM3D - 1, zout);
2253
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002254 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002255 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002256
2257 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2258 // multiply dst_stride_z by DEPTH_GEMM3D
2259 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2260
2261 // Store the output block
2262 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002264 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002265#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002267 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002270 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002272
2273#else // defined(REINTERPRET_OUTPUT_AS_3D)
2274 // Add offset for batched GEMM
2275 dst_addr += z * dst_stride_z;
2276
2277 // Store the output block
2278 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2279#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2280 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2281#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2282#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2283 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2286 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2288#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002289}
2290
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002291#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002292/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2293 *
2294 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2295 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2296 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2297 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2298 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2299 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2300 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2301 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002302 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2303 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002304 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2305 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2306 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2307 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2308 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002309 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2310 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2311 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2312 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2313 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2314 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2315 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2316 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2317 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2318 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2319 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2320 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2321 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2322 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2323 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2324 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2325 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2326 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002327 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2328 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2329 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002330 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2331 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002332 */
2333__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2334 IMAGE_DECLARATION(src1),
2335 IMAGE_DECLARATION(dst),
2336 uint src0_stride_z,
2337 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002338 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002339#if defined(REINTERPRET_INPUT_AS_3D)
2340 ,
2341 uint src_cross_plane_pad
2342#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002343#if defined(REINTERPRET_OUTPUT_AS_3D)
2344 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002345 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002346#endif // REINTERPRET_OUTPUT_AS_3D
2347 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002348{
2349 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2350
2351 // Compute starting address for matrix A and Matrix B
2352 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2353
2354 // Update address for the matrix A
2355 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2356
2357 // Update address for the matrix B
2358 src_addr.s1 += idx * sizeof(half);
2359
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002360#if defined(REINTERPRET_INPUT_AS_3D)
2361 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2362 // in order to take into account the presence of possible cross plane paddings
2363 //
2364 // | |
2365 // | plane0 |
2366 // | |
2367 // |__________________|
2368 // |******************|
2369 // | cross_plane_pad |
2370 // |******************|
2371 // | |
2372 // | plane1 |
2373 // | |
2374 // |__________________|
2375
2376 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2377 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2378 zin = min(DEPTH_GEMM3D - 1, zin);
2379
2380 // Add offset due to the cross plane paddings
2381 zin *= (src_cross_plane_pad * src0_stride_y);
2382
2383 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2384 // multiply src0_stride_z by DEPTH_GEMM3D
2385 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2386
2387#else // defined(REINTERPRET_INPUT_AS_3D)
2388
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002389 // Add offset for batched GEMM
2390 src_addr.s0 += get_global_id(2) * src0_stride_z;
2391
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002392#endif // defined(REINTERPRET_INPUT_AS_3D)
2393
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002394#if defined(MATRIX_B_DEPTH)
2395 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2396 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2397#else // defined(MATRIX_B_DEPTH)
2398 src_addr.s1 += get_global_id(2) * src1_stride_z;
2399#endif // defined(MATRIX_B_DEPTH)
2400
2401 half8 acc0 = 0.0h;
2402#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2403 half8 acc1 = 0.0h;
2404#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2405#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2406 half8 acc2 = 0.0h;
2407#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2409 half8 acc3 = 0.0h;
2410#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2411
2412 int i = 0;
2413 for(; i <= ((int)COLS_A - 4); i += 4)
2414 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002415#if defined(REINTERPRET_INPUT_AS_3D)
2416 // Load values from matrix A
2417 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2419 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2422 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2425 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2427#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002428 // Load values from matrix A
2429 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2430#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2431 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2434 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2437 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2438#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002439#endif // defined(REINTERPRET_INPUT_AS_3D)
2440
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002441 // Load values from matrix B
2442 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2443 src_addr.s1 += src1_stride_y;
2444
2445 // Accumulate
2446 acc0 = fma(b0, (half8)a0.s0, acc0);
2447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2448 acc1 = fma(b0, (half8)a1.s0, acc1);
2449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2451 acc2 = fma(b0, (half8)a2.s0, acc2);
2452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2453#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2454 acc3 = fma(b0, (half8)a3.s0, acc3);
2455#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2456
2457 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2458 src_addr.s1 += src1_stride_y;
2459 acc0 = fma(b0, (half8)a0.s1, acc0);
2460#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2461 acc1 = fma(b0, (half8)a1.s1, acc1);
2462#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2463#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2464 acc2 = fma(b0, (half8)a2.s1, acc2);
2465#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2466#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2467 acc3 = fma(b0, (half8)a3.s1, acc3);
2468#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2469
2470 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2471 src_addr.s1 += src1_stride_y;
2472 acc0 = fma(b0, (half8)a0.s2, acc0);
2473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2474 acc1 = fma(b0, (half8)a1.s2, acc1);
2475#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2477 acc2 = fma(b0, (half8)a2.s2, acc2);
2478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2479#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2480 acc3 = fma(b0, (half8)a3.s2, acc3);
2481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2482
2483 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2484 src_addr.s1 += src1_stride_y;
2485 acc0 = fma(b0, (half8)a0.s3, acc0);
2486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2487 acc1 = fma(b0, (half8)a1.s3, acc1);
2488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2490 acc2 = fma(b0, (half8)a2.s3, acc2);
2491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2493 acc3 = fma(b0, (half8)a3.s3, acc3);
2494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2495
2496 src_addr.s0 += 4 * sizeof(half);
2497 }
2498
2499 for(; i < (int)COLS_A; ++i)
2500 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002501#if defined(REINTERPRET_INPUT_AS_3D)
2502 // Load values from matrix A
2503 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2504#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2505 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2506#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2507#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2508 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2509#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2510#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2511 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2512#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2513#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002514 // Load values from matrix A
2515 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2516#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2517 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2518#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2520 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2521#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2522#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2523 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2524#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002525#endif // defined(REINTERPRET_INPUT_AS_3D)
2526
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002527 // Load values from matrix B
2528 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2529
2530 src_addr += (int2)(sizeof(half), src1_stride_y);
2531
2532 // Accumulate
2533 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2534#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2535 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2536#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2537#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2538 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2539#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2540#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2541 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2542#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2543 }
2544
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002545 // Multiply by the weight of matrix-matrix product and store the result
2546#if defined(ALPHA)
2547 acc0 = acc0 * (half8)ALPHA;
2548#endif // defined(ALPHA)
2549#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2550 acc1 = acc1 * (half8)ALPHA;
2551#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2552#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2553 acc2 = acc2 * (half8)ALPHA;
2554#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2555#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2556 acc3 = acc3 * (half8)ALPHA;
2557#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2558
2559 int z = get_global_id(2);
2560
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002561 // Compute destination address
2562 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2563
2564 // Compute dst address
2565 __global uchar *dst_addr = offset(&dst, 0, 0);
2566
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002567#if defined(REINTERPRET_OUTPUT_AS_3D)
2568 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002569 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002570 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002571 // | |
2572 // | plane0 |
2573 // | |
2574 // |__________________|
2575 // |******************|
2576 // | cross_plane_pad |
2577 // |******************|
2578 // | |
2579 // | plane1 |
2580 // | |
2581 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002582
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002583 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2584 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2585 zout = min(DEPTH_GEMM3D - 1, zout);
2586
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002587 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002588 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002589
2590 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2591 // multiply dst_stride_z by DEPTH_GEMM3D
2592 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2593
2594 // Store the output block
2595 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2597 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2598#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2599#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2600 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2603 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2604#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2605
2606#else // defined(REINTERPRET_OUTPUT_AS_3D)
2607 // Add offset for batched GEMM
2608 dst_addr += z * dst_stride_z;
2609
2610 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002611 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2612#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002613 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002616 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002619 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002621#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002622}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002623#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002624
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002625#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002626
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002627#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002628/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2629 *
Gian Marco19835e52018-01-30 13:35:54 +00002630 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002631 *
2632 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2633 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2634 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2635 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2636 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002637 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2638 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002639 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002640 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002641 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2642 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2643 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2644 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002645 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2646 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002647 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2648 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002649__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
2650 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002651{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002652 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002653 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2654 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002655
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002656 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002657 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2658
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002659 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002660 float4 c = vload4(0, (__global float *)src.ptr);
2661
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002662 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002663 float4 out = alpha_ab + (float4)BETA * c;
2664
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002665 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002666 vstore4(out, 0, (__global float *)dst.ptr);
2667}
2668
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002669#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002670/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2671 *
Gian Marco19835e52018-01-30 13:35:54 +00002672 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002673 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002674 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2675 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2676 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2677 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2678 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002679 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2680 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002681 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002682 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002683 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2684 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2685 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2686 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002687 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2688 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002689 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2690 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002691__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
2692 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002693{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002694 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002695 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2696 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002697
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002698 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002699 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2700
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002701 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002702 half8 c = vload8(0, (__global half *)src.ptr);
2703
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002704 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002705 half8 out = alpha_ab + (half8)BETA * c;
2706
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002707 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002708 vstore8(out, 0, (__global half *)dst.ptr);
2709}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01002710#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002711#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002712
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002713#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002714/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2715 *
Gian Marco19835e52018-01-30 13:35:54 +00002716 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002717 *
Gian Marco19835e52018-01-30 13:35:54 +00002718 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002719 *
2720 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2721 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2722 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2723 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2724 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2725 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002726 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002727 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2728 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2729 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2730 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2731 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2732 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2733 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002734 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002735 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2736 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2737 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2738 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2739 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2740 */
2741__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2742 TENSOR3D_DECLARATION(src1),
2743 IMAGE_DECLARATION(dst))
2744{
2745 int idx = get_global_id(0) * 4;
2746 int idy = get_global_id(1);
2747
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002748 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002749 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2750 src_addr.s1 += idx * sizeof(float);
2751
2752 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2753
2754 float4 acc = 0.0f;
2755
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002756 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002757 {
2758 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2759 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2760 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2761
2762 acc += b0 * (float4)a0.s0;
2763 acc += b1 * (float4)a0.s1;
2764 }
2765
2766 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2767 {
2768 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2769 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2770
2771 acc += b0 * (float4)a0;
2772 }
2773
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002774 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002775 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2776
2777 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2778}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002779#endif // defined(WIDTH_VECTOR_A)
2780
2781/** This kernel accumulates each row with the biases vector.
2782 *
2783 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2784 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2785 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002786 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002787 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2788 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2789 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2790 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2791 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2792 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2793 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2794 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2795 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2796 */
2797#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2798__kernel void gemm_accumulate_biases(
2799 IMAGE_DECLARATION(accum),
2800 VECTOR_DECLARATION(biases))
2801{
2802 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2803 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2804
2805 // Vector size, i.e. number of vector elements.
2806 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2807 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2808 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2809 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002810 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002811 // Store result in the accumulate buffer
2812 VSTORE(VECTOR_SIZE)
2813 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2814}
2815#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)