blob: bad09f3c427930c17a4502e0ac8c6d9c63d54133 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice368da832017-07-03 12:33:49 +010026#ifdef FIXED_POINT_POSITION
27#include "fixed_point.h"
28#endif // FIXED_POINT_POSITION
29
Gian Marco36a0a462018-01-12 10:21:40 +000030#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
32#if TRANSPOSE_W == 4
33#define DATA_TYPE uint
34#elif TRANSPOSE_W == 8
35#define DATA_TYPE ushort
36#elif TRANSPOSE_W == 16
37#define DATA_TYPE uchar
38#else // TRANSPOSE_W == 16
39#error "Transpose width not supported"
40#endif // TRANSPOSE_W
41
42/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043 *
Gian Marco36a0a462018-01-12 10:21:40 +000044 * @attention The multiplication factor (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
45 *
46 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
48 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
49 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
50 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
51 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010052 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +000056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
58 */
Gian Marco36a0a462018-01-12 10:21:40 +000059__kernel void gemm_transpose1xW(IMAGE_DECLARATION(src),
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010060 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010061{
62 uint x = get_global_id(0);
63 uint y = get_global_id(1);
64
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010065 // Compute address for Matrix B - source
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066 Image src = CONVERT_TO_IMAGE_STRUCT(src);
67
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010068 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +000069 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
70 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010071
Gian Marco36a0a462018-01-12 10:21:40 +000072 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
73 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010074
Gian Marco36a0a462018-01-12 10:21:40 +000075 VSTORE(TRANSPOSE_W)
76 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077}
Gian Marco36a0a462018-01-12 10:21:40 +000078#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079
Gian Marco36a0a462018-01-12 10:21:40 +000080#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
81
82/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083 *
Gian Marco36a0a462018-01-12 10:21:40 +000084 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010085 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
86 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
87 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
88 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
89 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +010090 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010091 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
92 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
93 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
94 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
95 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
96 */
Gian Marco36a0a462018-01-12 10:21:40 +000097__kernel void gemm_interleave4x4(IMAGE_DECLARATION(src),
Gian Marco Iodice9f89bae2017-06-22 12:09:49 +010098 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +010099{
Gian Marco36a0a462018-01-12 10:21:40 +0000100 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100101 uint x = get_global_id(0);
102 uint y = get_global_id(1);
103
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000104 // Compute address for Matrix B - source
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 Image src = CONVERT_TO_IMAGE_STRUCT(src);
106
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000107 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000108 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
109 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100110
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000111 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000112 VEC_DATA_TYPE(DATA_TYPE, 4)
113 a0 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 0)));
114 VEC_DATA_TYPE(DATA_TYPE, 4)
115 a1 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 1)));
116 VEC_DATA_TYPE(DATA_TYPE, 4)
117 a2 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 2)));
118 VEC_DATA_TYPE(DATA_TYPE, 4)
119 a3 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 3)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120
Gian Marco36a0a462018-01-12 10:21:40 +0000121 VEC_DATA_TYPE(DATA_TYPE, 4)
122 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
123 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100124
Gian Marco36a0a462018-01-12 10:21:40 +0000125 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
126 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marco36a0a462018-01-12 10:21:40 +0000128 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
129 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100130
Gian Marco36a0a462018-01-12 10:21:40 +0000131 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
132 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100133}
Gian Marco36a0a462018-01-12 10:21:40 +0000134#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100135
Gian Marco36a0a462018-01-12 10:21:40 +0000136#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100137/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100138 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100139 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000140 * @attention The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100141 *
142 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
143 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
144 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
145 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
146 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
147 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100148 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
150 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
151 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
152 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
153 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100154 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000156 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000158 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
160 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100161__kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
162 IMAGE_DECLARATION(src1),
163 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100164{
Gian Marco36a0a462018-01-12 10:21:40 +0000165 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
166 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100167
Gian Marco36a0a462018-01-12 10:21:40 +0000168 // Offset
169 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
170 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100171
Gian Marco36a0a462018-01-12 10:21:40 +0000172 // src_addr_a = address of matrix A
173 // src_addr_b = address of matrix B
174 __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
175 __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000177 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000178 __global float *src_end_addr_b = src_addr_b + COLS_B;
179
180 src_addr_a += offset_row_a;
181 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100182
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000183 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100184 float4 c00 = 0.0f;
185 float4 c10 = 0.0f;
186 float4 c20 = 0.0f;
187 float4 c30 = 0.0f;
188
Gian Marco36a0a462018-01-12 10:21:40 +0000189 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000191 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000192 float4 a0 = vload4(0, src_addr_a);
193 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194
195 c00 += (float4)a0.s0 * b0;
196 c10 += (float4)a0.s1 * b0;
197 c20 += (float4)a0.s2 * b0;
198 c30 += (float4)a0.s3 * b0;
199
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000200 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000201 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
202 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100203
204 c00 += (float4)a0.s0 * b0;
205 c10 += (float4)a0.s1 * b0;
206 c20 += (float4)a0.s2 * b0;
207 c30 += (float4)a0.s3 * b0;
208 }
209
Gian Marco36a0a462018-01-12 10:21:40 +0000210 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000212 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000213 float4 a0 = vload4(0, src_addr_a);
214 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100215
216 c00 += (float4)a0.s0 * b0;
217 c10 += (float4)a0.s1 * b0;
218 c20 += (float4)a0.s2 * b0;
219 c30 += (float4)a0.s3 * b0;
220 }
221
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000222 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100223 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
224
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000225#if defined(ALPHA)
226 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100227 c00 = c00 * (float4)ALPHA;
228 c10 = c10 * (float4)ALPHA;
229 c20 = c20 * (float4)ALPHA;
230 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000231#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000233 // Store 4x4 block
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234 vstore4(c00, 0, (__global float *)(offset(&dst, 0, 0)));
235 vstore4(c10, 0, (__global float *)(offset(&dst, 0, 1)));
236 vstore4(c20, 0, (__global float *)(offset(&dst, 0, 2)));
237 vstore4(c30, 0, (__global float *)(offset(&dst, 0, 3)));
238}
239
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000240/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100241 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100242 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000243 * @attention The number of matrix B columns and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100244 *
245 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
246 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
247 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
248 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
249 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
250 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100251 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
253 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
254 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
255 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
256 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100257 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100258 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000259 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100260 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000261 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100262 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
263 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100264__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
265 IMAGE_DECLARATION(src1),
266 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267{
Gian Marco36a0a462018-01-12 10:21:40 +0000268 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
269 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
270
271 // Offset
272 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
273 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
274
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100275 // src_addr_a = address of matrix A
276 // src_addr_b = address of matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000277 __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
278 __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100279
280 // Compute end row address for matrix B
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100281 __global float *src_end_addr_b = src_addr_b + COLS_B;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100282
Gian Marco36a0a462018-01-12 10:21:40 +0000283 src_addr_a += offset_row_a;
284 src_addr_b += offset_row_b;
285
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 // Reset accumulators
287 float c00 = 0.0f;
288 float c01 = 0.0f;
289 float c02 = 0.0f;
290 float c03 = 0.0f;
291 float c10 = 0.0f;
292 float c11 = 0.0f;
293 float c12 = 0.0f;
294 float c13 = 0.0f;
295 float c20 = 0.0f;
296 float c21 = 0.0f;
297 float c22 = 0.0f;
298 float c23 = 0.0f;
299 float c30 = 0.0f;
300 float c31 = 0.0f;
301 float c32 = 0.0f;
302 float c33 = 0.0f;
303
Gian Marco36a0a462018-01-12 10:21:40 +0000304 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305 {
306 // Load values from matrix A (interleaved) and matrix B (transposed)
307 float4 a0 = vload4(0, src_addr_a);
308 float4 b0 = vload4(0, src_addr_b);
309
310 c00 = fma(a0.s0, b0.s0, c00);
311 c01 = fma(a0.s0, b0.s1, c01);
312 c02 = fma(a0.s0, b0.s2, c02);
313 c03 = fma(a0.s0, b0.s3, c03);
314
315 c10 = fma(a0.s1, b0.s0, c10);
316 c11 = fma(a0.s1, b0.s1, c11);
317 c12 = fma(a0.s1, b0.s2, c12);
318 c13 = fma(a0.s1, b0.s3, c13);
319
320 c20 = fma(a0.s2, b0.s0, c20);
321 c21 = fma(a0.s2, b0.s1, c21);
322 c22 = fma(a0.s2, b0.s2, c22);
323 c23 = fma(a0.s2, b0.s3, c23);
324
325 c30 = fma(a0.s3, b0.s0, c30);
326 c31 = fma(a0.s3, b0.s1, c31);
327 c32 = fma(a0.s3, b0.s2, c32);
328 c33 = fma(a0.s3, b0.s3, c33);
329
330 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000331 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
332 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333
334 c00 = fma(a0.s0, b0.s0, c00);
335 c01 = fma(a0.s0, b0.s1, c01);
336 c02 = fma(a0.s0, b0.s2, c02);
337 c03 = fma(a0.s0, b0.s3, c03);
338
339 c10 = fma(a0.s1, b0.s0, c10);
340 c11 = fma(a0.s1, b0.s1, c11);
341 c12 = fma(a0.s1, b0.s2, c12);
342 c13 = fma(a0.s1, b0.s3, c13);
343
344 c20 = fma(a0.s2, b0.s0, c20);
345 c21 = fma(a0.s2, b0.s1, c21);
346 c22 = fma(a0.s2, b0.s2, c22);
347 c23 = fma(a0.s2, b0.s3, c23);
348
349 c30 = fma(a0.s3, b0.s0, c30);
350 c31 = fma(a0.s3, b0.s1, c31);
351 c32 = fma(a0.s3, b0.s2, c32);
352 c33 = fma(a0.s3, b0.s3, c33);
353
354 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000355 a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
356 b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100357
358 c00 = fma(a0.s0, b0.s0, c00);
359 c01 = fma(a0.s0, b0.s1, c01);
360 c02 = fma(a0.s0, b0.s2, c02);
361 c03 = fma(a0.s0, b0.s3, c03);
362
363 c10 = fma(a0.s1, b0.s0, c10);
364 c11 = fma(a0.s1, b0.s1, c11);
365 c12 = fma(a0.s1, b0.s2, c12);
366 c13 = fma(a0.s1, b0.s3, c13);
367
368 c20 = fma(a0.s2, b0.s0, c20);
369 c21 = fma(a0.s2, b0.s1, c21);
370 c22 = fma(a0.s2, b0.s2, c22);
371 c23 = fma(a0.s2, b0.s3, c23);
372
373 c30 = fma(a0.s3, b0.s0, c30);
374 c31 = fma(a0.s3, b0.s1, c31);
375 c32 = fma(a0.s3, b0.s2, c32);
376 c33 = fma(a0.s3, b0.s3, c33);
377
378 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000379 a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
380 b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100381
382 c00 = fma(a0.s0, b0.s0, c00);
383 c01 = fma(a0.s0, b0.s1, c01);
384 c02 = fma(a0.s0, b0.s2, c02);
385 c03 = fma(a0.s0, b0.s3, c03);
386
387 c10 = fma(a0.s1, b0.s0, c10);
388 c11 = fma(a0.s1, b0.s1, c11);
389 c12 = fma(a0.s1, b0.s2, c12);
390 c13 = fma(a0.s1, b0.s3, c13);
391
392 c20 = fma(a0.s2, b0.s0, c20);
393 c21 = fma(a0.s2, b0.s1, c21);
394 c22 = fma(a0.s2, b0.s2, c22);
395 c23 = fma(a0.s2, b0.s3, c23);
396
397 c30 = fma(a0.s3, b0.s0, c30);
398 c31 = fma(a0.s3, b0.s1, c31);
399 c32 = fma(a0.s3, b0.s2, c32);
400 c33 = fma(a0.s3, b0.s3, c33);
401 }
402
Gian Marco36a0a462018-01-12 10:21:40 +0000403 for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404 {
405 // Load values from matrix A (interleaved) and matrix B (transposed)
406 float4 a0 = vload4(0, src_addr_a);
407 float4 b0 = vload4(0, src_addr_b);
408
409 c00 = fma(a0.s0, b0.s0, c00);
410 c01 = fma(a0.s0, b0.s1, c01);
411 c02 = fma(a0.s0, b0.s2, c02);
412 c03 = fma(a0.s0, b0.s3, c03);
413
414 c10 = fma(a0.s1, b0.s0, c10);
415 c11 = fma(a0.s1, b0.s1, c11);
416 c12 = fma(a0.s1, b0.s2, c12);
417 c13 = fma(a0.s1, b0.s3, c13);
418
419 c20 = fma(a0.s2, b0.s0, c20);
420 c21 = fma(a0.s2, b0.s1, c21);
421 c22 = fma(a0.s2, b0.s2, c22);
422 c23 = fma(a0.s2, b0.s3, c23);
423
424 c30 = fma(a0.s3, b0.s0, c30);
425 c31 = fma(a0.s3, b0.s1, c31);
426 c32 = fma(a0.s3, b0.s2, c32);
427 c33 = fma(a0.s3, b0.s3, c33);
428 }
429
430 // Compute destination address
431 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
432
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000433#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100434 // Multiply by the weight of matrix product
435 c00 = c00 * ALPHA;
436 c01 = c01 * ALPHA;
437 c02 = c02 * ALPHA;
438 c03 = c03 * ALPHA;
439 c10 = c10 * ALPHA;
440 c11 = c11 * ALPHA;
441 c12 = c12 * ALPHA;
442 c13 = c13 * ALPHA;
443 c20 = c20 * ALPHA;
444 c21 = c21 * ALPHA;
445 c22 = c22 * ALPHA;
446 c23 = c23 * ALPHA;
447 c30 = c30 * ALPHA;
448 c31 = c31 * ALPHA;
449 c32 = c32 * ALPHA;
450 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000451#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100452
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100453 // Store 4x4 block
454 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(offset(&dst, 0, 0)));
455 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(offset(&dst, 0, 1)));
456 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(offset(&dst, 0, 2)));
457 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(offset(&dst, 0, 3)));
458}
459
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100460#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100461/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100462 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100463 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000464 * @attention The number of matrix B columns and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100465 *
466 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
467 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
468 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
469 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
470 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
471 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100472 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
474 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
475 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
476 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
477 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100478 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100479 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000480 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100481 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000482 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
484 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100485__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
486 IMAGE_DECLARATION(src1),
487 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100488{
Gian Marco36a0a462018-01-12 10:21:40 +0000489 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
490 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100491
Gian Marco36a0a462018-01-12 10:21:40 +0000492 // Offset
493 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
494 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100495
Gian Marco36a0a462018-01-12 10:21:40 +0000496 // src_addr_a = address of matrix A
497 // src_addr_b = address of matrix B
498 __global half *src_addr_a = (__global half *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
499 __global half *src_addr_b = (__global half *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000501 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000502 __global half *src_end_addr_b = src_addr_b + COLS_B;
503
504 src_addr_a += offset_row_a;
505 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100506
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000507 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100508 half8 c00 = 0.0f;
509 half8 c10 = 0.0f;
510 half8 c20 = 0.0f;
511 half8 c30 = 0.0f;
512
Gian Marco36a0a462018-01-12 10:21:40 +0000513 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000515 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000516 half4 a0 = vload4(0, src_addr_a);
517 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100518
519 c00 += (half8)a0.s0 * b0;
520 c10 += (half8)a0.s1 * b0;
521 c20 += (half8)a0.s2 * b0;
522 c30 += (half8)a0.s3 * b0;
523
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000524 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000525 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
526 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100527
528 c00 += (half8)a0.s0 * b0;
529 c10 += (half8)a0.s1 * b0;
530 c20 += (half8)a0.s2 * b0;
531 c30 += (half8)a0.s3 * b0;
532 }
533
Gian Marco36a0a462018-01-12 10:21:40 +0000534 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100535 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000536 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000537 half4 a0 = vload4(0, src_addr_a);
538 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539
540 c00 += (half8)a0.s0 * b0;
541 c10 += (half8)a0.s1 * b0;
542 c20 += (half8)a0.s2 * b0;
543 c30 += (half8)a0.s3 * b0;
544 }
545
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000546 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100547 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
548
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000549#if defined(ALPHA)
550 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100551 c00 = c00 * (half8)ALPHA;
552 c10 = c10 * (half8)ALPHA;
553 c20 = c20 * (half8)ALPHA;
554 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000555#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100556
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000557 // Store 4x8 block
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558 vstore8(c00, 0, (__global half *)(offset(&dst, 0, 0)));
559 vstore8(c10, 0, (__global half *)(offset(&dst, 0, 1)));
560 vstore8(c20, 0, (__global half *)(offset(&dst, 0, 2)));
561 vstore8(c30, 0, (__global half *)(offset(&dst, 0, 3)));
562}
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100563#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100564
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000565#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100566/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
567 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
568 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000569 * @attention The number of matrix B columns, the optional alpha's value and fixed point position need to be passed at compile time using -DCOLS_B -DALPHA and -DFIXED_POINT_POSITION
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100570 *
571 * @note: ALPHA must be passed in 8 bit fixed point format
572 *
573 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
574 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
575 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
576 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
577 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
578 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
579 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
580 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
581 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
582 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
583 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
584 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
585 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
586 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000587 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100588 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000589 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100590 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
591 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100592__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
593 IMAGE_DECLARATION(src1),
594 IMAGE_DECLARATION(dst))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100595{
Gian Marco36a0a462018-01-12 10:21:40 +0000596 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
597 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100598
Gian Marco36a0a462018-01-12 10:21:40 +0000599 // Offset
600 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
601 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100602
Gian Marco36a0a462018-01-12 10:21:40 +0000603 // src_addr_a = address of matrix A
604 // src_addr_b = address of matrix B
605 __global char *src_addr_a = src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes;
606 __global char *src_addr_b = src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100607
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000608 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000609 __global char *src_end_addr_b = src_addr_b + COLS_B;
610
611 src_addr_a += offset_row_a;
612 src_addr_b += offset_row_b;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100613
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000614 // Reset accumulators
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100615 short8 c00 = 0.0f;
616 short8 c10 = 0.0f;
617 short8 c20 = 0.0f;
618 short8 c30 = 0.0f;
619 short8 c01 = 0.0f;
620 short8 c11 = 0.0f;
621 short8 c21 = 0.0f;
622 short8 c31 = 0.0f;
623
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000624 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000625 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100626 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000627 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000628 char4 a0 = vload4(0, src_addr_a);
629 char16 b0 = vload16(0, src_addr_b);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100630
631 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
632 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
633 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
634 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
635
636 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
637 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
638 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
639 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
640 }
641
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000642 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100643 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
644
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000645 // Multiply by the weight of matrix product
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100646 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
647 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
648 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
649 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000651#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100652 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
653 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
654 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
655 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000656#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100657
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000658 // Store 16x4 block
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100659 vstore16(c00_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
660 vstore16(c10_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
661 vstore16(c20_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
662 vstore16(c30_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
663}
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100664
665/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
666 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
667 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000668 * @attention The number of matrix B columns, the optional alpha's value and fixed point position need to be passed at compile time using -DCOLS_B -DALPHA and -DFIXED_POINT_POSITION
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100669 *
670 * @note: ALPHA must be passed in 16 bit fixed point format
671 *
672 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
673 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
674 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
675 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
676 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
677 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
678 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
679 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
680 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
681 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
682 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
683 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
684 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
685 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000686 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100687 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000688 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100689 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
690 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100691__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
692 IMAGE_DECLARATION(src1),
693 IMAGE_DECLARATION(dst))
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100694{
Gian Marco36a0a462018-01-12 10:21:40 +0000695 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
696 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100697
Gian Marco36a0a462018-01-12 10:21:40 +0000698 // Offset
699 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
700 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100701
Gian Marco36a0a462018-01-12 10:21:40 +0000702 // src_addr_a = address of matrix A
703 // src_addr_b = address of matrix B
704 __global short *src_addr_a = (__global short *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
705 __global short *src_addr_b = (__global short *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100706
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000707 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000708 __global short *src_end_addr_b = src_addr_b + COLS_B;
709
710 src_addr_a += offset_row_a;
711 src_addr_b += offset_row_b;
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100712
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000713 // Reset accumulators
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100714 int8 c00 = 0.0f;
715 int8 c10 = 0.0f;
716 int8 c20 = 0.0f;
717 int8 c30 = 0.0f;
718
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000719 // This for loop performs 1 accumulation for each iteration
Gian Marco36a0a462018-01-12 10:21:40 +0000720 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100721 {
722 /* Load values from matrix A (interleaved) and matrix B (transposed) */
Gian Marco36a0a462018-01-12 10:21:40 +0000723 short4 a0 = vload4(0, src_addr_a);
724 short8 b0 = vload8(0, src_addr_b);
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100725
726 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
727 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
728 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
729 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
730 }
731
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000732 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100733 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
734
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000735 // Multiply by the weight of matrix product
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100736 short8 c00_qs16 = convert_short8_sat(c00);
737 short8 c10_qs16 = convert_short8_sat(c10);
738 short8 c20_qs16 = convert_short8_sat(c20);
739 short8 c30_qs16 = convert_short8_sat(c30);
740
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000741#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100742 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
743 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
744 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
745 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000746#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100747
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000748 // Store 8x4 block
Gian Marco Iodice8a383692017-07-03 17:41:47 +0100749 vstore8(c00_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
750 vstore8(c10_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
751 vstore8(c20_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
752 vstore8(c30_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
753}
754#endif // defined(FIXED_POINT_POSITION)
Gian Marco36a0a462018-01-12 10:21:40 +0000755#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100756
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100757#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
758#if defined(DATA_TYPE)
759#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
760/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100761 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100762 * @note This OpenCL kernel works with floating point data types (F16/F32)
763 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
764 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000765 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100766 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100767 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100768 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
769 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
770 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
771 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
772 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100773 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100774 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
775 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
776 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
777 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
778 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100779 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100780 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
781 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
782 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
783 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
784 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
785 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100786__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
787 IMAGE_DECLARATION(src1),
788 IMAGE_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100789{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100790 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100791
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100792 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100793 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100795 // Update address for the matrix A
796 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100797
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100798 // Update address for the matrix B
799 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100800
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100801 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
802
803 VECTOR_TYPE acc0 = 0.0f;
804#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
805 VECTOR_TYPE acc1 = 0.0f;
806#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
807#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
808 VECTOR_TYPE acc2 = 0.0f;
809#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
810#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
811 VECTOR_TYPE acc3 = 0.0f;
812#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
813
Georgios Pinitas96880cf2017-10-20 18:52:20 +0100814 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100815 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100816 // Load values from matrix A
817 VEC_DATA_TYPE(DATA_TYPE, 2)
818 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
819#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
820 VEC_DATA_TYPE(DATA_TYPE, 2)
821 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
822#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
823#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
824 VEC_DATA_TYPE(DATA_TYPE, 2)
825 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
826#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
827#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
828 VEC_DATA_TYPE(DATA_TYPE, 2)
829 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
831 // Load values from matrix B
832 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
833 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100834
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100835 // Accumulate
836 acc0 += b0 * (VECTOR_TYPE)a0.s0;
837 acc0 += b1 * (VECTOR_TYPE)a0.s1;
838#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
839 acc1 += b0 * (VECTOR_TYPE)a1.s0;
840 acc1 += b1 * (VECTOR_TYPE)a1.s1;
841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
843 acc2 += b0 * (VECTOR_TYPE)a2.s0;
844 acc2 += b1 * (VECTOR_TYPE)a2.s1;
845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
846#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
847 acc3 += b0 * (VECTOR_TYPE)a3.s0;
848 acc3 += b1 * (VECTOR_TYPE)a3.s1;
849#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100850 }
851
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100852 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100853 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100854 // Load values from matrix A
855 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
856#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
857 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
858#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
859#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
860 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
861#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
862#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
863 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
864#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
865 // Load values from matrix B
866 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100867
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100868 // Accumulate
869 acc0 += b0 * (VECTOR_TYPE)a0;
870#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
871 acc1 += b0 * (VECTOR_TYPE)a1;
872#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
873#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
874 acc2 += b0 * (VECTOR_TYPE)a2;
875#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
877 acc3 += b0 * (VECTOR_TYPE)a3;
878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100879 }
880
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100881 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100882 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
883
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100884 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000885#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100886 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000887#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100888 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
889 (acc0, 0, (__global DATA_TYPE *)(offset(&dst, 0, 0)));
890#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000891#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100892 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000893#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100894 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
895 (acc1, 0, (__global DATA_TYPE *)(offset(&dst, 0, 1)));
896#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
897#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000898#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100899 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000900#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100901 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
902 (acc2, 0, (__global DATA_TYPE *)(offset(&dst, 0, 2)));
903#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
904#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000905#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100906 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000907#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100908 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
909 (acc3, 0, (__global DATA_TYPE *)(offset(&dst, 0, 3)));
910#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100911}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100912#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100913
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000914/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
915 *
916 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
917 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
918 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
919 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
920 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
921 *
922 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
923 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
924 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
925 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
926 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
927 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
928 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
929 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
930 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
931 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
932 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
933 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
934 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
935 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
936 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
937 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
938 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
939 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
940 */
941__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
942 IMAGE_DECLARATION(src1),
943 IMAGE_DECLARATION(dst))
944{
945 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
946
947 // Compute starting address for matrix A and matrix B
948 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
949
950 // Update address for matrix A
951 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
952
953 // Update address for matrix B
954 src_addr.s1 += idx * sizeof(float);
955
956 // Address boundary for matrix A
957 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
958
959 // Initialize accumulators
960 float acc00 = 0.0f;
961 float acc01 = 0.0f;
962 float acc02 = 0.0f;
963 float acc03 = 0.0f;
964
965#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
966 float acc10 = 0.0f;
967 float acc11 = 0.0f;
968 float acc12 = 0.0f;
969 float acc13 = 0.0f;
970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
971
972#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
973 float acc20 = 0.0f;
974 float acc21 = 0.0f;
975 float acc22 = 0.0f;
976 float acc23 = 0.0f;
977#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
978
979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
980 float acc30 = 0.0f;
981 float acc31 = 0.0f;
982 float acc32 = 0.0f;
983 float acc33 = 0.0f;
984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
985
986 // A and B src indices get incremented at the same time.
987 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
988 {
989 // Load values from matrix A
990 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
991#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
992 float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
993#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
994#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
995 float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
996#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
998 float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1000 // Load values from matrix B
1001 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1002 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1003
1004 // Multiply and accumulate
1005 acc00 = fma(a0.s0, b0.s0, acc00);
1006 acc00 = fma(a0.s1, b1.s0, acc00);
1007 acc01 = fma(a0.s0, b0.s1, acc01);
1008 acc01 = fma(a0.s1, b1.s1, acc01);
1009 acc02 = fma(a0.s0, b0.s2, acc02);
1010 acc02 = fma(a0.s1, b1.s2, acc02);
1011 acc03 = fma(a0.s1, b1.s3, acc03);
1012 acc03 = fma(a0.s0, b0.s3, acc03);
1013
1014#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1015 acc10 = fma(a1.s0, b0.s0, acc10);
1016 acc11 = fma(a1.s0, b0.s1, acc11);
1017 acc12 = fma(a1.s0, b0.s2, acc12);
1018 acc13 = fma(a1.s0, b0.s3, acc13);
1019
1020 acc10 = fma(a1.s1, b1.s0, acc10);
1021 acc11 = fma(a1.s1, b1.s1, acc11);
1022 acc12 = fma(a1.s1, b1.s2, acc12);
1023 acc13 = fma(a1.s1, b1.s3, acc13);
1024#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1025#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1026 acc20 = fma(a2.s0, b0.s0, acc20);
1027 acc21 = fma(a2.s0, b0.s1, acc21);
1028 acc22 = fma(a2.s0, b0.s2, acc22);
1029 acc23 = fma(a2.s0, b0.s3, acc23);
1030
1031 acc20 = fma(a2.s1, b1.s0, acc20);
1032 acc21 = fma(a2.s1, b1.s1, acc21);
1033 acc22 = fma(a2.s1, b1.s2, acc22);
1034 acc23 = fma(a2.s1, b1.s3, acc23);
1035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1037 acc30 = fma(a3.s0, b0.s0, acc30);
1038 acc31 = fma(a3.s0, b0.s1, acc31);
1039 acc32 = fma(a3.s0, b0.s2, acc32);
1040 acc33 = fma(a3.s0, b0.s3, acc33);
1041
1042 acc30 = fma(a3.s1, b1.s0, acc30);
1043 acc31 = fma(a3.s1, b1.s1, acc31);
1044 acc32 = fma(a3.s1, b1.s2, acc32);
1045 acc33 = fma(a3.s1, b1.s3, acc33);
1046#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1047 }
1048
1049 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1050 {
1051 // Load values from matrix A
1052 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1054 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1056#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1057 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1060 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1061#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1062 // Load values from matrix B
1063 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1064
1065 // Multiply and accumulate
1066 acc00 = fma(a0, b0.s0, acc00);
1067 acc01 = fma(a0, b0.s1, acc01);
1068 acc02 = fma(a0, b0.s2, acc02);
1069 acc03 = fma(a0, b0.s3, acc03);
1070#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1071 acc10 = fma(a1, b0.s0, acc10);
1072 acc11 = fma(a1, b0.s1, acc11);
1073 acc12 = fma(a1, b0.s2, acc12);
1074 acc13 = fma(a1, b0.s3, acc13);
1075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1077 acc20 = fma(a2, b0.s0, acc20);
1078 acc21 = fma(a2, b0.s1, acc21);
1079 acc22 = fma(a2, b0.s2, acc22);
1080 acc23 = fma(a2, b0.s3, acc23);
1081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1082#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1083 acc30 = fma(a3, b0.s0, acc30);
1084 acc31 = fma(a3, b0.s1, acc31);
1085 acc32 = fma(a3, b0.s2, acc32);
1086 acc33 = fma(a3, b0.s3, acc33);
1087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1088 }
1089
1090 // Compute destination address
1091 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1092
1093 // Multiply by the weight of matrix-matrix product and store the result
1094#if defined(ALPHA)
1095 acc00 = acc00 * ALPHA;
1096 acc01 = acc01 * ALPHA;
1097 acc02 = acc02 * ALPHA;
1098 acc03 = acc03 * ALPHA;
1099#endif // defined(ALPHA)
1100
1101 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
1102 vstore4(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1103
1104#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1105#if defined(ALPHA)
1106 acc10 = acc10 * ALPHA;
1107 acc11 = acc11 * ALPHA;
1108 acc12 = acc12 * ALPHA;
1109 acc13 = acc13 * ALPHA;
1110#endif // defined(ALPHA)
1111 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
1112 vstore4(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1113#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1114#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1115#if defined(ALPHA)
1116 acc20 = acc20 * ALPHA;
1117 acc21 = acc21 * ALPHA;
1118 acc22 = acc22 * ALPHA;
1119 acc23 = acc23 * ALPHA;
1120#endif // defined(ALPHA)
1121 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
1122 vstore4(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1123#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1124#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1125#if defined(ALPHA)
1126 acc30 = acc30 * ALPHA;
1127 acc31 = acc31 * ALPHA;
1128 acc32 = acc32 * ALPHA;
1129 acc33 = acc33 * ALPHA;
1130#endif // defined(ALPHA)
1131 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
1132 vstore4(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1133#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1134}
1135
1136/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1137 *
1138 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1139 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1140 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1141 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1142 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1143 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
1144 *
1145 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1146 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1147 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1148 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1149 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1150 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1151 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1152 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1153 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1154 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1155 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1156 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1157 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1158 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1159 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1160 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1161 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1162 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1163 */
1164__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1165 IMAGE_DECLARATION(src1),
1166 IMAGE_DECLARATION(dst))
1167{
1168 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1169 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1170
1171 // Compute starting address for matrix A and Matrix B
1172 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1173
1174 // Update address for the matrix A
1175 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1176
1177 // Update address for the matrix B
1178 src_addr.s1 += idx * sizeof(float);
1179
1180 // Address boundary for the matrix A
1181 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1182
1183 // Initialize accumulators
1184 float acc00 = 0.0f;
1185 float acc01 = 0.0f;
1186
1187#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1188 float acc10 = 0.0f;
1189 float acc11 = 0.0f;
1190#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1191#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1192 float acc20 = 0.0f;
1193 float acc21 = 0.0f;
1194#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1196 float acc30 = 0.0f;
1197 float acc31 = 0.0f;
1198#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1199
1200 // A and B src indices get incremented at the same time.
1201 for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
1202 {
1203 // Load values from matrix A
1204 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1205
1206 // Load values from matrix B
1207 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1208 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1209 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
1210 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
1211
1212 // Multiply and accumulate
1213 acc00 = fma(a0.s0, b0.s0, acc00);
1214 acc00 = fma(a0.s1, b1.s0, acc00);
1215 acc00 = fma(a0.s2, b2.s0, acc00);
1216 acc00 = fma(a0.s3, b3.s0, acc00);
1217
1218 acc01 = fma(a0.s0, b0.s1, acc01);
1219 acc01 = fma(a0.s1, b1.s1, acc01);
1220 acc01 = fma(a0.s2, b2.s1, acc01);
1221 acc01 = fma(a0.s3, b3.s1, acc01);
1222
1223#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1224 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1225 acc10 = fma(a0.s0, b0.s0, acc10);
1226 acc10 = fma(a0.s1, b1.s0, acc10);
1227 acc10 = fma(a0.s2, b2.s0, acc10);
1228 acc10 = fma(a0.s3, b3.s0, acc10);
1229
1230 acc11 = fma(a0.s0, b0.s1, acc11);
1231 acc11 = fma(a0.s1, b1.s1, acc11);
1232 acc11 = fma(a0.s2, b2.s1, acc11);
1233 acc11 = fma(a0.s3, b3.s1, acc11);
1234#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1236 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1237 acc20 = fma(a0.s0, b0.s0, acc20);
1238 acc20 = fma(a0.s1, b1.s0, acc20);
1239 acc20 = fma(a0.s2, b2.s0, acc20);
1240 acc20 = fma(a0.s3, b3.s0, acc20);
1241
1242 acc21 = fma(a0.s0, b0.s1, acc21);
1243 acc21 = fma(a0.s1, b1.s1, acc21);
1244 acc21 = fma(a0.s2, b2.s1, acc21);
1245 acc21 = fma(a0.s3, b3.s1, acc21);
1246#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1247#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1248 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1249 acc30 = fma(a0.s0, b0.s0, acc30);
1250 acc30 = fma(a0.s1, b1.s0, acc30);
1251 acc30 = fma(a0.s2, b2.s0, acc30);
1252 acc30 = fma(a0.s3, b3.s0, acc30);
1253
1254 acc31 = fma(a0.s0, b0.s1, acc31);
1255 acc31 = fma(a0.s1, b1.s1, acc31);
1256 acc31 = fma(a0.s2, b2.s1, acc31);
1257 acc31 = fma(a0.s3, b3.s1, acc31);
1258#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1259 }
1260 // float size increment
1261 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
1262 {
1263 // Load values from matrix A
1264 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1265#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1266 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1267#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1268#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1269 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1270#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1271#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1272 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1273#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1274 // Load values from matrix B
1275 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1276
1277 // Multiply and accumulate
1278 acc00 = fma(a0, b0.s0, acc00);
1279 acc01 = fma(a0, b0.s1, acc01);
1280#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1281 acc10 = fma(a1, b0.s0, acc10);
1282 acc11 = fma(a1, b0.s1, acc11);
1283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1285 acc20 = fma(a2, b0.s0, acc20);
1286 acc21 = fma(a2, b0.s1, acc21);
1287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1288#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1289 acc30 = fma(a3, b0.s0, acc30);
1290 acc31 = fma(a3, b0.s1, acc31);
1291#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1292 }
1293
1294 // Compute destination address
1295 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1296
1297 // Multiply by the weight of matrix-matrix product and store the result
1298#if defined(ALPHA)
1299 acc00 = acc00 * ALPHA;
1300 acc01 = acc01 * ALPHA;
1301#endif // defined(ALPHA)
1302 float2 acc0 = ((float2)(acc00, acc01));
1303 vstore2(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1304#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1305#if defined(ALPHA)
1306 acc10 = acc10 * ALPHA;
1307 acc11 = acc11 * ALPHA;
1308#endif // defined(ALPHA)
1309 float2 acc1 = ((float2)(acc10, acc11));
1310 vstore2(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1313#if defined(ALPHA)
1314 acc20 = acc20 * ALPHA;
1315 acc21 = acc21 * ALPHA;
1316#endif // defined(ALPHA)
1317 float2 acc2 = ((float2)(acc20, acc21));
1318 vstore2(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1321#if defined(ALPHA)
1322 acc30 = acc30 * ALPHA;
1323 acc31 = acc31 * ALPHA;
1324#endif // defined(ALPHA)
1325 float2 acc3 = (float2)(acc30, acc31);
1326 vstore2(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1327#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1328}
1329
1330#if defined(FIXED_POINT_POSITION)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001331/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001332 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001333 * @note This OpenCL kernel works with fixed point data types QS8
1334 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001335 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001336 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001337 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001338 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001339 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001340 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1341 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1342 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1343 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1344 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1345 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1346 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1347 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1348 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1349 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1350 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1351 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1352 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1353 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1354 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1355 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1356 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1357 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001358__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001359 IMAGE_DECLARATION(src1),
1360 IMAGE_DECLARATION(dst))
1361{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001362 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001363
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001364 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001365 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001366
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001367 // Update address for the matrix A
1368 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001369
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001370 // Update address for the matrix B
1371 src_addr.s1 += idx * sizeof(char);
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001372
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001373 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
1374
1375 short8 acc00 = 0;
1376 short8 acc01 = 0;
1377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1378 short8 acc10 = 0;
1379 short8 acc11 = 0;
1380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1382 short8 acc20 = 0;
1383 short8 acc21 = 0;
1384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1385#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1386 short8 acc30 = 0;
1387 short8 acc31 = 0;
1388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1389
1390 // This for loop performs 4 accumulations per iteration
1391 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001392 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001393 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1394#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1395 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1396#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1397#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1398 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1399#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1401 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001403 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1404 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001405
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001406 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
1407 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
1408 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1409 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1410#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1411 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
1412 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
1413 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1414 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1416#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1417 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
1418 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
1419 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1420 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1421#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1423 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
1424 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
1425 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1426 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001428 }
1429
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001430 // Left-over accumulations
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001431 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1432 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001433 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1434#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1435 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1437#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1438 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1440#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1441 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1442#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001443 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
1444
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001445 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
1446 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1448 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
1449 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
1450#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1451#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1452 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
1453 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
1454#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1455#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1456 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
1457 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
1458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001459 }
1460
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001461 // Compute destination address
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001462 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1463
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001464 // Multiply by the weight of matrix product and store the result
1465 char16 acc_qs8;
1466 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001467#if defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001468 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001469#endif // defined(ALPHA)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001470 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001471#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1472 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001473#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001474 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001475#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001476 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
1477#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1478#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1479 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001480#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001481 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001482#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001483 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
1484#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1485#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1486 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001487#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001488 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001489#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001490 vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
1491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001492}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001493
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001494/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001495 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001496 * @note This OpenCL kernel works with fixed point data types QS16
1497 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001498 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001499 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001500 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001501 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001502 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001503 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1504 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1505 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1506 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1507 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1508 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1509 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1510 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1511 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1512 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1513 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1514 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1515 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1516 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1517 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1518 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1519 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1520 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001521__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001522 IMAGE_DECLARATION(src1),
1523 IMAGE_DECLARATION(dst))
1524{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001525 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001526
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001527 // Compute starting address for matrix A and Matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001528 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001529
1530 // Update address for the matrix A
1531 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1532
1533 // Update address for the matrix B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001534 src_addr.s1 += idx * sizeof(short);
1535
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001536 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001537
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001538 int8 acc0 = 0;
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001539#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1540 int8 acc1 = 0;
1541#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1542#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1543 int8 acc2 = 0;
1544#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1546 int8 acc3 = 0;
1547#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001548
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001549 // This for loop performs 4 accumulations per iteration
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001550 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001551 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001552 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1553#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1554 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1555#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1556#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1557 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1558#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1559#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1560 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1561#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001562 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1563 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001564
1565 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1566 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001567#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1568 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
1569 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
1570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1572 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
1573 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
1574#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1575#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1576 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
1577 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
1578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001579 }
1580
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001581 // Left-over accumulations
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001582 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
1583 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001584 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1585#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1586 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1587#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1588#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1589 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1590#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1591#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1592 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1593#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001594 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
1595
1596 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001597#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1598 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
1599#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1600#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1601 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
1602#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1603#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1604 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
1605#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001606 }
1607
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001608 // Compute destination address
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001609 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1610
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001611 // Multiply by the weight of matrix product and store the result
1612 short8 acc_qs16;
1613 acc_qs16 = convert_short8_sat(acc0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001614#if defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001615 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001616#endif // defined(ALPHA)
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001617 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1619 acc_qs16 = convert_short8_sat(acc1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001620#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001621 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001622#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001623 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
1624#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1625#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1626 acc_qs16 = convert_short8_sat(acc2);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001627#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001628 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001629#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001630 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
1631#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1632#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1633 acc_qs16 = convert_short8_sat(acc3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001634#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001635 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001636#endif // defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001637 vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
1638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001639}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001640#endif // defined(FIXED_POINT_POSITION)
1641#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001642
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001643#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001644/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1645 *
1646 * @attention The beta's value need to be passed at compile time using -DBETA
1647 *
1648 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
1649 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1650 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1651 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1652 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1653 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001654 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001655 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1656 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1657 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1658 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1659 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1660 */
1661__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
1662 IMAGE_DECLARATION(dst))
1663{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001664 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001665 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1666 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1667
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001668 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001669 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
1670
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001671 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001672 float4 c = vload4(0, (__global float *)src.ptr);
1673
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001674 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001675 float4 out = alpha_ab + (float4)BETA * c;
1676
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001677 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001678 vstore4(out, 0, (__global float *)dst.ptr);
1679}
1680
1681/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1682 *
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001683 * @attention The beta's value need to be passed at compile time using -DBETA
1684 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001685 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
1686 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1687 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1688 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1689 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1690 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001691 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001692 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1693 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1694 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1695 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1696 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1697 */
1698__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
1699 IMAGE_DECLARATION(dst))
1700{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001701 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001702 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1703 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1704
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001705 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001706 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
1707
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001708 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001709 half8 c = vload8(0, (__global half *)src.ptr);
1710
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001711 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001712 half8 out = alpha_ab + (half8)BETA * c;
1713
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001714 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001715 vstore8(out, 0, (__global half *)dst.ptr);
1716}
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001717
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001718#if defined(FIXED_POINT_POSITION)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001719/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1720 *
1721 * @attention The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1722 *
1723 * @note: BETA must be passed in 8 bit fixed point format
1724 *
1725 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
1726 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1727 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1728 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1729 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1730 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
1731 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
1732 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1733 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1734 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1735 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1736 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1737 */
1738__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
1739 IMAGE_DECLARATION(dst))
1740{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001741 // Compute source and destination addresses
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001742 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1743 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1744
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001745 // Load values from A x B
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001746 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
1747
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001748 // Load values from Matrix C
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001749 char16 c = vload16(0, (__global char *)src.ptr);
1750
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001751 // Computes alpha * axb + beta * c
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001752 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
1753
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001754 // Store final result in axb matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001755 vstore16(out, 0, (__global char *)dst.ptr);
1756}
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001757
1758/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1759 *
1760 * @attention The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1761 *
1762 * @note: BETA must be passed in 16 bit fixed point format
1763 *
1764 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
1765 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1766 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1767 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1768 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1769 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
1770 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
1771 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1772 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1773 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1774 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1775 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1776 */
1777__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
1778 IMAGE_DECLARATION(dst))
1779{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001780 // Compute source and destination addresses
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001781 Image src = CONVERT_TO_IMAGE_STRUCT(src);
1782 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1783
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001784 // Load values from A x B
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001785 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
1786
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001787 // Load values from Matrix C
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001788 short8 c = vload8(0, (__global short *)src.ptr);
1789
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001790 // Computes alpha * axb + beta * c
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001791 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
1792
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001793 // Store final result in axb matrix
Gian Marco Iodice8a383692017-07-03 17:41:47 +01001794 vstore8(out, 0, (__global short *)dst.ptr);
1795}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001796#endif // defined(FIXED_POINT_POSITION)
1797#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001798
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001799#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001800/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
1801 *
1802 * @attention The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
1803 *
1804 * @attention The input A and matrix B must not be reshaped
1805 *
1806 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1807 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1808 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1809 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1810 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1811 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001812 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001813 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1814 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1815 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1816 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1817 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1818 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1819 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001820 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001821 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1822 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1823 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1824 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1825 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1826 */
1827__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
1828 TENSOR3D_DECLARATION(src1),
1829 IMAGE_DECLARATION(dst))
1830{
1831 int idx = get_global_id(0) * 4;
1832 int idy = get_global_id(1);
1833
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001834 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001835 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
1836 src_addr.s1 += idx * sizeof(float);
1837
1838 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
1839
1840 float4 acc = 0.0f;
1841
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001842 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001843 {
1844 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
1845 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1846 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
1847
1848 acc += b0 * (float4)a0.s0;
1849 acc += b1 * (float4)a0.s1;
1850 }
1851
1852 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1853 {
1854 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
1855 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1856
1857 acc += b0 * (float4)a0;
1858 }
1859
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001860 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001861 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1862
1863 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
1864}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001865#endif // defined(WIDTH_VECTOR_A)
1866
1867/** This kernel accumulates each row with the biases vector.
1868 *
1869 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
1870 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
1871 *
1872 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
1873 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
1874 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
1875 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
1876 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1877 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
1878 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
1879 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
1880 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1881 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
1882 */
1883#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
1884__kernel void gemm_accumulate_biases(
1885 IMAGE_DECLARATION(accum),
1886 VECTOR_DECLARATION(biases))
1887{
1888 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
1889 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
1890
1891 // Vector size, i.e. number of vector elements.
1892 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1893 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
1894 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1895 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
1896#ifdef FIXED_POINT_POSITION
1897 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
1898#else // FIXED_POINT_POSITION
1899 accum_value = biases_value + accum_value;
1900#endif // FIXED_POINT_POSITION
1901 // Store result in the accumulate buffer
1902 VSTORE(VECTOR_SIZE)
1903 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
1904}
1905#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)