blob: c763cb355b385cf36cb6747d9786cfff1f21c6ce [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030/** This OpenCL kernel computes the "vector" 1x4 transposition of input matrix
31 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010032 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
34 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
35 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
36 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
37 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010038 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
40 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
41 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
42 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
43 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
44 */
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010045__kernel void gemm_transpose1x4(IMAGE_DECLARATION(src),
46 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047{
48 uint x = get_global_id(0);
49 uint y = get_global_id(1);
50
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010051 // Compute address for Matrix B - source
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052 Image src = CONVERT_TO_IMAGE_STRUCT(src);
53
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010054 // Compute address for Matrix B transposed - destination. X and Y are swapped
Anthony Barbier6ff3b192017-09-04 18:44:23 +010055 uint dst_addr_in_bytes = y * 16 + ((x * dst_stride_y + dst_offset_first_element_in_bytes));
56
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010057 uint4 b0 = vload4(0, (__global uint *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010058
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010059 vstore4(b0, 0, (__global uint *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010060}
61
62/** This OpenCL kernel computes the "vector" 1x8 transposition of input matrix
63 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010064 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U16/S16/QS16/F16
Anthony Barbier6ff3b192017-09-04 18:44:23 +010065 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010070 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010071 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
72 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
73 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
74 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
75 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
76 */
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010077__kernel void gemm_transpose1x8(IMAGE_DECLARATION(src),
78 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079{
80 uint x = get_global_id(0);
81 uint y = get_global_id(1);
82
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +000083 // Compute address for Matrix B - source
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084 Image src = CONVERT_TO_IMAGE_STRUCT(src);
85
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +000086 // Compute address for Matrix B transposed - destination. X and Y are swapped
Anthony Barbier6ff3b192017-09-04 18:44:23 +010087 uint dst_addr_in_bytes = y * 16 + ((x * dst_stride_y + dst_offset_first_element_in_bytes));
88
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010089 ushort8 b0 = vload8(0, (__global ushort *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010090
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010091 vstore8(b0, 0, (__global ushort *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092}
93
94/** This OpenCL kernel computes the "vector" 1x16 transposition of input matrix
95 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010096 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
101 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100102 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
104 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
105 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
106 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
107 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
108 */
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +0100109__kernel void gemm_transpose1x16(IMAGE_DECLARATION(src),
110 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100111{
112 uint x = get_global_id(0);
113 uint y = get_global_id(1);
114
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000115 // Compute address for Matrix B - source
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100116 Image src = CONVERT_TO_IMAGE_STRUCT(src);
117
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000118 // Compute address for Matrix B transposed - destination. X and Y are swapped
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100119 uint dst_addr_in_bytes = y * 16 + ((x * dst_stride_y + dst_offset_first_element_in_bytes));
120
121 uchar16 b0 = vload16(0, (__global uchar *)src.ptr);
122
123 vstore16(b0, 0, (__global uchar *)(dst_ptr + dst_addr_in_bytes));
124}
125
126/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
127 *
128 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U32/S32/F32
129 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
130 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
131 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
132 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
133 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100134 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100135 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
136 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
137 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
138 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
139 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
140 */
141__kernel void gemm_interleave4x4_32bit(IMAGE_DECLARATION(src),
142 IMAGE_DECLARATION(dst))
143{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000144 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100145 Image src = CONVERT_TO_IMAGE_STRUCT(src);
146 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
147
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000148 // Load values from Matrix A
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100149 uint4 a0 = vload4(0, (__global uint *)(offset(&src, 0, 0)));
150 uint4 a1 = vload4(0, (__global uint *)(offset(&src, 0, 1)));
151 uint4 a2 = vload4(0, (__global uint *)(offset(&src, 0, 2)));
152 uint4 a3 = vload4(0, (__global uint *)(offset(&src, 0, 3)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100153
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100154 uint4 val0 = (uint4)(a0.s0, a1.s0, a2.s0, a3.s0);
155 vstore4(val0, 0, ((__global uint *)dst.ptr) + 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100156
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100157 val0 = (uint4)(a0.s1, a1.s1, a2.s1, a3.s1);
158 vstore4(val0, 0, ((__global uint *)dst.ptr) + 4);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100160 val0 = (uint4)(a0.s2, a1.s2, a2.s2, a3.s2);
161 vstore4(val0, 0, ((__global uint *)dst.ptr) + 8);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100162
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100163 val0 = (uint4)(a0.s3, a1.s3, a2.s3, a3.s3);
164 vstore4(val0, 0, ((__global uint *)dst.ptr) + 12);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165}
166
167/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
168 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100169 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U16/S16/QS16/F16
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100170 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
171 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
172 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
173 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
174 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100175 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
177 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
178 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
179 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
180 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
181 */
182__kernel void gemm_interleave4x4_16bit(IMAGE_DECLARATION(src),
183 IMAGE_DECLARATION(dst))
184{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000185 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186 Image src = CONVERT_TO_IMAGE_STRUCT(src);
187 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
188
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000189 // Load values from Matrix A
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100190 ushort8 a0 = vload8(0, (__global ushort *)(offset(&src, 0, 0)));
191 ushort8 a1 = vload8(0, (__global ushort *)(offset(&src, 0, 1)));
192 ushort8 a2 = vload8(0, (__global ushort *)(offset(&src, 0, 2)));
193 ushort8 a3 = vload8(0, (__global ushort *)(offset(&src, 0, 3)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100195 ushort8 val0 = (ushort8)((ushort4)(a0.s0, a1.s0, a2.s0, a3.s0), (ushort4)(a0.s1, a1.s1, a2.s1, a3.s1));
196 vstore8(val0, 0, ((__global ushort *)dst.ptr) + 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100198 val0 = (ushort8)((ushort4)(a0.s2, a1.s2, a2.s2, a3.s2), (ushort4)(a0.s3, a1.s3, a2.s3, a3.s3));
199 vstore8(val0, 0, ((__global ushort *)dst.ptr) + 8);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100201 val0 = (ushort8)((ushort4)(a0.s4, a1.s4, a2.s4, a3.s4), (ushort4)(a0.s5, a1.s5, a2.s5, a3.s5));
202 vstore8(val0, 0, ((__global ushort *)dst.ptr) + 16);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203
Gian Marco Iodiceb93f5de2017-07-05 15:48:39 +0100204 val0 = (ushort8)((ushort4)(a0.s6, a1.s6, a2.s6, a3.s6), (ushort4)(a0.s7, a1.s7, a2.s7, a3.s7));
205 vstore8(val0, 0, ((__global ushort *)dst.ptr) + 24);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206}
207
208/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
209 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100210 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
212 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
213 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
214 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
215 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100216 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100217 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
218 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
219 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
220 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
221 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
222 */
223__kernel void gemm_interleave4x4_8bit(IMAGE_DECLARATION(src),
224 IMAGE_DECLARATION(dst))
225{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000226 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100227 Image src = CONVERT_TO_IMAGE_STRUCT(src);
228 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
229
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000230 // Load values from Matrix A
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231 uchar16 a0 = vload16(0, (__global uchar *)(offset(&src, 0, 0)));
232 uchar16 a1 = vload16(0, (__global uchar *)(offset(&src, 0, 1)));
233 uchar16 a2 = vload16(0, (__global uchar *)(offset(&src, 0, 2)));
234 uchar16 a3 = vload16(0, (__global uchar *)(offset(&src, 0, 3)));
235
236 uchar16 val0 = (uchar16)((uchar4)(a0.s0, a1.s0, a2.s0, a3.s0), (uchar4)(a0.s1, a1.s1, a2.s1, a3.s1),
237 (uchar4)(a0.s2, a1.s2, a2.s2, a3.s2), (uchar4)(a0.s3, a1.s3, a2.s3, a3.s3));
238 vstore16(val0, 0, ((__global uchar *)dst.ptr) + 0);
239
240 val0 = (uchar16)((uchar4)(a0.s4, a1.s4, a2.s4, a3.s4), (uchar4)(a0.s5, a1.s5, a2.s5, a3.s5),
241 (uchar4)(a0.s6, a1.s6, a2.s6, a3.s6), (uchar4)(a0.s7, a1.s7, a2.s7, a3.s7));
242 vstore16(val0, 0, ((__global uchar *)dst.ptr) + 16);
243
244 val0 = (uchar16)((uchar4)(a0.s8, a1.s8, a2.s8, a3.s8), (uchar4)(a0.s9, a1.s9, a2.s9, a3.s9),
245 (uchar4)(a0.sA, a1.sA, a2.sA, a3.sA), (uchar4)(a0.sB, a1.sB, a2.sB, a3.sB));
246 vstore16(val0, 0, ((__global uchar *)dst.ptr) + 32);
247
248 val0 = (uchar16)((uchar4)(a0.sC, a1.sC, a2.sC, a3.sC), (uchar4)(a0.sD, a1.sD, a2.sD, a3.sD),
249 (uchar4)(a0.sE, a1.sE, a2.sE, a3.sE), (uchar4)(a0.sF, a1.sF, a2.sF, a3.sF));
250 vstore16(val0, 0, ((__global uchar *)dst.ptr) + 48);
251}
252
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000253#if defined(COLS_B)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100254/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100255 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100256 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000257 * @attention The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100258 *
259 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
260 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
261 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
262 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
263 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
264 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100265 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
267 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
268 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
269 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
270 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100271 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
273 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
274 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
275 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
276 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
277 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100278__kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
279 IMAGE_DECLARATION(src1),
280 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100281{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000282 // src_addr.s0 = address of matrix A
283 // src_addr.s1 = address of matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000285 // Compute address for matrix A and B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 int2 src_addr = (int2)(get_global_id(1), get_global_id(0)) * (int2)((src0_stride_y),
287 (src1_stride_y));
288
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000289 // Add offset_first_element_in_bytes
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100290 src_addr = src_addr + ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
291
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000292 // Divide by 4 in order to get the src_addr in unit of float
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100293 src_addr = src_addr >> 2;
294
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000295 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100296 int end_row_mtx_b = src_addr.s1 + COLS_B;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100297
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000298 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100299 float4 c00 = 0.0f;
300 float4 c10 = 0.0f;
301 float4 c20 = 0.0f;
302 float4 c30 = 0.0f;
303
304 for(; src_addr.s1 <= (end_row_mtx_b - 8); src_addr += (int2)(8, 8))
305 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000306 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100307 float4 a0 = vload4(0, ((__global float *)src0_ptr) + src_addr.s0);
308 float4 b0 = vload4(0, ((__global float *)src1_ptr) + src_addr.s1);
309
310 c00 += (float4)a0.s0 * b0;
311 c10 += (float4)a0.s1 * b0;
312 c20 += (float4)a0.s2 * b0;
313 c30 += (float4)a0.s3 * b0;
314
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000315 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100316 a0 = vload4(0, ((__global float *)src0_ptr) + src_addr.s0 + 4);
317 b0 = vload4(0, ((__global float *)src1_ptr) + src_addr.s1 + 4);
318
319 c00 += (float4)a0.s0 * b0;
320 c10 += (float4)a0.s1 * b0;
321 c20 += (float4)a0.s2 * b0;
322 c30 += (float4)a0.s3 * b0;
323 }
324
325 for(; src_addr.s1 < end_row_mtx_b; src_addr += (int2)(4, 4))
326 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000327 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100328 float4 a0 = vload4(0, ((__global float *)src0_ptr) + src_addr.s0);
329 float4 b0 = vload4(0, ((__global float *)src1_ptr) + src_addr.s1);
330
331 c00 += (float4)a0.s0 * b0;
332 c10 += (float4)a0.s1 * b0;
333 c20 += (float4)a0.s2 * b0;
334 c30 += (float4)a0.s3 * b0;
335 }
336
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000337 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
339
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000340#if defined(ALPHA)
341 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342 c00 = c00 * (float4)ALPHA;
343 c10 = c10 * (float4)ALPHA;
344 c20 = c20 * (float4)ALPHA;
345 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000346#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000348 // Store 4x4 block
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349 vstore4(c00, 0, (__global float *)(offset(&dst, 0, 0)));
350 vstore4(c10, 0, (__global float *)(offset(&dst, 0, 1)));
351 vstore4(c20, 0, (__global float *)(offset(&dst, 0, 2)));
352 vstore4(c30, 0, (__global float *)(offset(&dst, 0, 3)));
353}
354
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000355/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100356 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100357 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000358 * @attention The number of matrix B columns and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100359 *
360 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
361 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
362 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
363 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
364 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
365 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100366 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100367 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
368 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
369 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
370 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
371 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100372 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100373 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
374 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
375 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
376 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
377 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
378 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100379__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
380 IMAGE_DECLARATION(src1),
381 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100382{
383 // src_addr_a = address of matrix A
384 // src_addr_b = address of matrix B
385 __global float *src_addr_a = (__global float *)(src0_ptr + get_global_id(1) * src0_stride_y + src0_offset_first_element_in_bytes);
386 __global float *src_addr_b = (__global float *)(src1_ptr + get_global_id(0) * src1_stride_y + src1_offset_first_element_in_bytes);
387
388 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100389 __global float *src_end_addr_b = src_addr_b + COLS_B;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100390
391 // Reset accumulators
392 float c00 = 0.0f;
393 float c01 = 0.0f;
394 float c02 = 0.0f;
395 float c03 = 0.0f;
396 float c10 = 0.0f;
397 float c11 = 0.0f;
398 float c12 = 0.0f;
399 float c13 = 0.0f;
400 float c20 = 0.0f;
401 float c21 = 0.0f;
402 float c22 = 0.0f;
403 float c23 = 0.0f;
404 float c30 = 0.0f;
405 float c31 = 0.0f;
406 float c32 = 0.0f;
407 float c33 = 0.0f;
408
409 for(; src_addr_b <= (src_end_addr_b - 16); src_addr_a += 16, src_addr_b += 16)
410 {
411 // Load values from matrix A (interleaved) and matrix B (transposed)
412 float4 a0 = vload4(0, src_addr_a);
413 float4 b0 = vload4(0, src_addr_b);
414
415 c00 = fma(a0.s0, b0.s0, c00);
416 c01 = fma(a0.s0, b0.s1, c01);
417 c02 = fma(a0.s0, b0.s2, c02);
418 c03 = fma(a0.s0, b0.s3, c03);
419
420 c10 = fma(a0.s1, b0.s0, c10);
421 c11 = fma(a0.s1, b0.s1, c11);
422 c12 = fma(a0.s1, b0.s2, c12);
423 c13 = fma(a0.s1, b0.s3, c13);
424
425 c20 = fma(a0.s2, b0.s0, c20);
426 c21 = fma(a0.s2, b0.s1, c21);
427 c22 = fma(a0.s2, b0.s2, c22);
428 c23 = fma(a0.s2, b0.s3, c23);
429
430 c30 = fma(a0.s3, b0.s0, c30);
431 c31 = fma(a0.s3, b0.s1, c31);
432 c32 = fma(a0.s3, b0.s2, c32);
433 c33 = fma(a0.s3, b0.s3, c33);
434
435 // Load values from matrix A (interleaved) and matrix B (transposed)
436 a0 = vload4(0, src_addr_a + 4);
437 b0 = vload4(0, src_addr_b + 4);
438
439 c00 = fma(a0.s0, b0.s0, c00);
440 c01 = fma(a0.s0, b0.s1, c01);
441 c02 = fma(a0.s0, b0.s2, c02);
442 c03 = fma(a0.s0, b0.s3, c03);
443
444 c10 = fma(a0.s1, b0.s0, c10);
445 c11 = fma(a0.s1, b0.s1, c11);
446 c12 = fma(a0.s1, b0.s2, c12);
447 c13 = fma(a0.s1, b0.s3, c13);
448
449 c20 = fma(a0.s2, b0.s0, c20);
450 c21 = fma(a0.s2, b0.s1, c21);
451 c22 = fma(a0.s2, b0.s2, c22);
452 c23 = fma(a0.s2, b0.s3, c23);
453
454 c30 = fma(a0.s3, b0.s0, c30);
455 c31 = fma(a0.s3, b0.s1, c31);
456 c32 = fma(a0.s3, b0.s2, c32);
457 c33 = fma(a0.s3, b0.s3, c33);
458
459 // Load values from matrix A (interleaved) and matrix B (transposed)
460 a0 = vload4(0, src_addr_a + 8);
461 b0 = vload4(0, src_addr_b + 8);
462
463 c00 = fma(a0.s0, b0.s0, c00);
464 c01 = fma(a0.s0, b0.s1, c01);
465 c02 = fma(a0.s0, b0.s2, c02);
466 c03 = fma(a0.s0, b0.s3, c03);
467
468 c10 = fma(a0.s1, b0.s0, c10);
469 c11 = fma(a0.s1, b0.s1, c11);
470 c12 = fma(a0.s1, b0.s2, c12);
471 c13 = fma(a0.s1, b0.s3, c13);
472
473 c20 = fma(a0.s2, b0.s0, c20);
474 c21 = fma(a0.s2, b0.s1, c21);
475 c22 = fma(a0.s2, b0.s2, c22);
476 c23 = fma(a0.s2, b0.s3, c23);
477
478 c30 = fma(a0.s3, b0.s0, c30);
479 c31 = fma(a0.s3, b0.s1, c31);
480 c32 = fma(a0.s3, b0.s2, c32);
481 c33 = fma(a0.s3, b0.s3, c33);
482
483 // Load values from matrix A (interleaved) and matrix B (transposed)
484 a0 = vload4(0, src_addr_a + 12);
485 b0 = vload4(0, src_addr_b + 12);
486
487 c00 = fma(a0.s0, b0.s0, c00);
488 c01 = fma(a0.s0, b0.s1, c01);
489 c02 = fma(a0.s0, b0.s2, c02);
490 c03 = fma(a0.s0, b0.s3, c03);
491
492 c10 = fma(a0.s1, b0.s0, c10);
493 c11 = fma(a0.s1, b0.s1, c11);
494 c12 = fma(a0.s1, b0.s2, c12);
495 c13 = fma(a0.s1, b0.s3, c13);
496
497 c20 = fma(a0.s2, b0.s0, c20);
498 c21 = fma(a0.s2, b0.s1, c21);
499 c22 = fma(a0.s2, b0.s2, c22);
500 c23 = fma(a0.s2, b0.s3, c23);
501
502 c30 = fma(a0.s3, b0.s0, c30);
503 c31 = fma(a0.s3, b0.s1, c31);
504 c32 = fma(a0.s3, b0.s2, c32);
505 c33 = fma(a0.s3, b0.s3, c33);
506 }
507
508 for(; src_addr_b < src_end_addr_b; src_addr_a += 4, src_addr_b += 4)
509 {
510 // Load values from matrix A (interleaved) and matrix B (transposed)
511 float4 a0 = vload4(0, src_addr_a);
512 float4 b0 = vload4(0, src_addr_b);
513
514 c00 = fma(a0.s0, b0.s0, c00);
515 c01 = fma(a0.s0, b0.s1, c01);
516 c02 = fma(a0.s0, b0.s2, c02);
517 c03 = fma(a0.s0, b0.s3, c03);
518
519 c10 = fma(a0.s1, b0.s0, c10);
520 c11 = fma(a0.s1, b0.s1, c11);
521 c12 = fma(a0.s1, b0.s2, c12);
522 c13 = fma(a0.s1, b0.s3, c13);
523
524 c20 = fma(a0.s2, b0.s0, c20);
525 c21 = fma(a0.s2, b0.s1, c21);
526 c22 = fma(a0.s2, b0.s2, c22);
527 c23 = fma(a0.s2, b0.s3, c23);
528
529 c30 = fma(a0.s3, b0.s0, c30);
530 c31 = fma(a0.s3, b0.s1, c31);
531 c32 = fma(a0.s3, b0.s2, c32);
532 c33 = fma(a0.s3, b0.s3, c33);
533 }
534
535 // Compute destination address
536 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
537
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000538#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539 // Multiply by the weight of matrix product
540 c00 = c00 * ALPHA;
541 c01 = c01 * ALPHA;
542 c02 = c02 * ALPHA;
543 c03 = c03 * ALPHA;
544 c10 = c10 * ALPHA;
545 c11 = c11 * ALPHA;
546 c12 = c12 * ALPHA;
547 c13 = c13 * ALPHA;
548 c20 = c20 * ALPHA;
549 c21 = c21 * ALPHA;
550 c22 = c22 * ALPHA;
551 c23 = c23 * ALPHA;
552 c30 = c30 * ALPHA;
553 c31 = c31 * ALPHA;
554 c32 = c32 * ALPHA;
555 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000556#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100557
558 barrier(CLK_GLOBAL_MEM_FENCE);
559
560 // Store 4x4 block
561 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(offset(&dst, 0, 0)));
562 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(offset(&dst, 0, 1)));
563 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(offset(&dst, 0, 2)));
564 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(offset(&dst, 0, 3)));
565}
566
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100567#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100569 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000571 * @attention The number of matrix B columns and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100572 *
573 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
574 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
575 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
576 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
577 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
578 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100579 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100580 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
581 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
582 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
583 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
584 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100585 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100586 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
587 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
588 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
589 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
590 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
591 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100592__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
593 IMAGE_DECLARATION(src1),
594 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100595{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000596 // src_addr.s0 = address of matrix A
597 // src_addr.s1 = address of matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100598
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000599 // Compute address for matrix A and B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100600 int2 src_addr = (int2)(get_global_id(1), get_global_id(0)) * (int2)((src0_stride_y),
601 (src1_stride_y));
602
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000603 // Add offset_first_element_in_bytes
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100604 src_addr = src_addr + ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000606 // Divide by 2 in order to get the src_addr in unit of half
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607 src_addr = src_addr >> 1;
608
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000609 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100610 int end_row_mtx_b = src_addr.s1 + COLS_B;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100611
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000612 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100613 half8 c00 = 0.0f;
614 half8 c10 = 0.0f;
615 half8 c20 = 0.0f;
616 half8 c30 = 0.0f;
617
Moritz Pflanzere49e2662017-07-21 15:55:28 +0100618 for(; src_addr.s1 <= (end_row_mtx_b - 16); src_addr += (int2)(8, 16))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100619 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000620 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100621 half4 a0 = vload4(0, ((__global half *)src0_ptr) + src_addr.s0);
622 half8 b0 = vload8(0, ((__global half *)src1_ptr) + src_addr.s1);
623
624 c00 += (half8)a0.s0 * b0;
625 c10 += (half8)a0.s1 * b0;
626 c20 += (half8)a0.s2 * b0;
627 c30 += (half8)a0.s3 * b0;
628
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000629 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100630 a0 = vload4(0, ((__global half *)src0_ptr) + src_addr.s0 + 4);
631 b0 = vload8(0, ((__global half *)src1_ptr) + src_addr.s1 + 8);
632
633 c00 += (half8)a0.s0 * b0;
634 c10 += (half8)a0.s1 * b0;
635 c20 += (half8)a0.s2 * b0;
636 c30 += (half8)a0.s3 * b0;
637 }
638
639 for(; src_addr.s1 < end_row_mtx_b; src_addr += (int2)(4, 8))
640 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000641 // Load values from matrix A (interleaved) and matrix B (transposed)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642 half4 a0 = vload4(0, ((__global half *)src0_ptr) + src_addr.s0);
643 half8 b0 = vload8(0, ((__global half *)src1_ptr) + src_addr.s1);
644
645 c00 += (half8)a0.s0 * b0;
646 c10 += (half8)a0.s1 * b0;
647 c20 += (half8)a0.s2 * b0;
648 c30 += (half8)a0.s3 * b0;
649 }
650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000651 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100652 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
653
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000654#if defined(ALPHA)
655 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100656 c00 = c00 * (half8)ALPHA;
657 c10 = c10 * (half8)ALPHA;
658 c20 = c20 * (half8)ALPHA;
659 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000660#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100661
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000662 // Store 4x8 block
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100663 vstore8(c00, 0, (__global half *)(offset(&dst, 0, 0)));
664 vstore8(c10, 0, (__global half *)(offset(&dst, 0, 1)));
665 vstore8(c20, 0, (__global half *)(offset(&dst, 0, 2)));
666 vstore8(c30, 0, (__global half *)(offset(&dst, 0, 3)));
667}
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100668#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100669
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000670#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100671/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
672 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
673 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000674 * @attention The number of matrix B columns, the optional alpha's value and fixed point position need to be passed at compile time using -DCOLS_B -DALPHA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100675 *
676 * @note: ALPHA must be passed in 8 bit fixed point format
677 *
678 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
679 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
680 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
681 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
682 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
683 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
684 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
685 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
686 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
687 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
688 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
689 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
690 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
691 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
692 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
693 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
694 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
695 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
696 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100697__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
698 IMAGE_DECLARATION(src1),
699 IMAGE_DECLARATION(dst))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100700{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000701 // src_addr.s0 = address of matrix A
702 // src_addr.s1 = address of matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100703
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000704 // Compute address for matrix A and B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100705 int2 src_addr = (int2)(get_global_id(1), get_global_id(0)) * (int2)((src0_stride_y),
706 (src1_stride_y));
707
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000708 // Add offset_first_element_in_bytes
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100709 src_addr = src_addr + ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
710
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000711 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100712 int end_row_mtx_b = src_addr.s1 + COLS_B;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100713
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000714 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100715 short8 c00 = 0.0f;
716 short8 c10 = 0.0f;
717 short8 c20 = 0.0f;
718 short8 c30 = 0.0f;
719 short8 c01 = 0.0f;
720 short8 c11 = 0.0f;
721 short8 c21 = 0.0f;
722 short8 c31 = 0.0f;
723
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000724 // This for loop performs 1 accumulation for each iteration
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100725 for(; src_addr.s1 <= (end_row_mtx_b - 16); src_addr += (int2)(4, 16))
726 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000727 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100728 char4 a0 = vload4(0, ((__global char *)src0_ptr) + src_addr.s0);
729 char16 b0 = vload16(0, ((__global char *)src1_ptr) + src_addr.s1);
730
731 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
732 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
733 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
734 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
735
736 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
737 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
738 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
739 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
740 }
741
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000742 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100743 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
744
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000745 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100746 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
747 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
748 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
749 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
750
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000751#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100752 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
753 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
754 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
755 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000756#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100757
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000758 // Store 16x4 block
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100759 vstore16(c00_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
760 vstore16(c10_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
761 vstore16(c20_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
762 vstore16(c30_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
763}
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100764
765/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
766 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
767 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000768 * @attention The number of matrix B columns, the optional alpha's value and fixed point position need to be passed at compile time using -DCOLS_B -DALPHA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100769 *
770 * @note: ALPHA must be passed in 16 bit fixed point format
771 *
772 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
773 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
774 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
775 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
776 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
777 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
778 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
779 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
780 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
781 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
782 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
783 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
784 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
785 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
786 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
787 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
788 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
789 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
790 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100791__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
792 IMAGE_DECLARATION(src1),
793 IMAGE_DECLARATION(dst))
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100794{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000795 // src_addr.s0 = address of matrix A
796 // src_addr.s1 = address of matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100797
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000798 // Compute address for matrix A and B
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100799 int2 src_addr = (int2)(get_global_id(1), get_global_id(0)) * (int2)((src0_stride_y),
800 (src1_stride_y));
801
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000802 // Add offset_first_element_in_bytes
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100803 src_addr = src_addr + ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
804
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000805 // Divide by 2 in order to get the src_addr in unit of short
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100806 src_addr = src_addr >> 1;
807
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000808 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100809 int end_row_mtx_b = src_addr.s1 + COLS_B;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100810
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000811 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100812 int8 c00 = 0.0f;
813 int8 c10 = 0.0f;
814 int8 c20 = 0.0f;
815 int8 c30 = 0.0f;
816
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000817 // This for loop performs 1 accumulation for each iteration
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100818 for(; src_addr.s1 <= (end_row_mtx_b - 8); src_addr += (int2)(4, 8))
819 {
820 /* Load values from matrix A (interleaved) and matrix B (transposed) */
821 short4 a0 = vload4(0, ((__global short *)src0_ptr) + src_addr.s0);
822 short8 b0 = vload8(0, ((__global short *)src1_ptr) + src_addr.s1);
823
824 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
825 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
826 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
827 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
828 }
829
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000830 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100831 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
832
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000833 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100834 short8 c00_qs16 = convert_short8_sat(c00);
835 short8 c10_qs16 = convert_short8_sat(c10);
836 short8 c20_qs16 = convert_short8_sat(c20);
837 short8 c30_qs16 = convert_short8_sat(c30);
838
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000839#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100840 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
841 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
842 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
843 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000844#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100845
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000846 // Store 8x4 block
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100847 vstore8(c00_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
848 vstore8(c10_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
849 vstore8(c20_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
850 vstore8(c30_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
851}
852#endif // defined(FIXED_POINT_POSITION)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000853#endif // defined(COLS_B)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100854
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100855#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
856#if defined(DATA_TYPE)
857#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
858/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100859 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100860 * @note This OpenCL kernel works with floating point data types (F16/F32)
861 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
862 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000863 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100864 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100865 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100866 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
867 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
868 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
869 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
870 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100871 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100872 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
873 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
874 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
875 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
876 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100877 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100878 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
879 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
880 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
881 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
882 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
883 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100884__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
885 IMAGE_DECLARATION(src1),
886 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100887{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100888 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100889
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100890 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100891 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100892
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100893 // Update address for the matrix A
894 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100895
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100896 // Update address for the matrix B
897 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100898
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100899 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
900
901 VECTOR_TYPE acc0 = 0.0f;
902#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
903 VECTOR_TYPE acc1 = 0.0f;
904#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
905#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
906 VECTOR_TYPE acc2 = 0.0f;
907#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
908#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
909 VECTOR_TYPE acc3 = 0.0f;
910#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
911
Georgios Pinitas96880cf2017-10-20 18:52:20 +0100912 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100913 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100914 // Load values from matrix A
915 VEC_DATA_TYPE(DATA_TYPE, 2)
916 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
918 VEC_DATA_TYPE(DATA_TYPE, 2)
919 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
920#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
921#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
922 VEC_DATA_TYPE(DATA_TYPE, 2)
923 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
924#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
925#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
926 VEC_DATA_TYPE(DATA_TYPE, 2)
927 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
929 // Load values from matrix B
930 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
931 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100932
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100933 // Accumulate
934 acc0 += b0 * (VECTOR_TYPE)a0.s0;
935 acc0 += b1 * (VECTOR_TYPE)a0.s1;
936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
937 acc1 += b0 * (VECTOR_TYPE)a1.s0;
938 acc1 += b1 * (VECTOR_TYPE)a1.s1;
939#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
940#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
941 acc2 += b0 * (VECTOR_TYPE)a2.s0;
942 acc2 += b1 * (VECTOR_TYPE)a2.s1;
943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
945 acc3 += b0 * (VECTOR_TYPE)a3.s0;
946 acc3 += b1 * (VECTOR_TYPE)a3.s1;
947#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100948 }
949
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100950 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100951 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100952 // Load values from matrix A
953 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
955 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
956#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
957#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
958 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
959#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
961 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
963 // Load values from matrix B
964 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100965
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100966 // Accumulate
967 acc0 += b0 * (VECTOR_TYPE)a0;
968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
969 acc1 += b0 * (VECTOR_TYPE)a1;
970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
971#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
972 acc2 += b0 * (VECTOR_TYPE)a2;
973#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
975 acc3 += b0 * (VECTOR_TYPE)a3;
976#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100977 }
978
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100979 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100980 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
981
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100982 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000983#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100984 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000985#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100986 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
987 (acc0, 0, (__global DATA_TYPE *)(offset(&dst, 0, 0)));
988#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000989#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100990 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000991#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100992 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
993 (acc1, 0, (__global DATA_TYPE *)(offset(&dst, 0, 1)));
994#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
995#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000996#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100997 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000998#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100999 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1000 (acc2, 0, (__global DATA_TYPE *)(offset(&dst, 0, 2)));
1001#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1002#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001003#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001004 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001005#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001006 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1007 (acc3, 0, (__global DATA_TYPE *)(offset(&dst, 0, 3)));
1008#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001009}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001010#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001011
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001012/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1013 *
1014 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1015 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1016 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1017 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1018 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
1019 *
1020 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1021 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1022 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1023 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1024 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1025 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1026 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1027 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1028 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1029 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1030 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1031 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1032 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1033 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1034 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1035 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1036 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1037 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1038 */
1039__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1040 IMAGE_DECLARATION(src1),
1041 IMAGE_DECLARATION(dst))
1042{
1043 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1044
1045 // Compute starting address for matrix A and matrix B
1046 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1047
1048 // Update address for matrix A
1049 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1050
1051 // Update address for matrix B
1052 src_addr.s1 += idx * sizeof(float);
1053
1054 // Address boundary for matrix A
1055 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1056
1057 // Initialize accumulators
1058 float acc00 = 0.0f;
1059 float acc01 = 0.0f;
1060 float acc02 = 0.0f;
1061 float acc03 = 0.0f;
1062
1063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1064 float acc10 = 0.0f;
1065 float acc11 = 0.0f;
1066 float acc12 = 0.0f;
1067 float acc13 = 0.0f;
1068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1069
1070#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1071 float acc20 = 0.0f;
1072 float acc21 = 0.0f;
1073 float acc22 = 0.0f;
1074 float acc23 = 0.0f;
1075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1076
1077#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1078 float acc30 = 0.0f;
1079 float acc31 = 0.0f;
1080 float acc32 = 0.0f;
1081 float acc33 = 0.0f;
1082#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1083
1084 // A and B src indices get incremented at the same time.
1085 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
1086 {
1087 // Load values from matrix A
1088 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1090 float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1091#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1093 float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1096 float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1097#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1098 // Load values from matrix B
1099 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1100 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1101
1102 // Multiply and accumulate
1103 acc00 = fma(a0.s0, b0.s0, acc00);
1104 acc00 = fma(a0.s1, b1.s0, acc00);
1105 acc01 = fma(a0.s0, b0.s1, acc01);
1106 acc01 = fma(a0.s1, b1.s1, acc01);
1107 acc02 = fma(a0.s0, b0.s2, acc02);
1108 acc02 = fma(a0.s1, b1.s2, acc02);
1109 acc03 = fma(a0.s1, b1.s3, acc03);
1110 acc03 = fma(a0.s0, b0.s3, acc03);
1111
1112#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1113 acc10 = fma(a1.s0, b0.s0, acc10);
1114 acc11 = fma(a1.s0, b0.s1, acc11);
1115 acc12 = fma(a1.s0, b0.s2, acc12);
1116 acc13 = fma(a1.s0, b0.s3, acc13);
1117
1118 acc10 = fma(a1.s1, b1.s0, acc10);
1119 acc11 = fma(a1.s1, b1.s1, acc11);
1120 acc12 = fma(a1.s1, b1.s2, acc12);
1121 acc13 = fma(a1.s1, b1.s3, acc13);
1122#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1123#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1124 acc20 = fma(a2.s0, b0.s0, acc20);
1125 acc21 = fma(a2.s0, b0.s1, acc21);
1126 acc22 = fma(a2.s0, b0.s2, acc22);
1127 acc23 = fma(a2.s0, b0.s3, acc23);
1128
1129 acc20 = fma(a2.s1, b1.s0, acc20);
1130 acc21 = fma(a2.s1, b1.s1, acc21);
1131 acc22 = fma(a2.s1, b1.s2, acc22);
1132 acc23 = fma(a2.s1, b1.s3, acc23);
1133#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1134#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1135 acc30 = fma(a3.s0, b0.s0, acc30);
1136 acc31 = fma(a3.s0, b0.s1, acc31);
1137 acc32 = fma(a3.s0, b0.s2, acc32);
1138 acc33 = fma(a3.s0, b0.s3, acc33);
1139
1140 acc30 = fma(a3.s1, b1.s0, acc30);
1141 acc31 = fma(a3.s1, b1.s1, acc31);
1142 acc32 = fma(a3.s1, b1.s2, acc32);
1143 acc33 = fma(a3.s1, b1.s3, acc33);
1144#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1145 }
1146
1147 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1148 {
1149 // Load values from matrix A
1150 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1151#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1152 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1153#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1154#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1155 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1156#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1157#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1158 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1159#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1160 // Load values from matrix B
1161 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1162
1163 // Multiply and accumulate
1164 acc00 = fma(a0, b0.s0, acc00);
1165 acc01 = fma(a0, b0.s1, acc01);
1166 acc02 = fma(a0, b0.s2, acc02);
1167 acc03 = fma(a0, b0.s3, acc03);
1168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1169 acc10 = fma(a1, b0.s0, acc10);
1170 acc11 = fma(a1, b0.s1, acc11);
1171 acc12 = fma(a1, b0.s2, acc12);
1172 acc13 = fma(a1, b0.s3, acc13);
1173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1175 acc20 = fma(a2, b0.s0, acc20);
1176 acc21 = fma(a2, b0.s1, acc21);
1177 acc22 = fma(a2, b0.s2, acc22);
1178 acc23 = fma(a2, b0.s3, acc23);
1179#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1181 acc30 = fma(a3, b0.s0, acc30);
1182 acc31 = fma(a3, b0.s1, acc31);
1183 acc32 = fma(a3, b0.s2, acc32);
1184 acc33 = fma(a3, b0.s3, acc33);
1185#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1186 }
1187
1188 // Compute destination address
1189 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1190
1191 // Multiply by the weight of matrix-matrix product and store the result
1192#if defined(ALPHA)
1193 acc00 = acc00 * ALPHA;
1194 acc01 = acc01 * ALPHA;
1195 acc02 = acc02 * ALPHA;
1196 acc03 = acc03 * ALPHA;
1197#endif // defined(ALPHA)
1198
1199 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
1200 vstore4(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1201
1202#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1203#if defined(ALPHA)
1204 acc10 = acc10 * ALPHA;
1205 acc11 = acc11 * ALPHA;
1206 acc12 = acc12 * ALPHA;
1207 acc13 = acc13 * ALPHA;
1208#endif // defined(ALPHA)
1209 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
1210 vstore4(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1213#if defined(ALPHA)
1214 acc20 = acc20 * ALPHA;
1215 acc21 = acc21 * ALPHA;
1216 acc22 = acc22 * ALPHA;
1217 acc23 = acc23 * ALPHA;
1218#endif // defined(ALPHA)
1219 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
1220 vstore4(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1221#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1223#if defined(ALPHA)
1224 acc30 = acc30 * ALPHA;
1225 acc31 = acc31 * ALPHA;
1226 acc32 = acc32 * ALPHA;
1227 acc33 = acc33 * ALPHA;
1228#endif // defined(ALPHA)
1229 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
1230 vstore4(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1231#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1232}
1233
1234/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1235 *
1236 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1237 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1238 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1239 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1240 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1241 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
1242 *
1243 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1244 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1245 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1246 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1247 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1248 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1249 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1250 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1251 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1252 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1253 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1254 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1255 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1256 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1257 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1258 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1259 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1260 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1261 */
1262__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1263 IMAGE_DECLARATION(src1),
1264 IMAGE_DECLARATION(dst))
1265{
1266 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1267 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1268
1269 // Compute starting address for matrix A and Matrix B
1270 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1271
1272 // Update address for the matrix A
1273 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1274
1275 // Update address for the matrix B
1276 src_addr.s1 += idx * sizeof(float);
1277
1278 // Address boundary for the matrix A
1279 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1280
1281 // Initialize accumulators
1282 float acc00 = 0.0f;
1283 float acc01 = 0.0f;
1284
1285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1286 float acc10 = 0.0f;
1287 float acc11 = 0.0f;
1288#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1289#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1290 float acc20 = 0.0f;
1291 float acc21 = 0.0f;
1292#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1294 float acc30 = 0.0f;
1295 float acc31 = 0.0f;
1296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1297
1298 // A and B src indices get incremented at the same time.
1299 for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
1300 {
1301 // Load values from matrix A
1302 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1303
1304 // Load values from matrix B
1305 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1306 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1307 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
1308 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
1309
1310 // Multiply and accumulate
1311 acc00 = fma(a0.s0, b0.s0, acc00);
1312 acc00 = fma(a0.s1, b1.s0, acc00);
1313 acc00 = fma(a0.s2, b2.s0, acc00);
1314 acc00 = fma(a0.s3, b3.s0, acc00);
1315
1316 acc01 = fma(a0.s0, b0.s1, acc01);
1317 acc01 = fma(a0.s1, b1.s1, acc01);
1318 acc01 = fma(a0.s2, b2.s1, acc01);
1319 acc01 = fma(a0.s3, b3.s1, acc01);
1320
1321#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1322 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1323 acc10 = fma(a0.s0, b0.s0, acc10);
1324 acc10 = fma(a0.s1, b1.s0, acc10);
1325 acc10 = fma(a0.s2, b2.s0, acc10);
1326 acc10 = fma(a0.s3, b3.s0, acc10);
1327
1328 acc11 = fma(a0.s0, b0.s1, acc11);
1329 acc11 = fma(a0.s1, b1.s1, acc11);
1330 acc11 = fma(a0.s2, b2.s1, acc11);
1331 acc11 = fma(a0.s3, b3.s1, acc11);
1332#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1333#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1334 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1335 acc20 = fma(a0.s0, b0.s0, acc20);
1336 acc20 = fma(a0.s1, b1.s0, acc20);
1337 acc20 = fma(a0.s2, b2.s0, acc20);
1338 acc20 = fma(a0.s3, b3.s0, acc20);
1339
1340 acc21 = fma(a0.s0, b0.s1, acc21);
1341 acc21 = fma(a0.s1, b1.s1, acc21);
1342 acc21 = fma(a0.s2, b2.s1, acc21);
1343 acc21 = fma(a0.s3, b3.s1, acc21);
1344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1346 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1347 acc30 = fma(a0.s0, b0.s0, acc30);
1348 acc30 = fma(a0.s1, b1.s0, acc30);
1349 acc30 = fma(a0.s2, b2.s0, acc30);
1350 acc30 = fma(a0.s3, b3.s0, acc30);
1351
1352 acc31 = fma(a0.s0, b0.s1, acc31);
1353 acc31 = fma(a0.s1, b1.s1, acc31);
1354 acc31 = fma(a0.s2, b2.s1, acc31);
1355 acc31 = fma(a0.s3, b3.s1, acc31);
1356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1357 }
1358 // float size increment
1359 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
1360 {
1361 // Load values from matrix A
1362 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1364 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1365#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1367 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1370 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1372 // Load values from matrix B
1373 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1374
1375 // Multiply and accumulate
1376 acc00 = fma(a0, b0.s0, acc00);
1377 acc01 = fma(a0, b0.s1, acc01);
1378#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1379 acc10 = fma(a1, b0.s0, acc10);
1380 acc11 = fma(a1, b0.s1, acc11);
1381#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1382#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1383 acc20 = fma(a2, b0.s0, acc20);
1384 acc21 = fma(a2, b0.s1, acc21);
1385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1387 acc30 = fma(a3, b0.s0, acc30);
1388 acc31 = fma(a3, b0.s1, acc31);
1389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1390 }
1391
1392 // Compute destination address
1393 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1394
1395 // Multiply by the weight of matrix-matrix product and store the result
1396#if defined(ALPHA)
1397 acc00 = acc00 * ALPHA;
1398 acc01 = acc01 * ALPHA;
1399#endif // defined(ALPHA)
1400 float2 acc0 = ((float2)(acc00, acc01));
1401 vstore2(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1402#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1403#if defined(ALPHA)
1404 acc10 = acc10 * ALPHA;
1405 acc11 = acc11 * ALPHA;
1406#endif // defined(ALPHA)
1407 float2 acc1 = ((float2)(acc10, acc11));
1408 vstore2(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1409#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1410#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1411#if defined(ALPHA)
1412 acc20 = acc20 * ALPHA;
1413 acc21 = acc21 * ALPHA;
1414#endif // defined(ALPHA)
1415 float2 acc2 = ((float2)(acc20, acc21));
1416 vstore2(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1419#if defined(ALPHA)
1420 acc30 = acc30 * ALPHA;
1421 acc31 = acc31 * ALPHA;
1422#endif // defined(ALPHA)
1423 float2 acc3 = (float2)(acc30, acc31);
1424 vstore2(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1426}
1427
1428#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001429/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001430 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001431 * @note This OpenCL kernel works with fixed point data types QS8
1432 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001433 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001434 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001435 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001436 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001437 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001438 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1439 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1440 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1441 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1442 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1443 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1444 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1445 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1446 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1447 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1448 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1449 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1450 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1451 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1452 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1453 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1454 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1455 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001456__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001457 IMAGE_DECLARATION(src1),
1458 IMAGE_DECLARATION(dst))
1459{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001460 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001461
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001462 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001463 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001464
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001465 // Update address for the matrix A
1466 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001467
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001468 // Update address for the matrix B
1469 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001470
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001471 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
1472
1473 short8 acc00 = 0;
1474 short8 acc01 = 0;
1475#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1476 short8 acc10 = 0;
1477 short8 acc11 = 0;
1478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1479#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1480 short8 acc20 = 0;
1481 short8 acc21 = 0;
1482#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1484 short8 acc30 = 0;
1485 short8 acc31 = 0;
1486#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1487
1488 // This for loop performs 4 accumulations per iteration
1489 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001490 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001491 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1493 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1496 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1498#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1499 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001501 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1502 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001503
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001504 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
1505 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
1506 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1507 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1508#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1509 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
1510 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
1511 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1512 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1513#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1514#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1515 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
1516 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
1517 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1518 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1519#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1520#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1521 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
1522 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
1523 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1524 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1525#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001526 }
1527
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001528 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001529 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1530 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001531 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1532#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1533 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1534#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1535#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1536 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1537#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1538#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1539 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1540#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001541 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
1542
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001543 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
1544 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1546 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
1547 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
1548#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1549#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1550 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
1551 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
1552#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1553#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1554 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
1555 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
1556#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001557 }
1558
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001559 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001560 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1561
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001562 // Multiply by the weight of matrix product and store the result
1563 char16 acc_qs8;
1564 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001565#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001566 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001567#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001568 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001569#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1570 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001571#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001572 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001573#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001574 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
1575#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1576#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1577 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001578#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001579 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001580#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001581 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
1582#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1583#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1584 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001585#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001586 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001587#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001588 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
1589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001590}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001591
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001592/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001593 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001594 * @note This OpenCL kernel works with fixed point data types QS16
1595 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001596 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001597 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001598 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001599 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001600 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001601 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1602 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1603 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1604 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1605 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1606 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1607 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1608 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1609 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1610 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1611 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1612 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1613 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1614 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1615 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1616 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1617 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1618 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001619__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001620 IMAGE_DECLARATION(src1),
1621 IMAGE_DECLARATION(dst))
1622{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001623 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001624
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001625 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001626 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001627
1628 // Update address for the matrix A
1629 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1630
1631 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001632 src_addr.s1 += idx * sizeof(short);
1633
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001634 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001635
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001636 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001637#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1638 int8 acc1 = 0;
1639#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1640#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1641 int8 acc2 = 0;
1642#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1643#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1644 int8 acc3 = 0;
1645#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001646
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001647 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001648 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001649 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001650 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1651#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1652 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1653#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1655 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1657#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1658 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1659#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001660 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1661 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001662
1663 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1664 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001665#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1666 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
1667 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
1668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1669#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1670 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
1671 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
1672#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1673#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1674 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
1675 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
1676#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001677 }
1678
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001679 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001680 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
1681 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001682 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1683#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1684 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1685#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1686#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1687 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1688#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1689#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1690 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1691#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001692 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
1693
1694 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001695#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1696 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
1697#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1698#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1699 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
1700#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1702 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
1703#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001704 }
1705
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001706 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001707 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1708
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001709 // Multiply by the weight of matrix product and store the result
1710 short8 acc_qs16;
1711 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001712#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001713 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001714#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001715 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001716#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1717 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001718#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001719 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001720#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001721 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
1722#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1723#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1724 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001725#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001726 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001727#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001728 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
1729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1730#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1731 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001732#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001733 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001734#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001735 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
1736#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001737}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001738#endif // defined(FIXED_POINT_POSITION)
1739#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001740
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001741#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001742/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1743 *
1744 * @attention The beta's value need to be passed at compile time using -DBETA
1745 *
1746 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
1747 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1748 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1749 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1750 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1751 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001752 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001753 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1754 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1755 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1756 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1757 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1758 */
1759__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
1760 IMAGE_DECLARATION(dst))
1761{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001762 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001763 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1764 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1765
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001766 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001767 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
1768
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001769 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001770 float4 c = vload4(0, (__global float *)src.ptr);
1771
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001772 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001773 float4 out = alpha_ab + (float4)BETA * c;
1774
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001775 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001776 vstore4(out, 0, (__global float *)dst.ptr);
1777}
1778
1779/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1780 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001781 * @attention The beta's value need to be passed at compile time using -DBETA
1782 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001783 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
1784 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1785 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1786 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1787 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1788 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001789 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001790 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1791 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1792 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1793 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1794 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1795 */
1796__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
1797 IMAGE_DECLARATION(dst))
1798{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001799 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001800 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1801 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1802
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001803 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001804 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
1805
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001806 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001807 half8 c = vload8(0, (__global half *)src.ptr);
1808
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001809 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001810 half8 out = alpha_ab + (half8)BETA * c;
1811
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001812 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001813 vstore8(out, 0, (__global half *)dst.ptr);
1814}
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001815
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001816#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001817/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1818 *
1819 * @attention The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1820 *
1821 * @note: BETA must be passed in 8 bit fixed point format
1822 *
1823 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
1824 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1825 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1826 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1827 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1828 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
1829 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
1830 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1831 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1832 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1833 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1834 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1835 */
1836__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
1837 IMAGE_DECLARATION(dst))
1838{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001839 // Compute source and destination addresses
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001840 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1841 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1842
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001843 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001844 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
1845
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001846 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001847 char16 c = vload16(0, (__global char *)src.ptr);
1848
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001849 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001850 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
1851
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001852 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001853 vstore16(out, 0, (__global char *)dst.ptr);
1854}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001855
1856/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1857 *
1858 * @attention The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1859 *
1860 * @note: BETA must be passed in 16 bit fixed point format
1861 *
1862 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
1863 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1864 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1865 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1866 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1867 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
1868 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
1869 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1870 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1871 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1872 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1873 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1874 */
1875__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
1876 IMAGE_DECLARATION(dst))
1877{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001878 // Compute source and destination addresses
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001879 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1880 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1881
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001882 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001883 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
1884
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001885 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001886 short8 c = vload8(0, (__global short *)src.ptr);
1887
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001888 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001889 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
1890
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001891 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001892 vstore8(out, 0, (__global short *)dst.ptr);
1893}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001894#endif // defined(FIXED_POINT_POSITION)
1895#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001896
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001897#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001898/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
1899 *
1900 * @attention The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
1901 *
1902 * @attention The input A and matrix B must not be reshaped
1903 *
1904 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1905 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1906 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1907 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1908 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1909 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001910 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001911 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1912 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1913 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1914 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1915 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1916 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1917 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001918 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001919 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1920 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1921 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1922 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1923 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1924 */
1925__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
1926 TENSOR3D_DECLARATION(src1),
1927 IMAGE_DECLARATION(dst))
1928{
1929 int idx = get_global_id(0) * 4;
1930 int idy = get_global_id(1);
1931
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001932 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001933 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
1934 src_addr.s1 += idx * sizeof(float);
1935
1936 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
1937
1938 float4 acc = 0.0f;
1939
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001940 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001941 {
1942 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
1943 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1944 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
1945
1946 acc += b0 * (float4)a0.s0;
1947 acc += b1 * (float4)a0.s1;
1948 }
1949
1950 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1951 {
1952 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
1953 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1954
1955 acc += b0 * (float4)a0;
1956 }
1957
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001958 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001959 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1960
1961 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
1962}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001963#endif // defined(WIDTH_VECTOR_A)
1964
1965/** This kernel accumulates each row with the biases vector.
1966 *
1967 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
1968 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
1969 *
1970 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
1971 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
1972 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
1973 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
1974 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1975 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
1976 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
1977 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
1978 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1979 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
1980 */
1981#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
1982__kernel void gemm_accumulate_biases(
1983 IMAGE_DECLARATION(accum),
1984 VECTOR_DECLARATION(biases))
1985{
1986 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
1987 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
1988
1989 // Vector size, i.e. number of vector elements.
1990 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1991 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
1992 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1993 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
1994#ifdef FIXED_POINT_POSITION
1995 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
1996#else // FIXED_POINT_POSITION
1997 accum_value = biases_value + accum_value;
1998#endif // FIXED_POINT_POSITION
1999 // Store result in the accumulate buffer
2000 VSTORE(VECTOR_SIZE)
2001 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2002}
2003#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)