blob: 5d5cab6578f6be0742698983645cf3633ef5180b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco36a0a462018-01-12 10:21:40 +000026#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
27
Gian Marco19835e52018-01-30 13:35:54 +000028#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +000029#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +000030#elif ELEMENT_SIZE == 2
31#define DATA_TYPE ushort
32#elif ELEMENT_SIZE == 4
33#define DATA_TYPE uint
34#else // ELEMENT_SIZE == 1
35#error "Element size not supported"
36#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +000037
38/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039 *
Gian Marco19835e52018-01-30 13:35:54 +000040 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
41 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +000042 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010043 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
45 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
46 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
47 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000048 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
49 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010050 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010051 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000053 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000055 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +000056 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
57 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
59 */
Gian Marcoae2af742018-02-15 12:35:44 +000060__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
61 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062{
63 uint x = get_global_id(0);
64 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +000065 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010067 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +000068 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010070 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000071 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
72 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marcoae2af742018-02-15 12:35:44 +000074 // Add offset for batched GEMM
75 dst_addr_in_bytes += z * dst_stride_z;
76
Gian Marco36a0a462018-01-12 10:21:40 +000077 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
78 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079
Gian Marco36a0a462018-01-12 10:21:40 +000080 VSTORE(TRANSPOSE_W)
81 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010082}
Gian Marco36a0a462018-01-12 10:21:40 +000083#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084
Gian Marco36a0a462018-01-12 10:21:40 +000085#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
86
Gian Marco Iodice4b908652018-10-18 10:21:02 +010087/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
88 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +010089 *
Gian Marco19835e52018-01-30 13:35:54 +000090 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
91 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +010092 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
93 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
94 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
95 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
96 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +000097 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010098 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010099 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
100 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
101 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
102 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000103 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
104 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100106 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100107 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
108 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
109 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
110 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000111 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
112 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100113 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100114 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115 */
Gian Marcoae2af742018-02-15 12:35:44 +0000116__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100117 TENSOR3D_DECLARATION(dst)
118#if defined(REINTERPRET_INPUT_AS_3D)
119 ,
120 uint cross_plane_pad
121#endif // REINTERPRET_INPUT_AS_3D
122 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123{
Gian Marco36a0a462018-01-12 10:21:40 +0000124 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100125 uint x = get_global_id(0);
126 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000127 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100128
Gian Marcoae2af742018-02-15 12:35:44 +0000129 // Compute address for source tensor
130 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100131
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000132 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000133 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
134 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100135
Gian Marcoae2af742018-02-15 12:35:44 +0000136 // Add offset for batched GEMM
137 dst_addr_in_bytes += z * dst_stride_z;
138
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100139#if defined(REINTERPRET_INPUT_AS_3D)
140 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
141
142 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
143 // in order to take into account the presence of possible cross plane paddings
144 //
145 // | |
146 // | plane0 |
147 // | |
148 // |__________________|
149 // |******************|
150 // | cross_plane_pad |
151 // |******************|
152 // | |
153 // | plane1 |
154 // | |
155 // |__________________|
156
157 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
158 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
159 zin = min(DEPTH_GEMM3D - 1, zin);
160
161 // Add offset due to the cross plane paddings
162 zin *= (cross_plane_pad * src_stride_y);
163
164 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
165 // multiply src_stride_z by DEPTH_GEMM3D
166 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
167
168 // Load values from Matrix A
169 VEC_DATA_TYPE(DATA_TYPE, 4)
170 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
171 VEC_DATA_TYPE(DATA_TYPE, 4)
172 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
173 VEC_DATA_TYPE(DATA_TYPE, 4)
174 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
175 VEC_DATA_TYPE(DATA_TYPE, 4)
176 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
177#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000178 __global uchar *input_ptr = src.ptr;
179
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000180 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000181 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000182 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000183 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000184 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000185 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000186 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000187 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000188 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100189#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100191#if defined(UNROLL_BLOCK)
192 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
193 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
194 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
195 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
196#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +0000197 VEC_DATA_TYPE(DATA_TYPE, 4)
198 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
199 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
Gian Marco36a0a462018-01-12 10:21:40 +0000201 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
202 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203
Gian Marco36a0a462018-01-12 10:21:40 +0000204 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
205 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206
Gian Marco36a0a462018-01-12 10:21:40 +0000207 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
208 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100209#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210}
Gian Marco36a0a462018-01-12 10:21:40 +0000211#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212
Gian Marco36a0a462018-01-12 10:21:40 +0000213#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100214/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100215 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216 *
Gian Marco19835e52018-01-30 13:35:54 +0000217 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
218 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
219 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000220 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
221 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000223 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
224 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
225 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
226 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
227 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
228 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
230 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
231 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
232 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
233 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
234 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100235 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100236 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
237 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
238 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
239 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
240 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100241 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100242 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000243 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100244 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000245 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100246 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000247 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
248 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
249 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100250 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100252__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
253 IMAGE_DECLARATION(src1),
254 IMAGE_DECLARATION(dst),
255 uint src0_stride_z,
256 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000257 uint dst_stride_z
258#if defined(REINTERPRET_OUTPUT_AS_3D)
259 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100260 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000261#endif // REINTERPRET_OUTPUT_AS_3D
262 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263{
Gian Marco36a0a462018-01-12 10:21:40 +0000264 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
265 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000266 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267
Gian Marco36a0a462018-01-12 10:21:40 +0000268 // Offset
269 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
270 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100271
Gian Marco36a0a462018-01-12 10:21:40 +0000272 // src_addr_a = address of matrix A
273 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000274 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
275 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
276
277#if defined(MATRIX_B_DEPTH)
278 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
279 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
280#else // defined(MATRIX_B_DEPTH)
281 src1_addr_in_bytes += z * src1_stride_z;
282#endif // defined(MATRIX_B_DEPTH)
283
284 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
285 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000287 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000288 __global float *src_end_addr_b = src_addr_b + COLS_B;
289
290 src_addr_a += offset_row_a;
291 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100292
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000293 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100294 float4 c00 = 0.0f;
295 float4 c10 = 0.0f;
296 float4 c20 = 0.0f;
297 float4 c30 = 0.0f;
298
Gian Marco36a0a462018-01-12 10:21:40 +0000299 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100300 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000301 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000302 float4 a0 = vload4(0, src_addr_a);
303 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100304
305 c00 += (float4)a0.s0 * b0;
306 c10 += (float4)a0.s1 * b0;
307 c20 += (float4)a0.s2 * b0;
308 c30 += (float4)a0.s3 * b0;
309
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000310 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000311 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
312 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313
314 c00 += (float4)a0.s0 * b0;
315 c10 += (float4)a0.s1 * b0;
316 c20 += (float4)a0.s2 * b0;
317 c30 += (float4)a0.s3 * b0;
318 }
319
Gian Marco36a0a462018-01-12 10:21:40 +0000320 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100321 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000322 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000323 float4 a0 = vload4(0, src_addr_a);
324 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100325
326 c00 += (float4)a0.s0 * b0;
327 c10 += (float4)a0.s1 * b0;
328 c20 += (float4)a0.s2 * b0;
329 c30 += (float4)a0.s3 * b0;
330 }
331
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000332 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
334
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000335#if defined(ALPHA)
336 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337 c00 = c00 * (float4)ALPHA;
338 c10 = c10 * (float4)ALPHA;
339 c20 = c20 * (float4)ALPHA;
340 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000341#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342
Gian Marcoae2af742018-02-15 12:35:44 +0000343 // Compute dst address
344 __global uchar *dst_addr = offset(&dst, 0, 0);
345
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000346#if defined(REINTERPRET_OUTPUT_AS_3D)
347 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100348 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000349 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100350 // | |
351 // | plane0 |
352 // | |
353 // |__________________|
354 // |******************|
355 // | cross_plane_pad |
356 // |******************|
357 // | |
358 // | plane1 |
359 // | |
360 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000361
362 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
363 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
364 zout = min(DEPTH_GEMM3D - 1, zout);
365
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100366 // Add offset due to the cross plane paddings
367 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000368
369 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
370 // multiply dst_stride_z by DEPTH_GEMM3D
371 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
372
373 // Store 4x4 block
374 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
375 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
376 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
377 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
378
379#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000380 // Add offset for batched GEMM
381 dst_addr += z * dst_stride_z;
382
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000383 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000384 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
385 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
386 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
387 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000388#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100389}
390
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000391/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100392 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100393 *
Gian Marco19835e52018-01-30 13:35:54 +0000394 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
395 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
396 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000397 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
398 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
399 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000401 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
402 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
403 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
404 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
405 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
406 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100407 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
408 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
409 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
410 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
411 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
412 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100413 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100414 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
415 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
416 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
417 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
418 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100419 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100420 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000421 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100422 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000423 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100424 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000425 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
426 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
427 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100428 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100429 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100430__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
431 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000432 IMAGE_DECLARATION(dst),
433 uint src0_stride_z,
434 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000435 uint dst_stride_z
436#if defined(REINTERPRET_OUTPUT_AS_3D)
437 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100438 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000439#endif // REINTERPRET_OUTPUT_AS_3D
440 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100441{
Gian Marco36a0a462018-01-12 10:21:40 +0000442 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
443 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000444 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000445
446 // Offset
447 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
448 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
449
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100450 // src_addr_a = address of matrix A
451 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000452 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
453 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
454
455#if defined(MATRIX_B_DEPTH)
456 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
457 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
458#else // defined(MATRIX_B_DEPTH)
459 src1_addr_in_bytes += z * src1_stride_z;
460#endif // defined(MATRIX_B_DEPTH)
461
462 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
463 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100464
Gian Marco36a0a462018-01-12 10:21:40 +0000465 src_addr_a += offset_row_a;
466 src_addr_b += offset_row_b;
467
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100468 // Reset accumulators
469 float c00 = 0.0f;
470 float c01 = 0.0f;
471 float c02 = 0.0f;
472 float c03 = 0.0f;
473 float c10 = 0.0f;
474 float c11 = 0.0f;
475 float c12 = 0.0f;
476 float c13 = 0.0f;
477 float c20 = 0.0f;
478 float c21 = 0.0f;
479 float c22 = 0.0f;
480 float c23 = 0.0f;
481 float c30 = 0.0f;
482 float c31 = 0.0f;
483 float c32 = 0.0f;
484 float c33 = 0.0f;
485
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100486#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
487
488 int i = 0;
489 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100490 {
491 // Load values from matrix A (interleaved) and matrix B (transposed)
492 float4 a0 = vload4(0, src_addr_a);
493 float4 b0 = vload4(0, src_addr_b);
494
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100495 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
496 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100497
498 c00 = fma(a0.s0, b0.s0, c00);
499 c01 = fma(a0.s0, b0.s1, c01);
500 c02 = fma(a0.s0, b0.s2, c02);
501 c03 = fma(a0.s0, b0.s3, c03);
502
503 c10 = fma(a0.s1, b0.s0, c10);
504 c11 = fma(a0.s1, b0.s1, c11);
505 c12 = fma(a0.s1, b0.s2, c12);
506 c13 = fma(a0.s1, b0.s3, c13);
507
508 c20 = fma(a0.s2, b0.s0, c20);
509 c21 = fma(a0.s2, b0.s1, c21);
510 c22 = fma(a0.s2, b0.s2, c22);
511 c23 = fma(a0.s2, b0.s3, c23);
512
513 c30 = fma(a0.s3, b0.s0, c30);
514 c31 = fma(a0.s3, b0.s1, c31);
515 c32 = fma(a0.s3, b0.s2, c32);
516 c33 = fma(a0.s3, b0.s3, c33);
517
518 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100519 a0 = vload4(0, src_addr_a);
520 b0 = vload4(0, src_addr_b);
521
522 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
523 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100524
525 c00 = fma(a0.s0, b0.s0, c00);
526 c01 = fma(a0.s0, b0.s1, c01);
527 c02 = fma(a0.s0, b0.s2, c02);
528 c03 = fma(a0.s0, b0.s3, c03);
529
530 c10 = fma(a0.s1, b0.s0, c10);
531 c11 = fma(a0.s1, b0.s1, c11);
532 c12 = fma(a0.s1, b0.s2, c12);
533 c13 = fma(a0.s1, b0.s3, c13);
534
535 c20 = fma(a0.s2, b0.s0, c20);
536 c21 = fma(a0.s2, b0.s1, c21);
537 c22 = fma(a0.s2, b0.s2, c22);
538 c23 = fma(a0.s2, b0.s3, c23);
539
540 c30 = fma(a0.s3, b0.s0, c30);
541 c31 = fma(a0.s3, b0.s1, c31);
542 c32 = fma(a0.s3, b0.s2, c32);
543 c33 = fma(a0.s3, b0.s3, c33);
544
545 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100546 a0 = vload4(0, src_addr_a);
547 b0 = vload4(0, src_addr_b);
548
549 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
550 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
551
552 c00 = fma(a0.s0, b0.s0, c00);
553 c01 = fma(a0.s0, b0.s1, c01);
554 c02 = fma(a0.s0, b0.s2, c02);
555 c03 = fma(a0.s0, b0.s3, c03);
556
557 c10 = fma(a0.s1, b0.s0, c10);
558 c11 = fma(a0.s1, b0.s1, c11);
559 c12 = fma(a0.s1, b0.s2, c12);
560 c13 = fma(a0.s1, b0.s3, c13);
561
562 c20 = fma(a0.s2, b0.s0, c20);
563 c21 = fma(a0.s2, b0.s1, c21);
564 c22 = fma(a0.s2, b0.s2, c22);
565 c23 = fma(a0.s2, b0.s3, c23);
566
567 c30 = fma(a0.s3, b0.s0, c30);
568 c31 = fma(a0.s3, b0.s1, c31);
569 c32 = fma(a0.s3, b0.s2, c32);
570 c33 = fma(a0.s3, b0.s3, c33);
571
572 // Load values from matrix A (interleaved) and matrix B (transposed)
573 a0 = vload4(0, src_addr_a);
574 b0 = vload4(0, src_addr_b);
575
576 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
577 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100578
579 c00 = fma(a0.s0, b0.s0, c00);
580 c01 = fma(a0.s0, b0.s1, c01);
581 c02 = fma(a0.s0, b0.s2, c02);
582 c03 = fma(a0.s0, b0.s3, c03);
583
584 c10 = fma(a0.s1, b0.s0, c10);
585 c11 = fma(a0.s1, b0.s1, c11);
586 c12 = fma(a0.s1, b0.s2, c12);
587 c13 = fma(a0.s1, b0.s3, c13);
588
589 c20 = fma(a0.s2, b0.s0, c20);
590 c21 = fma(a0.s2, b0.s1, c21);
591 c22 = fma(a0.s2, b0.s2, c22);
592 c23 = fma(a0.s2, b0.s3, c23);
593
594 c30 = fma(a0.s3, b0.s0, c30);
595 c31 = fma(a0.s3, b0.s1, c31);
596 c32 = fma(a0.s3, b0.s2, c32);
597 c33 = fma(a0.s3, b0.s3, c33);
598 }
599
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100600 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100601 {
602 // Load values from matrix A (interleaved) and matrix B (transposed)
603 float4 a0 = vload4(0, src_addr_a);
604 float4 b0 = vload4(0, src_addr_b);
605
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100606 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
607 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
608
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100609 c00 = fma(a0.s0, b0.s0, c00);
610 c01 = fma(a0.s0, b0.s1, c01);
611 c02 = fma(a0.s0, b0.s2, c02);
612 c03 = fma(a0.s0, b0.s3, c03);
613
614 c10 = fma(a0.s1, b0.s0, c10);
615 c11 = fma(a0.s1, b0.s1, c11);
616 c12 = fma(a0.s1, b0.s2, c12);
617 c13 = fma(a0.s1, b0.s3, c13);
618
619 c20 = fma(a0.s2, b0.s0, c20);
620 c21 = fma(a0.s2, b0.s1, c21);
621 c22 = fma(a0.s2, b0.s2, c22);
622 c23 = fma(a0.s2, b0.s3, c23);
623
624 c30 = fma(a0.s3, b0.s0, c30);
625 c31 = fma(a0.s3, b0.s1, c31);
626 c32 = fma(a0.s3, b0.s2, c32);
627 c33 = fma(a0.s3, b0.s3, c33);
628 }
629
630 // Compute destination address
631 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
632
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000633#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634 // Multiply by the weight of matrix product
635 c00 = c00 * ALPHA;
636 c01 = c01 * ALPHA;
637 c02 = c02 * ALPHA;
638 c03 = c03 * ALPHA;
639 c10 = c10 * ALPHA;
640 c11 = c11 * ALPHA;
641 c12 = c12 * ALPHA;
642 c13 = c13 * ALPHA;
643 c20 = c20 * ALPHA;
644 c21 = c21 * ALPHA;
645 c22 = c22 * ALPHA;
646 c23 = c23 * ALPHA;
647 c30 = c30 * ALPHA;
648 c31 = c31 * ALPHA;
649 c32 = c32 * ALPHA;
650 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000651#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100652
Gian Marcoae2af742018-02-15 12:35:44 +0000653 // Compute dst address
654 __global uchar *dst_addr = offset(&dst, 0, 0);
655
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000656#if defined(REINTERPRET_OUTPUT_AS_3D)
657 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100658 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000659 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100660 // | |
661 // | plane0 |
662 // | |
663 // |__________________|
664 // |******************|
665 // | cross_plane_pad |
666 // |******************|
667 // | |
668 // | plane1 |
669 // | |
670 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000671
672 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
673 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
674 zout = min(DEPTH_GEMM3D - 1, zout);
675
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100676 // Add offset due to the cross plane paddings
677 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000678
679 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
680 // multiply dst_stride_z by DEPTH_GEMM3D
681 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
682
683 // Store 4x4 block
684 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
685 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
686 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
687 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
688
689#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000690 // Add offset for batched GEMM
691 dst_addr += z * dst_stride_z;
692
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100693 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000694 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
695 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
696 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
697 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000698#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100699}
700
Georgios Pinitas84225582018-05-14 12:00:05 +0100701// Undefine local defines
702#undef COLS_MTX_B
703
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100704#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100705/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100706 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100707 *
Gian Marco19835e52018-01-30 13:35:54 +0000708 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
709 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
710 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000711 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
712 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100713 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000714 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
715 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
716 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
717 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
718 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
719 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100720 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
721 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
722 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
723 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
724 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
725 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100726 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100727 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
728 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
729 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
730 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
731 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100732 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100733 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000734 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100735 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000736 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000738 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
739 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
740 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100741 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100743__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
744 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000745 IMAGE_DECLARATION(dst),
746 uint src0_stride_z,
747 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000748 uint dst_stride_z
749#if defined(REINTERPRET_OUTPUT_AS_3D)
750 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100751 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000752#endif // REINTERPRET_OUTPUT_AS_3D
753 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100754{
Gian Marco36a0a462018-01-12 10:21:40 +0000755 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
756 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000757 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100758
Gian Marco36a0a462018-01-12 10:21:40 +0000759 // Offset
760 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
761 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100762
Gian Marco36a0a462018-01-12 10:21:40 +0000763 // src_addr_a = address of matrix A
764 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000765 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
766 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
767
768#if defined(MATRIX_B_DEPTH)
769 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
770 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
771#else // defined(MATRIX_B_DEPTH)
772 src1_addr_in_bytes += z * src1_stride_z;
773#endif // defined(MATRIX_B_DEPTH)
774
775 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
776 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100777
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000778 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000779 __global half *src_end_addr_b = src_addr_b + COLS_B;
780
781 src_addr_a += offset_row_a;
782 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100783
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000784 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100785 half8 c00 = 0.0f;
786 half8 c10 = 0.0f;
787 half8 c20 = 0.0f;
788 half8 c30 = 0.0f;
789
Gian Marco36a0a462018-01-12 10:21:40 +0000790 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100791 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000792 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000793 half4 a0 = vload4(0, src_addr_a);
794 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100795
796 c00 += (half8)a0.s0 * b0;
797 c10 += (half8)a0.s1 * b0;
798 c20 += (half8)a0.s2 * b0;
799 c30 += (half8)a0.s3 * b0;
800
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000801 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000802 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
803 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100804
805 c00 += (half8)a0.s0 * b0;
806 c10 += (half8)a0.s1 * b0;
807 c20 += (half8)a0.s2 * b0;
808 c30 += (half8)a0.s3 * b0;
809 }
810
Gian Marco36a0a462018-01-12 10:21:40 +0000811 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100812 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000813 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000814 half4 a0 = vload4(0, src_addr_a);
815 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100816
817 c00 += (half8)a0.s0 * b0;
818 c10 += (half8)a0.s1 * b0;
819 c20 += (half8)a0.s2 * b0;
820 c30 += (half8)a0.s3 * b0;
821 }
822
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000823 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100824 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
825
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000826#if defined(ALPHA)
827 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100828 c00 = c00 * (half8)ALPHA;
829 c10 = c10 * (half8)ALPHA;
830 c20 = c20 * (half8)ALPHA;
831 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000832#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100833
Gian Marcoae2af742018-02-15 12:35:44 +0000834 // Compute dst address
835 __global uchar *dst_addr = offset(&dst, 0, 0);
836
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000837#if defined(REINTERPRET_OUTPUT_AS_3D)
838 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100839 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000840 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100841 // | |
842 // | plane0 |
843 // | |
844 // |__________________|
845 // |******************|
846 // | cross_plane_pad |
847 // |******************|
848 // | |
849 // | plane1 |
850 // | |
851 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000852
853 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
854 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
855 zout = min(DEPTH_GEMM3D - 1, zout);
856
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100857 // Add offset due to the cross plane paddings
858 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000859
860 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
861 // multiply dst_stride_z by DEPTH_GEMM3D
862 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
863
864 // Store 4x8 block
865 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
866 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
867 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
868 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
869
870#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000871 // Add offset for batched GEMM
872 dst_addr += z * dst_stride_z;
873
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000874 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +0000875 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
876 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
877 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
878 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000879#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100880}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100881
882/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
883 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
884 *
885 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
886 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
887 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
888 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
889 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
890 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000891 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
892 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
893 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
894 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
895 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
896 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100897 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
898 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
899 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
900 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
901 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
902 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
903 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
904 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
905 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
906 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
907 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
908 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
909 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
910 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
911 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
912 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
913 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
914 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100915 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100916 */
917__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
918 IMAGE_DECLARATION(src1),
919 IMAGE_DECLARATION(dst),
920 uint src0_stride_z,
921 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000922 uint dst_stride_z
923#if defined(REINTERPRET_OUTPUT_AS_3D)
924 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100925 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000926#endif // REINTERPRET_OUTPUT_AS_3D
927 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100928{
929 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
930 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
931 int z = get_global_id(2);
932
933 // Offset
934 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
935 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
936
937 // src_addr_a = address of matrix A
938 // src_addr_b = address of matrix B
939 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
940 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
941
942#if defined(MATRIX_B_DEPTH)
943 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
944 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
945#else // defined(MATRIX_B_DEPTH)
946 src1_addr_in_bytes += z * src1_stride_z;
947#endif // defined(MATRIX_B_DEPTH)
948
949 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
950 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
951
952 // Compute end row address for matrix B
953 __global half *src_end_addr_b = src_addr_b + COLS_B;
954
955 src_addr_a += offset_row_a;
956 src_addr_b += offset_row_b;
957
958 // Reset accumulators
959 half8 c00 = 0.0f;
960 half8 c10 = 0.0f;
961 half8 c20 = 0.0f;
962 half8 c30 = 0.0f;
963
964#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
965
966 int i = 0;
967 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
968 {
969#if MULT_INTERLEAVE4X4_HEIGHT == 1
970 // Load values from matrix A (interleaved) and matrix B (transposed)
971 half8 a0 = vload8(0, src_addr_a);
972 half8 b0 = vload8(0, src_addr_b);
973
974 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
975 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
976
977 c00 = fma((half8)a0.s0, b0, c00);
978 c10 = fma((half8)a0.s1, b0, c10);
979 c20 = fma((half8)a0.s2, b0, c20);
980 c30 = fma((half8)a0.s3, b0, c30);
981
982 // Load values from matrix B (transposed)
983 b0 = vload8(0, src_addr_b);
984
985 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
986
987 c00 = fma((half8)a0.s4, b0, c00);
988 c10 = fma((half8)a0.s5, b0, c10);
989 c20 = fma((half8)a0.s6, b0, c20);
990 c30 = fma((half8)a0.s7, b0, c30);
991
992 // Load values from matrix A (interleaved) and matrix B (transposed)
993 a0 = vload8(0, src_addr_a);
994 b0 = vload8(0, src_addr_b);
995
996 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
997 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
998
999 c00 = fma((half8)a0.s0, b0, c00);
1000 c10 = fma((half8)a0.s1, b0, c10);
1001 c20 = fma((half8)a0.s2, b0, c20);
1002 c30 = fma((half8)a0.s3, b0, c30);
1003
1004 // Load values from matrix B (transposed)
1005 b0 = vload8(0, src_addr_b);
1006
1007 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1008
1009 c00 = fma((half8)a0.s4, b0, c00);
1010 c10 = fma((half8)a0.s5, b0, c10);
1011 c20 = fma((half8)a0.s6, b0, c20);
1012 c30 = fma((half8)a0.s7, b0, c30);
1013#else // MULT_INTERLEAVE4X4_HEIGHT == 1
1014 // Load values from matrix A (interleaved) and matrix B (transposed)
1015 half4 a0 = vload4(0, src_addr_a);
1016 half8 b0 = vload8(0, src_addr_b);
1017
1018 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1019 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1020
1021 c00 = fma((half8)a0.s0, b0, c00);
1022 c10 = fma((half8)a0.s1, b0, c10);
1023 c20 = fma((half8)a0.s2, b0, c20);
1024 c30 = fma((half8)a0.s3, b0, c30);
1025
1026 // Load values from matrix A (interleaved) and matrix B (transposed)
1027 a0 = vload4(0, src_addr_a);
1028 b0 = vload8(0, src_addr_b);
1029
1030 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1031 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1032
1033 c00 = fma((half8)a0.s0, b0, c00);
1034 c10 = fma((half8)a0.s1, b0, c10);
1035 c20 = fma((half8)a0.s2, b0, c20);
1036 c30 = fma((half8)a0.s3, b0, c30);
1037
1038 // Load values from matrix A (interleaved) and matrix B (transposed)
1039 a0 = vload4(0, src_addr_a);
1040 b0 = vload8(0, src_addr_b);
1041
1042 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1043 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1044
1045 c00 = fma((half8)a0.s0, b0, c00);
1046 c10 = fma((half8)a0.s1, b0, c10);
1047 c20 = fma((half8)a0.s2, b0, c20);
1048 c30 = fma((half8)a0.s3, b0, c30);
1049
1050 // Load values from matrix A (interleaved) and matrix B (transposed)
1051 a0 = vload4(0, src_addr_a);
1052 b0 = vload8(0, src_addr_b);
1053
1054 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1055 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1056
1057 c00 = fma((half8)a0.s0, b0, c00);
1058 c10 = fma((half8)a0.s1, b0, c10);
1059 c20 = fma((half8)a0.s2, b0, c20);
1060 c30 = fma((half8)a0.s3, b0, c30);
1061#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
1062 }
1063
1064 for(; i < (int)(COLS_MTX_B); ++i)
1065 {
1066 // Load values from matrix A (interleaved) and matrix B (transposed)
1067 half4 a0 = vload4(0, src_addr_a);
1068 half8 b0 = vload8(0, src_addr_b);
1069
1070 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1071 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1072
1073 c00 = fma((half8)a0.s0, b0, c00);
1074 c10 = fma((half8)a0.s1, b0, c10);
1075 c20 = fma((half8)a0.s2, b0, c20);
1076 c30 = fma((half8)a0.s3, b0, c30);
1077 }
1078
1079 // Compute destination address
1080 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1081
1082#if defined(ALPHA)
1083 // Multiply by the weight of matrix product
1084 c00 = c00 * (half8)ALPHA;
1085 c10 = c10 * (half8)ALPHA;
1086 c20 = c20 * (half8)ALPHA;
1087 c30 = c30 * (half8)ALPHA;
1088#endif // defined(ALPHA)
1089
1090 // Compute dst address
1091 __global uchar *dst_addr = offset(&dst, 0, 0);
1092
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001093#if defined(REINTERPRET_OUTPUT_AS_3D)
1094 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001095 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001096 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001097 // | |
1098 // | plane0 |
1099 // | |
1100 // |__________________|
1101 // |******************|
1102 // | cross_plane_pad |
1103 // |******************|
1104 // | |
1105 // | plane1 |
1106 // | |
1107 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001108
1109 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1110 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1111 zout = min(DEPTH_GEMM3D - 1, zout);
1112
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001113 // Add offset due to the cross plane paddings
1114 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001115
1116 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1117 // multiply dst_stride_z by DEPTH_GEMM3D
1118 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1119
1120 // Store 4x8 block
1121 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1122 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1123 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1124 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1125
1126#else // defined(REINTERPRET_OUTPUT_AS_3D)
1127 // Add offset for batched GEMM
1128 dst_addr += z * dst_stride_z;
1129
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001130 // Store 4x8 block
1131 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1132 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1133 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1134 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001135#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001136}
Georgios Pinitas84225582018-05-14 12:00:05 +01001137
1138// Undefine local defines
1139#undef COLS_MTX_B
1140
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001141#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142
Gian Marco36a0a462018-01-12 10:21:40 +00001143#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001144
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001145#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1146#if defined(DATA_TYPE)
1147#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001148/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001149 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001150 * @note This OpenCL kernel works with floating point data types (F16/F32)
1151 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1152 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001153 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001154 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1155 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001156 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001157 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1158 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001159 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1160 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1161 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1162 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1163 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001164 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001165 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1166 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1167 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1168 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1169 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001170 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001171 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1172 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1173 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1174 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1175 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001176 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001177 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1178 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1179 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1180 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1181 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001182 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1183 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1184 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001185 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1186 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001187 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001188__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1189 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001190 IMAGE_DECLARATION(dst),
1191 uint src0_stride_z,
1192 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001193 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001194#if defined(REINTERPRET_INPUT_AS_3D)
1195 ,
1196 uint src_cross_plane_pad
1197#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001198#if defined(REINTERPRET_OUTPUT_AS_3D)
1199 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001200 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001201#endif // REINTERPRET_OUTPUT_AS_3D
1202 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001203{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001204 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001205
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001206 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001207 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001208
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001209 // Update address for the matrix A
1210 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001211
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001212 // Update address for the matrix B
1213 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001214
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001215#if defined(REINTERPRET_INPUT_AS_3D)
1216 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1217 // in order to take into account the presence of possible cross plane paddings
1218 //
1219 // | |
1220 // | plane0 |
1221 // | |
1222 // |__________________|
1223 // |******************|
1224 // | cross_plane_pad |
1225 // |******************|
1226 // | |
1227 // | plane1 |
1228 // | |
1229 // |__________________|
1230
1231 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1232 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1233 zin = min(DEPTH_GEMM3D - 1, zin);
1234
1235 // Add offset due to the cross plane paddings
1236 zin *= (src_cross_plane_pad * src0_stride_y);
1237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply src0_stride_z by DEPTH_GEMM3D
1240 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_INPUT_AS_3D)
1243
Gian Marcoae2af742018-02-15 12:35:44 +00001244 // Add offset for batched GEMM
1245 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001246
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001247#endif // defined(REINTERPRET_INPUT_AS_3D)
1248
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001249#if defined(MATRIX_B_DEPTH)
1250 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1251 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1252#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001253 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001254#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001255
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001256 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1257
1258 VECTOR_TYPE acc0 = 0.0f;
1259#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1260 VECTOR_TYPE acc1 = 0.0f;
1261#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1262#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1263 VECTOR_TYPE acc2 = 0.0f;
1264#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1265#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1266 VECTOR_TYPE acc3 = 0.0f;
1267#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1268
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001269 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001270 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001271#if defined(REINTERPRET_INPUT_AS_3D)
1272 // Load values from matrix A
1273 VEC_DATA_TYPE(DATA_TYPE, 2)
1274 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1276 VEC_DATA_TYPE(DATA_TYPE, 2)
1277 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1278#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1279#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1280 VEC_DATA_TYPE(DATA_TYPE, 2)
1281 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1282#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1283#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1284 VEC_DATA_TYPE(DATA_TYPE, 2)
1285 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1287#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001288 // Load values from matrix A
1289 VEC_DATA_TYPE(DATA_TYPE, 2)
1290 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1291#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1292 VEC_DATA_TYPE(DATA_TYPE, 2)
1293 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1294#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1295#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1296 VEC_DATA_TYPE(DATA_TYPE, 2)
1297 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1299#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1300 VEC_DATA_TYPE(DATA_TYPE, 2)
1301 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1302#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001303#endif // defined(REINTERPRET_INPUT_AS_3D)
1304
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001305 // Load values from matrix B
1306 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1307 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001308
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001309 // Accumulate
1310 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1311 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1313 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1314 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1315#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1316#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1317 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1318 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1321 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1322 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1323#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001324 }
1325
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001326 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001327 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001328#if defined(REINTERPRET_INPUT_AS_3D)
1329 // Load values from matrix A
1330 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1331#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1332 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1333#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1334#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1335 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1336#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1337#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1338 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1339#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1340#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001341 // Load values from matrix A
1342 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1343#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1344 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1345#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1346#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1347 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1349#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1350 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1351#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001352#endif // defined(REINTERPRET_INPUT_AS_3D)
1353
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001354 // Load values from matrix B
1355 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001356
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001357 // Accumulate
1358 acc0 += b0 * (VECTOR_TYPE)a0;
1359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1360 acc1 += b0 * (VECTOR_TYPE)a1;
1361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1362#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1363 acc2 += b0 * (VECTOR_TYPE)a2;
1364#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1365#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1366 acc3 += b0 * (VECTOR_TYPE)a3;
1367#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001368 }
1369
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001370 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001371 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1372
Gian Marcoae2af742018-02-15 12:35:44 +00001373 // Compute dst address
1374 __global uchar *dst_addr = offset(&dst, 0, 0);
1375
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001376 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001377#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001378 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001379#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001380#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1381 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1382#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1384 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1387 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1389
1390 int z = get_global_id(2);
1391
1392#if defined(REINTERPRET_OUTPUT_AS_3D)
1393 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001394 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001395 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001396 // | |
1397 // | plane0 |
1398 // | |
1399 // |__________________|
1400 // |******************|
1401 // | cross_plane_pad |
1402 // |******************|
1403 // | |
1404 // | plane1 |
1405 // | |
1406 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001407
1408 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1409 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1410 zout = min(DEPTH_GEMM3D - 1, zout);
1411
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001412 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001413 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001414
1415 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1416 // multiply dst_stride_z by DEPTH_GEMM3D
1417 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1418
1419 // Store output block
1420 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1421 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
1422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1423 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1424 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
1425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1426#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1427 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1428 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
1429#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1430#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1431 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1432 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
1433#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1434
1435#else // defined(REINTERPRET_OUTPUT_AS_3D)
1436 // Add offset for batched GEMM
1437 dst_addr += z * dst_stride_z;
1438
1439 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001440 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001441 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001443 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001444 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001445#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1446#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001447 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001448 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001451 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001452 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001453#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001454#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001455}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001456#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001457
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001458/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001459 *
1460 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1461 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1462 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1463 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1464 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001465 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1466 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001467 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001468 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1469 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001470 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1471 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1472 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1473 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1474 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001475 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1476 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1477 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1478 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1479 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1480 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1481 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1482 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1483 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1484 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1485 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1486 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1487 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1488 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1489 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1490 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1491 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1492 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001493 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1494 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1495 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001496 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1497 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001498 */
1499__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1500 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001501 IMAGE_DECLARATION(dst),
1502 uint src0_stride_z,
1503 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001504 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001505#if defined(REINTERPRET_INPUT_AS_3D)
1506 ,
1507 uint src_cross_plane_pad
1508#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001509#if defined(REINTERPRET_OUTPUT_AS_3D)
1510 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001511 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001512#endif // REINTERPRET_OUTPUT_AS_3D
1513 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001514{
1515 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1516
1517 // Compute starting address for matrix A and matrix B
1518 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1519
1520 // Update address for matrix A
1521 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1522
1523 // Update address for matrix B
1524 src_addr.s1 += idx * sizeof(float);
1525
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001526#if defined(REINTERPRET_INPUT_AS_3D)
1527 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1528 // in order to take into account the presence of possible cross plane paddings
1529 //
1530 // | |
1531 // | plane0 |
1532 // | |
1533 // |__________________|
1534 // |******************|
1535 // | cross_plane_pad |
1536 // |******************|
1537 // | |
1538 // | plane1 |
1539 // | |
1540 // |__________________|
1541
1542 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1543 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1544 zin = min(DEPTH_GEMM3D - 1, zin);
1545
1546 // Add offset due to the cross plane paddings
1547 zin *= (src_cross_plane_pad * src0_stride_y);
1548
1549 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1550 // multiply src0_stride_z by DEPTH_GEMM3D
1551 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1552
1553#else // defined(REINTERPRET_INPUT_AS_3D)
1554
Gian Marcoae2af742018-02-15 12:35:44 +00001555 // Add offset for batched GEMM
1556 src_addr.s0 += get_global_id(2) * src0_stride_z;
1557
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001558#endif // defined(REINTERPRET_INPUT_AS_3D)
1559
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001560#if defined(MATRIX_B_DEPTH)
1561 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1562 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1563#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001564 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001565#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001566
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001567 // Initialize accumulators
1568 float acc00 = 0.0f;
1569 float acc01 = 0.0f;
1570 float acc02 = 0.0f;
1571 float acc03 = 0.0f;
1572
1573#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1574 float acc10 = 0.0f;
1575 float acc11 = 0.0f;
1576 float acc12 = 0.0f;
1577 float acc13 = 0.0f;
1578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1579
1580#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1581 float acc20 = 0.0f;
1582 float acc21 = 0.0f;
1583 float acc22 = 0.0f;
1584 float acc23 = 0.0f;
1585#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1586
1587#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1588 float acc30 = 0.0f;
1589 float acc31 = 0.0f;
1590 float acc32 = 0.0f;
1591 float acc33 = 0.0f;
1592#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1593
1594 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001595 int i = 0;
1596 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001597 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001598#if defined(REINTERPRET_INPUT_AS_3D)
1599 // Load values from matrix A and matrix B
1600 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1601#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1602 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1603#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1604#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1605 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1606#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1607#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1608 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1609#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1610#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001611 // Load values from matrix A and matrix B
1612 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001614 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1616#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001617 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001620 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001621#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001622#endif // defined(REINTERPRET_INPUT_AS_3D)
1623
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001624 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1625 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001626
1627 // Multiply and accumulate
1628 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001629 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001630 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001631 acc03 = fma(a0.s0, b0.s3, acc03);
1632
1633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001634
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001635 acc10 = fma(a1.s0, b0.s0, acc10);
1636 acc11 = fma(a1.s0, b0.s1, acc11);
1637 acc12 = fma(a1.s0, b0.s2, acc12);
1638 acc13 = fma(a1.s0, b0.s3, acc13);
1639
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001640#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001642
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001643 acc20 = fma(a2.s0, b0.s0, acc20);
1644 acc21 = fma(a2.s0, b0.s1, acc21);
1645 acc22 = fma(a2.s0, b0.s2, acc22);
1646 acc23 = fma(a2.s0, b0.s3, acc23);
1647
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001648#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1649#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001651 acc30 = fma(a3.s0, b0.s0, acc30);
1652 acc31 = fma(a3.s0, b0.s1, acc31);
1653 acc32 = fma(a3.s0, b0.s2, acc32);
1654 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001655#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001656
1657 // Load values from matrix A and matrix B
1658 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1659 src_addr.s1 += src1_stride_y;
1660
1661 // Multiply and accumulate
1662 acc00 = fma(a0.s1, b0.s0, acc00);
1663 acc01 = fma(a0.s1, b0.s1, acc01);
1664 acc02 = fma(a0.s1, b0.s2, acc02);
1665 acc03 = fma(a0.s1, b0.s3, acc03);
1666
1667#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1668
1669 acc10 = fma(a1.s1, b0.s0, acc10);
1670 acc11 = fma(a1.s1, b0.s1, acc11);
1671 acc12 = fma(a1.s1, b0.s2, acc12);
1672 acc13 = fma(a1.s1, b0.s3, acc13);
1673
1674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1675#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1676
1677 acc20 = fma(a2.s1, b0.s0, acc20);
1678 acc21 = fma(a2.s1, b0.s1, acc21);
1679 acc22 = fma(a2.s1, b0.s2, acc22);
1680 acc23 = fma(a2.s1, b0.s3, acc23);
1681
1682#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1683#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1684
1685 acc30 = fma(a3.s1, b0.s0, acc30);
1686 acc31 = fma(a3.s1, b0.s1, acc31);
1687 acc32 = fma(a3.s1, b0.s2, acc32);
1688 acc33 = fma(a3.s1, b0.s3, acc33);
1689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1690
1691 // Load values from matrix A and matrix B
1692 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1693 src_addr.s1 += src1_stride_y;
1694
1695 // Multiply and accumulate
1696 acc00 = fma(a0.s2, b0.s0, acc00);
1697 acc01 = fma(a0.s2, b0.s1, acc01);
1698 acc02 = fma(a0.s2, b0.s2, acc02);
1699 acc03 = fma(a0.s2, b0.s3, acc03);
1700
1701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1702
1703 acc10 = fma(a1.s2, b0.s0, acc10);
1704 acc11 = fma(a1.s2, b0.s1, acc11);
1705 acc12 = fma(a1.s2, b0.s2, acc12);
1706 acc13 = fma(a1.s2, b0.s3, acc13);
1707
1708#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1709#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1710
1711 acc20 = fma(a2.s2, b0.s0, acc20);
1712 acc21 = fma(a2.s2, b0.s1, acc21);
1713 acc22 = fma(a2.s2, b0.s2, acc22);
1714 acc23 = fma(a2.s2, b0.s3, acc23);
1715
1716#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1717#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1718
1719 acc30 = fma(a3.s2, b0.s0, acc30);
1720 acc31 = fma(a3.s2, b0.s1, acc31);
1721 acc32 = fma(a3.s2, b0.s2, acc32);
1722 acc33 = fma(a3.s2, b0.s3, acc33);
1723#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1724
1725 // Load values from matrix A and matrix B
1726 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1727 src_addr.s1 += src1_stride_y;
1728
1729 // Multiply and accumulate
1730 acc00 = fma(a0.s3, b0.s0, acc00);
1731 acc01 = fma(a0.s3, b0.s1, acc01);
1732 acc02 = fma(a0.s3, b0.s2, acc02);
1733 acc03 = fma(a0.s3, b0.s3, acc03);
1734
1735#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1736
1737 acc10 = fma(a1.s3, b0.s0, acc10);
1738 acc11 = fma(a1.s3, b0.s1, acc11);
1739 acc12 = fma(a1.s3, b0.s2, acc12);
1740 acc13 = fma(a1.s3, b0.s3, acc13);
1741
1742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1744
1745 acc20 = fma(a2.s3, b0.s0, acc20);
1746 acc21 = fma(a2.s3, b0.s1, acc21);
1747 acc22 = fma(a2.s3, b0.s2, acc22);
1748 acc23 = fma(a2.s3, b0.s3, acc23);
1749
1750#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1751#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1752
1753 acc30 = fma(a3.s3, b0.s0, acc30);
1754 acc31 = fma(a3.s3, b0.s1, acc31);
1755 acc32 = fma(a3.s3, b0.s2, acc32);
1756 acc33 = fma(a3.s3, b0.s3, acc33);
1757#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1758
1759 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001760 }
1761
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001762 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001763 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001764#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001765 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001766 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1767#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1768 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1769#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1770#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1771 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1774 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1775#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1776#else // defined(REINTERPRET_INPUT_AS_3D)
1777 // Load values from matrix A
1778 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001779#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1780 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1781#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1782#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1783 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1784#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1786 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001788#endif // defined(REINTERPRET_INPUT_AS_3D)
1789
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001790 // Load values from matrix B
1791 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001792 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001793
1794 // Multiply and accumulate
1795 acc00 = fma(a0, b0.s0, acc00);
1796 acc01 = fma(a0, b0.s1, acc01);
1797 acc02 = fma(a0, b0.s2, acc02);
1798 acc03 = fma(a0, b0.s3, acc03);
1799#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1800 acc10 = fma(a1, b0.s0, acc10);
1801 acc11 = fma(a1, b0.s1, acc11);
1802 acc12 = fma(a1, b0.s2, acc12);
1803 acc13 = fma(a1, b0.s3, acc13);
1804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1806 acc20 = fma(a2, b0.s0, acc20);
1807 acc21 = fma(a2, b0.s1, acc21);
1808 acc22 = fma(a2, b0.s2, acc22);
1809 acc23 = fma(a2, b0.s3, acc23);
1810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1811#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1812 acc30 = fma(a3, b0.s0, acc30);
1813 acc31 = fma(a3, b0.s1, acc31);
1814 acc32 = fma(a3, b0.s2, acc32);
1815 acc33 = fma(a3, b0.s3, acc33);
1816#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001817
1818 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001819 }
1820
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001821 int z = get_global_id(2);
1822
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001823 // Compute destination address
1824 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1825
1826 // Multiply by the weight of matrix-matrix product and store the result
1827#if defined(ALPHA)
1828 acc00 = acc00 * ALPHA;
1829 acc01 = acc01 * ALPHA;
1830 acc02 = acc02 * ALPHA;
1831 acc03 = acc03 * ALPHA;
1832#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001833#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001834 acc10 = acc10 * ALPHA;
1835 acc11 = acc11 * ALPHA;
1836 acc12 = acc12 * ALPHA;
1837 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001840 acc20 = acc20 * ALPHA;
1841 acc21 = acc21 * ALPHA;
1842 acc22 = acc22 * ALPHA;
1843 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1845#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001846 acc30 = acc30 * ALPHA;
1847 acc31 = acc31 * ALPHA;
1848 acc32 = acc32 * ALPHA;
1849 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001850#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1851
1852 // Compute dst address
1853 __global uchar *dst_addr = offset(&dst, 0, 0);
1854
1855#if defined(REINTERPRET_OUTPUT_AS_3D)
1856 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001857 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001858 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001859 // | |
1860 // | plane0 |
1861 // | |
1862 // |__________________|
1863 // |******************|
1864 // | cross_plane_pad |
1865 // |******************|
1866 // | |
1867 // | plane1 |
1868 // | |
1869 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001870
1871 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1872 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1873 zout = min(DEPTH_GEMM3D - 1, zout);
1874
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001875 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001876 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001877
1878 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1879 // multiply dst_stride_z by DEPTH_GEMM3D
1880 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1881
1882 // Store the output block
1883 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1884#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1885 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1886#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1887#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1888 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1889#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1890#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1891 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001892#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001893
1894#else // defined(REINTERPRET_OUTPUT_AS_3D)
1895 // Add offset for batched GEMM
1896 dst_addr += z * dst_stride_z;
1897
1898 // Store the output block
1899 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1901 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1904 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1905#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1906#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1907 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1909#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001910}
1911
1912/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1913 *
1914 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1915 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1916 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1917 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1918 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1919 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001920 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1921 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001922 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001923 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1924 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001925 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1926 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1927 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1928 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1929 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001930 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1931 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1932 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1933 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1934 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1935 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1936 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1937 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1938 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1939 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1940 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1941 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1942 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1943 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1944 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1945 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1946 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1947 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001948 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1949 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1950 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001951 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1952 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001953 */
1954__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1955 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001956 IMAGE_DECLARATION(dst),
1957 uint src0_stride_z,
1958 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001959 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001960#if defined(REINTERPRET_INPUT_AS_3D)
1961 ,
1962 uint src_cross_plane_pad
1963#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001964#if defined(REINTERPRET_OUTPUT_AS_3D)
1965 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001966 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001967#endif // REINTERPRET_OUTPUT_AS_3D
1968 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001969{
1970 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1971 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1972
1973 // Compute starting address for matrix A and Matrix B
1974 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1975
1976 // Update address for the matrix A
1977 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1978
1979 // Update address for the matrix B
1980 src_addr.s1 += idx * sizeof(float);
1981
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001982#if defined(REINTERPRET_INPUT_AS_3D)
1983 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1984 // in order to take into account the presence of possible cross plane paddings
1985 //
1986 // | |
1987 // | plane0 |
1988 // | |
1989 // |__________________|
1990 // |******************|
1991 // | cross_plane_pad |
1992 // |******************|
1993 // | |
1994 // | plane1 |
1995 // | |
1996 // |__________________|
1997
1998 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1999 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2000 zin = min(DEPTH_GEMM3D - 1, zin);
2001
2002 // Add offset due to the cross plane paddings
2003 zin *= (src_cross_plane_pad * src0_stride_y);
2004
2005 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2006 // multiply src0_stride_z by DEPTH_GEMM3D
2007 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2008
2009#else // defined(REINTERPRET_INPUT_AS_3D)
2010
Gian Marcoae2af742018-02-15 12:35:44 +00002011 // Add offset for batched GEMM
2012 src_addr.s0 += get_global_id(2) * src0_stride_z;
2013
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002014#endif // defined(REINTERPRET_INPUT_AS_3D)
2015
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002016#if defined(MATRIX_B_DEPTH)
2017 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2018 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2019#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002020 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002021#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002022
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002023 // Initialize accumulators
2024 float acc00 = 0.0f;
2025 float acc01 = 0.0f;
2026
2027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2028 float acc10 = 0.0f;
2029 float acc11 = 0.0f;
2030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2031#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2032 float acc20 = 0.0f;
2033 float acc21 = 0.0f;
2034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2035#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2036 float acc30 = 0.0f;
2037 float acc31 = 0.0f;
2038#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2039
2040 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002041 int i = 0;
2042 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002043 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002044#if defined(REINTERPRET_INPUT_AS_3D)
2045 // Load values from matrix A
2046 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
2047#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002048 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002049 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002050#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002051
2052 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002053 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2054 src_addr.s1 += src1_stride_y;
2055 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2056 src_addr.s1 += src1_stride_y;
2057 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2058 src_addr.s1 += src1_stride_y;
2059 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2060 src_addr.s1 += src1_stride_y;
2061 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2062 src_addr.s1 += src1_stride_y;
2063 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2064 src_addr.s1 += src1_stride_y;
2065 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2066 src_addr.s1 += src1_stride_y;
2067 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2068 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002069
2070 // Multiply and accumulate
2071 acc00 = fma(a0.s0, b0.s0, acc00);
2072 acc00 = fma(a0.s1, b1.s0, acc00);
2073 acc00 = fma(a0.s2, b2.s0, acc00);
2074 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002075 acc00 = fma(a0.s4, b4.s0, acc00);
2076 acc00 = fma(a0.s5, b5.s0, acc00);
2077 acc00 = fma(a0.s6, b6.s0, acc00);
2078 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002079
2080 acc01 = fma(a0.s0, b0.s1, acc01);
2081 acc01 = fma(a0.s1, b1.s1, acc01);
2082 acc01 = fma(a0.s2, b2.s1, acc01);
2083 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002084 acc01 = fma(a0.s4, b4.s1, acc01);
2085 acc01 = fma(a0.s5, b5.s1, acc01);
2086 acc01 = fma(a0.s6, b6.s1, acc01);
2087 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002088
2089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002090#if defined(REINTERPRET_INPUT_AS_3D)
2091 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2092#else // defined(REINTERPRET_INPUT_AS_3D)
2093 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2094#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002095 acc10 = fma(a0.s0, b0.s0, acc10);
2096 acc10 = fma(a0.s1, b1.s0, acc10);
2097 acc10 = fma(a0.s2, b2.s0, acc10);
2098 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002099 acc10 = fma(a0.s4, b4.s0, acc10);
2100 acc10 = fma(a0.s5, b5.s0, acc10);
2101 acc10 = fma(a0.s6, b6.s0, acc10);
2102 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002103
2104 acc11 = fma(a0.s0, b0.s1, acc11);
2105 acc11 = fma(a0.s1, b1.s1, acc11);
2106 acc11 = fma(a0.s2, b2.s1, acc11);
2107 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002108 acc11 = fma(a0.s4, b4.s1, acc11);
2109 acc11 = fma(a0.s5, b5.s1, acc11);
2110 acc11 = fma(a0.s6, b6.s1, acc11);
2111 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002112#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2113#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002114#if defined(REINTERPRET_INPUT_AS_3D)
2115 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2116#else // defined(REINTERPRET_INPUT_AS_3D)
2117 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2118#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002119 acc20 = fma(a0.s0, b0.s0, acc20);
2120 acc20 = fma(a0.s1, b1.s0, acc20);
2121 acc20 = fma(a0.s2, b2.s0, acc20);
2122 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002123 acc20 = fma(a0.s4, b4.s0, acc20);
2124 acc20 = fma(a0.s5, b5.s0, acc20);
2125 acc20 = fma(a0.s6, b6.s0, acc20);
2126 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002127
2128 acc21 = fma(a0.s0, b0.s1, acc21);
2129 acc21 = fma(a0.s1, b1.s1, acc21);
2130 acc21 = fma(a0.s2, b2.s1, acc21);
2131 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002132 acc21 = fma(a0.s4, b4.s1, acc21);
2133 acc21 = fma(a0.s5, b5.s1, acc21);
2134 acc21 = fma(a0.s6, b6.s1, acc21);
2135 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002136#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2137#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002138#if defined(REINTERPRET_INPUT_AS_3D)
2139 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2140#else // defined(REINTERPRET_INPUT_AS_3D)
2141 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2142#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002143 acc30 = fma(a0.s0, b0.s0, acc30);
2144 acc30 = fma(a0.s1, b1.s0, acc30);
2145 acc30 = fma(a0.s2, b2.s0, acc30);
2146 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002147 acc30 = fma(a0.s4, b4.s0, acc30);
2148 acc30 = fma(a0.s5, b5.s0, acc30);
2149 acc30 = fma(a0.s6, b6.s0, acc30);
2150 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002151
2152 acc31 = fma(a0.s0, b0.s1, acc31);
2153 acc31 = fma(a0.s1, b1.s1, acc31);
2154 acc31 = fma(a0.s2, b2.s1, acc31);
2155 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002156 acc31 = fma(a0.s4, b4.s1, acc31);
2157 acc31 = fma(a0.s5, b5.s1, acc31);
2158 acc31 = fma(a0.s6, b6.s1, acc31);
2159 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002160#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002161
2162 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002163 }
2164 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002165 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002166 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002167#if defined(REINTERPRET_INPUT_AS_3D)
2168 // Load values from matrix A
2169 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2171 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2172#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2173#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2174 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2175#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2177 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2178#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2179#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002180 // Load values from matrix A
2181 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2183 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2184#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2185#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2186 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2189 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2190#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002191#endif // defined(REINTERPRET_INPUT_AS_3D)
2192
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002193 // Load values from matrix B
2194 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002195 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002196
2197 // Multiply and accumulate
2198 acc00 = fma(a0, b0.s0, acc00);
2199 acc01 = fma(a0, b0.s1, acc01);
2200#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2201 acc10 = fma(a1, b0.s0, acc10);
2202 acc11 = fma(a1, b0.s1, acc11);
2203#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2204#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2205 acc20 = fma(a2, b0.s0, acc20);
2206 acc21 = fma(a2, b0.s1, acc21);
2207#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2208#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2209 acc30 = fma(a3, b0.s0, acc30);
2210 acc31 = fma(a3, b0.s1, acc31);
2211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002212
2213 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002214 }
2215
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002216 // Multiply by the weight of matrix-matrix product and store the result
2217#if defined(ALPHA)
2218 acc00 = acc00 * ALPHA;
2219 acc01 = acc01 * ALPHA;
2220#endif // defined(ALPHA)
2221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2222 acc10 = acc10 * ALPHA;
2223 acc11 = acc11 * ALPHA;
2224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2226 acc20 = acc20 * ALPHA;
2227 acc21 = acc21 * ALPHA;
2228#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2229#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2230 acc30 = acc30 * ALPHA;
2231 acc31 = acc31 * ALPHA;
2232#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2233
2234 int z = get_global_id(2);
2235
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002236 // Compute destination address
2237 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2238
Gian Marcoae2af742018-02-15 12:35:44 +00002239 // Compute dst address
2240 __global uchar *dst_addr = offset(&dst, 0, 0);
2241
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002242#if defined(REINTERPRET_OUTPUT_AS_3D)
2243 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002244 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002245 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002246 // | |
2247 // | plane0 |
2248 // | |
2249 // |__________________|
2250 // |******************|
2251 // | cross_plane_pad |
2252 // |******************|
2253 // | |
2254 // | plane1 |
2255 // | |
2256 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00002257
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002258 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2259 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2260 zout = min(DEPTH_GEMM3D - 1, zout);
2261
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002262 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002263 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002264
2265 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2266 // multiply dst_stride_z by DEPTH_GEMM3D
2267 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2268
2269 // Store the output block
2270 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002271#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002272 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002273#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2274#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002275 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002276#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2277#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002278 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002279#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002280
2281#else // defined(REINTERPRET_OUTPUT_AS_3D)
2282 // Add offset for batched GEMM
2283 dst_addr += z * dst_stride_z;
2284
2285 // Store the output block
2286 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2288 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2291 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2292#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2294 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2295#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2296#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002297}
2298
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002299#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002300/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2301 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00002302 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
2303 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2304 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2305 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2306 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2307 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2308 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2309 *
2310 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2311 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2312 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2313 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2314 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2315 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2316 *
2317 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2318 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2319 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2320 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2321 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2322 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2323 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2324 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2325 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2326 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2327 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2328 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2329 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2330 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2331 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2332 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2333 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2334 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2335 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2336 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2337 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2338 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2339 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2340 */
2341__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
2342 IMAGE_DECLARATION(src1),
2343 IMAGE_DECLARATION(dst),
2344 uint src0_stride_z,
2345 uint src1_stride_z,
2346 uint dst_stride_z
2347#if defined(REINTERPRET_INPUT_AS_3D)
2348 ,
2349 uint src_cross_plane_pad
2350#endif // REINTERPRET_INPUT_AS_3D
2351#if defined(REINTERPRET_OUTPUT_AS_3D)
2352 ,
2353 uint dst_cross_plane_pad
2354#endif // REINTERPRET_OUTPUT_AS_3D
2355 )
2356{
2357 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2358
2359 // Compute starting address for matrix A and Matrix B
2360 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2361
2362 // Update address for the matrix A
2363 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2364
2365 // Update address for the matrix B
2366 src_addr.s1 += idx * sizeof(half);
2367
2368#if defined(REINTERPRET_INPUT_AS_3D)
2369 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2370 // in order to take into account the presence of possible cross plane paddings
2371 //
2372 // | |
2373 // | plane0 |
2374 // | |
2375 // |__________________|
2376 // |******************|
2377 // | cross_plane_pad |
2378 // |******************|
2379 // | |
2380 // | plane1 |
2381 // | |
2382 // |__________________|
2383
2384 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2385 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2386 zin = min(DEPTH_GEMM3D - 1, zin);
2387
2388 // Add offset due to the cross plane paddings
2389 zin *= (src_cross_plane_pad * src0_stride_y);
2390
2391 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2392 // multiply src0_stride_z by DEPTH_GEMM3D
2393 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2394
2395#else // defined(REINTERPRET_INPUT_AS_3D)
2396
2397 // Add offset for batched GEMM
2398 src_addr.s0 += get_global_id(2) * src0_stride_z;
2399
2400#endif // defined(REINTERPRET_INPUT_AS_3D)
2401
2402#if defined(MATRIX_B_DEPTH)
2403 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2404 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2405#else // defined(MATRIX_B_DEPTH)
2406 src_addr.s1 += get_global_id(2) * src1_stride_z;
2407#endif // defined(MATRIX_B_DEPTH)
2408
2409 float8 acc0 = 0.0h;
2410#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2411 float8 acc1 = 0.0h;
2412#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2413#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2414 float8 acc2 = 0.0h;
2415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2416#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2417 float8 acc3 = 0.0h;
2418#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2419
2420 int i = 0;
2421 for(; i <= ((int)COLS_A - 4); i += 4)
2422 {
2423#if defined(REINTERPRET_INPUT_AS_3D)
2424 // Load values from matrix A
2425 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2426#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2427 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2428#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2429#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2430 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2431#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2432#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2433 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2434#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2435#else // defined(REINTERPRET_INPUT_AS_3D)
2436 // Load values from matrix A
2437 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2438#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2439 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2442 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2445 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2447#endif // defined(REINTERPRET_INPUT_AS_3D)
2448
2449 // Load values from matrix B
2450 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2451 src_addr.s1 += src1_stride_y;
2452
2453 // Accumulate
2454 acc0 = fma(b0, (float8)a0.s0, acc0);
2455#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2456 acc1 = fma(b0, (float8)a1.s0, acc1);
2457#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2458#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2459 acc2 = fma(b0, (float8)a2.s0, acc2);
2460#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2461#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2462 acc3 = fma(b0, (float8)a3.s0, acc3);
2463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2464
2465 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2466 src_addr.s1 += src1_stride_y;
2467 acc0 = fma(b0, (float8)a0.s1, acc0);
2468#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2469 acc1 = fma(b0, (float8)a1.s1, acc1);
2470#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2471#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2472 acc2 = fma(b0, (float8)a2.s1, acc2);
2473#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2474#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2475 acc3 = fma(b0, (float8)a3.s1, acc3);
2476#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2477
2478 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2479 src_addr.s1 += src1_stride_y;
2480 acc0 = fma(b0, (float8)a0.s2, acc0);
2481#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2482 acc1 = fma(b0, (float8)a1.s2, acc1);
2483#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2484#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2485 acc2 = fma(b0, (float8)a2.s2, acc2);
2486#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2487#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2488 acc3 = fma(b0, (float8)a3.s2, acc3);
2489#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2490
2491 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2492 src_addr.s1 += src1_stride_y;
2493 acc0 = fma(b0, (float8)a0.s3, acc0);
2494#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2495 acc1 = fma(b0, (float8)a1.s3, acc1);
2496#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2497#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2498 acc2 = fma(b0, (float8)a2.s3, acc2);
2499#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2500#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2501 acc3 = fma(b0, (float8)a3.s3, acc3);
2502#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2503
2504 src_addr.s0 += 4 * sizeof(half);
2505 }
2506
2507 for(; i < (int)COLS_A; ++i)
2508 {
2509#if defined(REINTERPRET_INPUT_AS_3D)
2510 // Load values from matrix A
2511 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2512#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2513 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2515#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2516 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2517#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2518#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2519 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2520#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2521#else // defined(REINTERPRET_INPUT_AS_3D)
2522 // Load values from matrix A
2523 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2524#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2525 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2526#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2527#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2528 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2529#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2530#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2531 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2532#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2533#endif // defined(REINTERPRET_INPUT_AS_3D)
2534
2535 // Load values from matrix B
2536 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2537
2538 src_addr += (int2)(sizeof(half), src1_stride_y);
2539
2540 // Accumulate
2541 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
2542#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2543 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
2544#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2546 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
2547#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2548#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2549 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
2550#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2551 }
2552
2553 // Multiply by the weight of matrix-matrix product and store the result
2554#if defined(ALPHA)
2555 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
2556#else //defined(ALPHA)
2557 half8 hacc0 = convert_half8(acc0);
2558#endif // defined(ALPHA)
2559#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2560#if defined(ALPHA)
2561 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
2562#else //defined(ALPHA)
2563 half8 hacc1 = convert_half8(acc1);
2564#endif //defined(ALPHA)
2565#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
2566
2567#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2568#if defined(ALPHA)
2569 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
2570#else //defined(ALPHA)
2571 half8 hacc2 = convert_half8(acc2);
2572#endif //defined(ALPHA)
2573#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2574
2575#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2576#if defined(ALPHA)
2577 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
2578#else //defined(ALPHA)
2579 half8 hacc3 = convert_half8(acc3);
2580#endif // defined(ALPHA)
2581#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2582
2583 int z = get_global_id(2);
2584
2585 // Compute destination address
2586 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2587
2588 // Compute dst address
2589 __global uchar *dst_addr = offset(&dst, 0, 0);
2590
2591#if defined(REINTERPRET_OUTPUT_AS_3D)
2592 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2593 // in order to take into account the presence of possible cross plane paddings
2594 //
2595 // | |
2596 // | plane0 |
2597 // | |
2598 // |__________________|
2599 // |******************|
2600 // | cross_plane_pad |
2601 // |******************|
2602 // | |
2603 // | plane1 |
2604 // | |
2605 // |__________________|
2606
2607 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2608 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2609 zout = min(DEPTH_GEMM3D - 1, zout);
2610
2611 // Add offset due to the cross plane paddings
2612 zout *= (dst_cross_plane_pad * dst_stride_y);
2613
2614 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2615 // multiply dst_stride_z by DEPTH_GEMM3D
2616 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2617
2618 // Store the output block
2619 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2620#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2621 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2622#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2623#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2624 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2625#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2626#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2627 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2628#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2629
2630#else // defined(REINTERPRET_OUTPUT_AS_3D)
2631 // Add offset for batched GEMM
2632 dst_addr += z * dst_stride_z;
2633
2634 // Store the output block
2635 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2637 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2639#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2640 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2643 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2645#endif // REINTERPRET_OUTPUT_AS_3D
2646}
2647
2648/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2649 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002650 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2651 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2652 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2653 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2654 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2655 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2656 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2657 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002658 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2659 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002660 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2661 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2662 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2663 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2664 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002665 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2666 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2667 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2668 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2669 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2670 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2671 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2672 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2673 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2674 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2675 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2676 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2677 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2678 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2679 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2680 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2681 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2682 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002683 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2684 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2685 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002686 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2687 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002688 */
2689__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2690 IMAGE_DECLARATION(src1),
2691 IMAGE_DECLARATION(dst),
2692 uint src0_stride_z,
2693 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002694 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002695#if defined(REINTERPRET_INPUT_AS_3D)
2696 ,
2697 uint src_cross_plane_pad
2698#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002699#if defined(REINTERPRET_OUTPUT_AS_3D)
2700 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002701 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002702#endif // REINTERPRET_OUTPUT_AS_3D
2703 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002704{
2705 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2706
2707 // Compute starting address for matrix A and Matrix B
2708 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2709
2710 // Update address for the matrix A
2711 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2712
2713 // Update address for the matrix B
2714 src_addr.s1 += idx * sizeof(half);
2715
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002716#if defined(REINTERPRET_INPUT_AS_3D)
2717 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2718 // in order to take into account the presence of possible cross plane paddings
2719 //
2720 // | |
2721 // | plane0 |
2722 // | |
2723 // |__________________|
2724 // |******************|
2725 // | cross_plane_pad |
2726 // |******************|
2727 // | |
2728 // | plane1 |
2729 // | |
2730 // |__________________|
2731
2732 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2733 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2734 zin = min(DEPTH_GEMM3D - 1, zin);
2735
2736 // Add offset due to the cross plane paddings
2737 zin *= (src_cross_plane_pad * src0_stride_y);
2738
2739 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2740 // multiply src0_stride_z by DEPTH_GEMM3D
2741 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2742
2743#else // defined(REINTERPRET_INPUT_AS_3D)
2744
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002745 // Add offset for batched GEMM
2746 src_addr.s0 += get_global_id(2) * src0_stride_z;
2747
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002748#endif // defined(REINTERPRET_INPUT_AS_3D)
2749
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002750#if defined(MATRIX_B_DEPTH)
2751 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2752 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2753#else // defined(MATRIX_B_DEPTH)
2754 src_addr.s1 += get_global_id(2) * src1_stride_z;
2755#endif // defined(MATRIX_B_DEPTH)
2756
2757 half8 acc0 = 0.0h;
2758#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2759 half8 acc1 = 0.0h;
2760#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2761#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2762 half8 acc2 = 0.0h;
2763#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2764#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2765 half8 acc3 = 0.0h;
2766#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2767
2768 int i = 0;
2769 for(; i <= ((int)COLS_A - 4); i += 4)
2770 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002771#if defined(REINTERPRET_INPUT_AS_3D)
2772 // Load values from matrix A
2773 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2774#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2775 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2776#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2777#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2778 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2779#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2780#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2781 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2783#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002784 // Load values from matrix A
2785 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2786#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2787 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2790 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2792#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2793 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2794#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002795#endif // defined(REINTERPRET_INPUT_AS_3D)
2796
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002797 // Load values from matrix B
2798 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2799 src_addr.s1 += src1_stride_y;
2800
2801 // Accumulate
2802 acc0 = fma(b0, (half8)a0.s0, acc0);
2803#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2804 acc1 = fma(b0, (half8)a1.s0, acc1);
2805#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2806#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2807 acc2 = fma(b0, (half8)a2.s0, acc2);
2808#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2809#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2810 acc3 = fma(b0, (half8)a3.s0, acc3);
2811#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2812
2813 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2814 src_addr.s1 += src1_stride_y;
2815 acc0 = fma(b0, (half8)a0.s1, acc0);
2816#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2817 acc1 = fma(b0, (half8)a1.s1, acc1);
2818#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2819#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2820 acc2 = fma(b0, (half8)a2.s1, acc2);
2821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2822#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2823 acc3 = fma(b0, (half8)a3.s1, acc3);
2824#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2825
2826 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2827 src_addr.s1 += src1_stride_y;
2828 acc0 = fma(b0, (half8)a0.s2, acc0);
2829#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2830 acc1 = fma(b0, (half8)a1.s2, acc1);
2831#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2832#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2833 acc2 = fma(b0, (half8)a2.s2, acc2);
2834#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2835#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2836 acc3 = fma(b0, (half8)a3.s2, acc3);
2837#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2838
2839 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2840 src_addr.s1 += src1_stride_y;
2841 acc0 = fma(b0, (half8)a0.s3, acc0);
2842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2843 acc1 = fma(b0, (half8)a1.s3, acc1);
2844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2845#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2846 acc2 = fma(b0, (half8)a2.s3, acc2);
2847#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2848#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2849 acc3 = fma(b0, (half8)a3.s3, acc3);
2850#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2851
2852 src_addr.s0 += 4 * sizeof(half);
2853 }
2854
2855 for(; i < (int)COLS_A; ++i)
2856 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002857#if defined(REINTERPRET_INPUT_AS_3D)
2858 // Load values from matrix A
2859 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2860#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2861 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2862#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2863#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2864 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2865#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2866#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2867 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2868#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2869#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002870 // Load values from matrix A
2871 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2872#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2873 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2874#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2875#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2876 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2877#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2878#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2879 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2880#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002881#endif // defined(REINTERPRET_INPUT_AS_3D)
2882
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002883 // Load values from matrix B
2884 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2885
2886 src_addr += (int2)(sizeof(half), src1_stride_y);
2887
2888 // Accumulate
2889 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2890#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2891 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2892#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2893#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2894 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2895#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2896#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2897 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2898#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2899 }
2900
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002901 // Multiply by the weight of matrix-matrix product and store the result
2902#if defined(ALPHA)
2903 acc0 = acc0 * (half8)ALPHA;
2904#endif // defined(ALPHA)
2905#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2906 acc1 = acc1 * (half8)ALPHA;
2907#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2908#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2909 acc2 = acc2 * (half8)ALPHA;
2910#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2911#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2912 acc3 = acc3 * (half8)ALPHA;
2913#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2914
2915 int z = get_global_id(2);
2916
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002917 // Compute destination address
2918 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2919
2920 // Compute dst address
2921 __global uchar *dst_addr = offset(&dst, 0, 0);
2922
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002923#if defined(REINTERPRET_OUTPUT_AS_3D)
2924 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002925 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002926 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002927 // | |
2928 // | plane0 |
2929 // | |
2930 // |__________________|
2931 // |******************|
2932 // | cross_plane_pad |
2933 // |******************|
2934 // | |
2935 // | plane1 |
2936 // | |
2937 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002938
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002939 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2940 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2941 zout = min(DEPTH_GEMM3D - 1, zout);
2942
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002943 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002944 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002945
2946 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2947 // multiply dst_stride_z by DEPTH_GEMM3D
2948 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2949
2950 // Store the output block
2951 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2952#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2953 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2954#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2955#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2956 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2957#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2958#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2959 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2960#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2961
2962#else // defined(REINTERPRET_OUTPUT_AS_3D)
2963 // Add offset for batched GEMM
2964 dst_addr += z * dst_stride_z;
2965
2966 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002967 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002969 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2971#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002972 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2973#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002975 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2976#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002977#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002978}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002979#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002980
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002981#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002982
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002983#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002984/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2985 *
Gian Marco19835e52018-01-30 13:35:54 +00002986 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002987 *
2988 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2989 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2990 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2991 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2992 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002993 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
2994 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002995 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002996 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002997 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2998 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2999 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3000 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003001 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3002 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003003 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3004 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003005__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
3006 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003007{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003008 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003009 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3010 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003011
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003012 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003013 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
3014
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003015 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003016 float4 c = vload4(0, (__global float *)src.ptr);
3017
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003018 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003019 float4 out = alpha_ab + (float4)BETA * c;
3020
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003021 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003022 vstore4(out, 0, (__global float *)dst.ptr);
3023}
3024
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01003025#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003026/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3027 *
Gian Marco19835e52018-01-30 13:35:54 +00003028 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003029 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003030 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
3031 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3032 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3033 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3034 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003035 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3036 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003037 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003038 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003039 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3040 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3041 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3042 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003043 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3044 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003045 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3046 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003047__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
3048 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003049{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003050 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003051 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3052 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003053
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003054 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003055 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
3056
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003057 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003058 half8 c = vload8(0, (__global half *)src.ptr);
3059
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003060 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003061 half8 out = alpha_ab + (half8)BETA * c;
3062
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003063 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003064 vstore8(out, 0, (__global half *)dst.ptr);
3065}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01003066#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003067#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003068
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003069#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003070/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
3071 *
Gian Marco19835e52018-01-30 13:35:54 +00003072 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003073 *
Gian Marco19835e52018-01-30 13:35:54 +00003074 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003075 *
3076 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3077 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3078 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3079 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3080 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3081 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003082 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003083 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3084 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3085 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3086 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3087 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3088 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3089 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003090 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003091 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3092 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3093 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3094 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3095 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3096 */
3097__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
3098 TENSOR3D_DECLARATION(src1),
3099 IMAGE_DECLARATION(dst))
3100{
3101 int idx = get_global_id(0) * 4;
3102 int idy = get_global_id(1);
3103
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003104 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003105 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
3106 src_addr.s1 += idx * sizeof(float);
3107
3108 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
3109
3110 float4 acc = 0.0f;
3111
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003112 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003113 {
3114 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
3115 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3116 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
3117
3118 acc += b0 * (float4)a0.s0;
3119 acc += b1 * (float4)a0.s1;
3120 }
3121
3122 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
3123 {
3124 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
3125 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3126
3127 acc += b0 * (float4)a0;
3128 }
3129
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003130 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003131 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3132
3133 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
3134}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003135#endif // defined(WIDTH_VECTOR_A)
3136
3137/** This kernel accumulates each row with the biases vector.
3138 *
3139 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
3140 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
3141 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01003142 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003143 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
3144 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
3145 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
3146 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3147 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
3148 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
3149 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
3150 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3151 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
3152 */
3153#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
3154__kernel void gemm_accumulate_biases(
3155 IMAGE_DECLARATION(accum),
3156 VECTOR_DECLARATION(biases))
3157{
3158 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
3159 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
3160
3161 // Vector size, i.e. number of vector elements.
3162 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3163 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
3164 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3165 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01003166 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003167 // Store result in the accumulate buffer
3168 VSTORE(VECTOR_SIZE)
3169 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
3170}
3171#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)