blob: dcfbb1308124886124c76c30454b47438e48500a [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
26#include "arm_compute/core/AccessWindowTranspose.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Types.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
37
38#include <arm_neon.h>
39#include <cstddef>
40#include <cstdint>
41#include <tuple>
42
43using namespace arm_compute;
44
45namespace arm_compute
46{
47class Coordinates;
48} // namespace arm_compute
49
50namespace
51{
52template <bool multiply_alpha>
53void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
54{
55 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
56 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
57 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
58
59 // The implementation computes 16 elements per iteration
60 const int window_start_x = 16 * window.thread_id();
61 const int window_step_x = 16 * window.num_threads();
62 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
63 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
64
65 Window win_out(window);
66 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
67 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
68
69 Window win_a(window);
70 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
71 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
72
73 Window win_b;
74 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
75 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
76 if(input1->info()->num_dimensions() >= 3)
77 {
78 win_b = window;
79 }
80 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
81 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
82
83 Iterator ina(input0, win_a);
84 Iterator inb(input1, win_b);
85 Iterator out(output, win_out);
86
87 execute_window_loop(win_out, [&](const Coordinates & id)
88 {
89 if(id.x() > width_matrix_b)
90 {
91 return;
92 }
93
94 float32x4_t acc0 = vdupq_n_f32(0.f);
95 float32x4_t acc1 = vdupq_n_f32(0.f);
96 float32x4_t acc2 = vdupq_n_f32(0.f);
97 float32x4_t acc3 = vdupq_n_f32(0.f);
98
99 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
100 auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
101
102#if __arm__
103 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
104 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
105 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
106#endif
107
108 auto vec_a_end_addr = vec_a + num_elems_vec_a;
109 for(; vec_a <= (vec_a_end_addr - 4);)
110 {
111 float32x2_t a0l = vld1_f32(vec_a);
112
113 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
114 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
115 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
116 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
117
118 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
119 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
120 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
121 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
122
123#if __arm__
124 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
125 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
126 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
127 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
128 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
129#endif
130
131 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
132 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
133 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
134 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
135
136 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
137 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
138 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
139 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
140
141 vec_a += 2;
142 matrix_b += 2 * in_b_stride;
143
144 a0l = vld1_f32(vec_a);
145
146 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
147 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
148 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
149 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
150
151 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
152 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
153 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
154 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
155
156 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
157 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
158 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
159 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
160
161 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
162 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
163 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
164 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
165
166 vec_a += 2;
167 matrix_b += 2 * in_b_stride;
168 }
169
170 for(; vec_a < vec_a_end_addr;)
171 {
172 const float a0 = *vec_a;
173
174 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
175 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
176 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
177 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
178
179 acc0 = vmlaq_n_f32(acc0, b00, a0);
180 acc1 = vmlaq_n_f32(acc1, b01, a0);
181 acc2 = vmlaq_n_f32(acc2, b02, a0);
182 acc3 = vmlaq_n_f32(acc3, b03, a0);
183
184 vec_a += 1;
185 matrix_b += in_b_stride;
186 }
187
188 // Multiply by the weight of matrix product (alpha)
189 if(multiply_alpha)
190 {
191 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
192 acc0 = vmulq_f32(acc0, alpha_f32);
193 acc1 = vmulq_f32(acc1, alpha_f32);
194 acc2 = vmulq_f32(acc2, alpha_f32);
195 acc3 = vmulq_f32(acc3, alpha_f32);
196 }
197
198 const auto vec_out = reinterpret_cast<float *>(out.ptr());
199
200 vst1q_f32(vec_out + 0, acc0);
201 vst1q_f32(vec_out + 4, acc1);
202 vst1q_f32(vec_out + 8, acc2);
203 vst1q_f32(vec_out + 12, acc3);
204 },
205 ina, inb, out);
206}
207
208template <bool multiply_alpha>
209void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
210{
211 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
212 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
213 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
214 const int fixed_point_position = input0->info()->fixed_point_position();
215
216 // The implementation computes 32 elements per iteration
217 const int window_start_x = 32 * window.thread_id();
218 const int window_step_x = 32 * window.num_threads();
219 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
220 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
221
222 Window win_out(window);
223 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
224 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
225
226 Window win_a(window);
227 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
228 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
229
230 Window win_b;
231 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
232 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
233 if(input1->info()->num_dimensions() >= 3)
234 {
235 win_b = window;
236 }
237 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
238 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
239
240 Iterator ina(input0, win_a);
241 Iterator inb(input1, win_b);
242 Iterator out(output, win_out);
243
244 execute_window_loop(win_out, [&](const Coordinates & id)
245 {
246 if(id.x() > width_matrix_b)
247 {
248 return;
249 }
250
251 // Reset accumulators
252 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
253 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
254 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
255 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
256
257 auto vec_a = reinterpret_cast<const qint8_t *>(ina.ptr());
258 auto matrix_b = reinterpret_cast<const qint8_t *>(inb.ptr());
259
260 auto vec_a_end_addr = vec_a + num_elems_vec_a;
261 for(; vec_a <= (vec_a_end_addr - 2);)
262 {
263 const qint8x8_t a0 = vld1_dup_qs8(vec_a + 0);
264 const qint8x8_t a1 = vld1_dup_qs8(vec_a + 1);
265
266 const qint8x8_t b00 = vld1_qs8(matrix_b + 0 + 0 * in_b_stride);
267 const qint8x8_t b01 = vld1_qs8(matrix_b + 8 + 0 * in_b_stride);
268 const qint8x8_t b02 = vld1_qs8(matrix_b + 16 + 0 * in_b_stride);
269 const qint8x8_t b03 = vld1_qs8(matrix_b + 24 + 0 * in_b_stride);
270 const qint8x8_t b10 = vld1_qs8(matrix_b + 0 + 1 * in_b_stride);
271 const qint8x8_t b11 = vld1_qs8(matrix_b + 8 + 1 * in_b_stride);
272 const qint8x8_t b12 = vld1_qs8(matrix_b + 16 + 1 * in_b_stride);
273 const qint8x8_t b13 = vld1_qs8(matrix_b + 24 + 1 * in_b_stride);
274
275 // First accumulation
276 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
277 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
278 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
279 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
280
281 // Second accumulation
282 acc00_qs16 = vqmlal_qs8(acc00_qs16, b10, a1, fixed_point_position);
283 acc01_qs16 = vqmlal_qs8(acc01_qs16, b11, a1, fixed_point_position);
284 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a1, fixed_point_position);
285 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a1, fixed_point_position);
286
287 vec_a += 2;
288 matrix_b += 2 * in_b_stride;
289 }
290
291 for(; vec_a < vec_a_end_addr;)
292 {
293 const qint8x8_t a0 = vld1_dup_qs8(vec_a);
294
295 const qint8x8_t b00 = vld1_qs8(matrix_b + 0);
296 const qint8x8_t b01 = vld1_qs8(matrix_b + 8);
297 const qint8x8_t b02 = vld1_qs8(matrix_b + 16);
298 const qint8x8_t b03 = vld1_qs8(matrix_b + 24);
299
300 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
301 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
302 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
303 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
304
305 vec_a += 1;
306 matrix_b += in_b_stride;
307 }
308
309 // Convert back to qint8x8_t and saturate
310 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
311 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
312 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
313 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
314
315 // Multiply by the weight of the matrix product (alpha)
316 if(multiply_alpha)
317 {
318 const qint8x8_t alpha_qs8 = vdup_n_qs8(scvt_qs8_f32(alpha, fixed_point_position));
319 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
320 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
321 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
322 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
323 }
324
325 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
326
327 // Store 8x4 output elements
328 vst1_qs8(mtx_out0 + 0, acc00_qs8);
329 vst1_qs8(mtx_out0 + 8, acc01_qs8);
330 vst1_qs8(mtx_out0 + 16, acc02_qs8);
331 vst1_qs8(mtx_out0 + 24, acc03_qs8);
332 },
333 ina, inb, out);
334}
335
336template <bool multiply_alpha>
337void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
338{
339 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
340 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
341 const size_t out_stride2 = out_stride1 * 2;
342 const size_t out_stride3 = out_stride1 * 3;
343 const int num_elems_matrix_b_x = input1->info()->dimension(0);
344
345 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
346 Window win_a(window);
347 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
348 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
349
350 Window win_b;
351 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
352 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
353 if(input1->info()->num_dimensions() >= 3)
354 {
355 win_b = window;
356 }
357 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
358 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
359 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
360 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
361
362 Iterator ina(input0, win_a);
363 Iterator inb(input1, win_b);
364 Iterator out(output, window);
365
366 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
367 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
368 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
369 execute_window_loop(window, [&](const Coordinates & id)
370 {
371 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
372 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
373 auto mtx_b1 = mtx_b0 + in_b_stride;
374
375 float32x4_t acc00 = vdupq_n_f32(0.f);
376 float32x4_t acc10 = vdupq_n_f32(0.f);
377 float32x4_t acc20 = vdupq_n_f32(0.f);
378 float32x4_t acc30 = vdupq_n_f32(0.f);
379
380 float32x4_t acc01 = vdupq_n_f32(0.f);
381 float32x4_t acc11 = vdupq_n_f32(0.f);
382 float32x4_t acc21 = vdupq_n_f32(0.f);
383 float32x4_t acc31 = vdupq_n_f32(0.f);
384
385#if __arm__
386 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
387 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
388 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
389#endif
390
391 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
392 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
393 {
394 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
395 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
396 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
397 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
398
399 float32x4_t b00 = vld1q_f32(mtx_b0);
400 float32x4_t b10 = vld1q_f32(mtx_b1);
401 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
402 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
403
404#if __arm__
405 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
406 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
407 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
408#endif
409
410 // 4x4 block 0
411 acc00 = vmlaq_f32(acc00, b00, a0);
412 acc10 = vmlaq_f32(acc10, b00, a1);
413 acc20 = vmlaq_f32(acc20, b00, a2);
414 acc30 = vmlaq_f32(acc30, b00, a3);
415
416 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
417 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
418 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
419 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
420
421 // 4x4 block 1
422 acc01 = vmlaq_f32(acc01, b10, a0);
423 acc11 = vmlaq_f32(acc11, b10, a1);
424 acc21 = vmlaq_f32(acc21, b10, a2);
425 acc31 = vmlaq_f32(acc31, b10, a3);
426
427 // 4x4 block 0
428 acc00 = vmlaq_f32(acc00, b01, a4);
429 acc10 = vmlaq_f32(acc10, b01, a5);
430 acc20 = vmlaq_f32(acc20, b01, a6);
431 acc30 = vmlaq_f32(acc30, b01, a7);
432
433 // 4x4 block 1
434 acc01 = vmlaq_f32(acc01, b11, a4);
435 acc11 = vmlaq_f32(acc11, b11, a5);
436 acc21 = vmlaq_f32(acc21, b11, a6);
437 acc31 = vmlaq_f32(acc31, b11, a7);
438
439 mtx_a0 += 8;
440 mtx_b0 += 8;
441 mtx_b1 += 8;
442
443 a0 = vld1q_dup_f32(mtx_a0 + 0);
444 a1 = vld1q_dup_f32(mtx_a0 + 1);
445 a2 = vld1q_dup_f32(mtx_a0 + 2);
446 a3 = vld1q_dup_f32(mtx_a0 + 3);
447
448 b00 = vld1q_f32(mtx_b0);
449 b10 = vld1q_f32(mtx_b1);
450 b01 = vld1q_f32(mtx_b0 + 4);
451 b11 = vld1q_f32(mtx_b1 + 4);
452
453 // 4x4 block 0
454 acc00 = vmlaq_f32(acc00, b00, a0);
455 acc10 = vmlaq_f32(acc10, b00, a1);
456 acc20 = vmlaq_f32(acc20, b00, a2);
457 acc30 = vmlaq_f32(acc30, b00, a3);
458
459 a4 = vld1q_dup_f32(mtx_a0 + 4);
460 a5 = vld1q_dup_f32(mtx_a0 + 5);
461 a6 = vld1q_dup_f32(mtx_a0 + 6);
462 a7 = vld1q_dup_f32(mtx_a0 + 7);
463
464 // 4x4 block 1
465 acc01 = vmlaq_f32(acc01, b10, a0);
466 acc11 = vmlaq_f32(acc11, b10, a1);
467 acc21 = vmlaq_f32(acc21, b10, a2);
468 acc31 = vmlaq_f32(acc31, b10, a3);
469
470 // 4x4 block 0
471 acc00 = vmlaq_f32(acc00, b01, a4);
472 acc10 = vmlaq_f32(acc10, b01, a5);
473 acc20 = vmlaq_f32(acc20, b01, a6);
474 acc30 = vmlaq_f32(acc30, b01, a7);
475
476 // 4x4 block 1
477 acc01 = vmlaq_f32(acc01, b11, a4);
478 acc11 = vmlaq_f32(acc11, b11, a5);
479 acc21 = vmlaq_f32(acc21, b11, a6);
480 acc31 = vmlaq_f32(acc31, b11, a7);
481
482 mtx_a0 += 8;
483 mtx_b0 += 8;
484 mtx_b1 += 8;
485
486 a0 = vld1q_dup_f32(mtx_a0 + 0);
487 a1 = vld1q_dup_f32(mtx_a0 + 1);
488 a2 = vld1q_dup_f32(mtx_a0 + 2);
489 a3 = vld1q_dup_f32(mtx_a0 + 3);
490 b00 = vld1q_f32(mtx_b0);
491 b10 = vld1q_f32(mtx_b1);
492 b01 = vld1q_f32(mtx_b0 + 4);
493 b11 = vld1q_f32(mtx_b1 + 4);
494
495#if __arm__
496 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
497 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
498 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
499#endif
500
501 // 4x4 block 0
502 acc00 = vmlaq_f32(acc00, b00, a0);
503 acc10 = vmlaq_f32(acc10, b00, a1);
504 acc20 = vmlaq_f32(acc20, b00, a2);
505 acc30 = vmlaq_f32(acc30, b00, a3);
506
507 a4 = vld1q_dup_f32(mtx_a0 + 4);
508 a5 = vld1q_dup_f32(mtx_a0 + 5);
509 a6 = vld1q_dup_f32(mtx_a0 + 6);
510 a7 = vld1q_dup_f32(mtx_a0 + 7);
511
512 // 4x4 block 1
513 acc01 = vmlaq_f32(acc01, b10, a0);
514 acc11 = vmlaq_f32(acc11, b10, a1);
515 acc21 = vmlaq_f32(acc21, b10, a2);
516 acc31 = vmlaq_f32(acc31, b10, a3);
517
518 // 4x4 block 0
519 acc00 = vmlaq_f32(acc00, b01, a4);
520 acc10 = vmlaq_f32(acc10, b01, a5);
521 acc20 = vmlaq_f32(acc20, b01, a6);
522 acc30 = vmlaq_f32(acc30, b01, a7);
523
524 // 4x4 block 1
525 acc01 = vmlaq_f32(acc01, b11, a4);
526 acc11 = vmlaq_f32(acc11, b11, a5);
527 acc21 = vmlaq_f32(acc21, b11, a6);
528 acc31 = vmlaq_f32(acc31, b11, a7);
529
530 mtx_a0 += 8;
531 mtx_b0 += 8;
532 mtx_b1 += 8;
533
534 a0 = vld1q_dup_f32(mtx_a0 + 0);
535 a1 = vld1q_dup_f32(mtx_a0 + 1);
536 a2 = vld1q_dup_f32(mtx_a0 + 2);
537 a3 = vld1q_dup_f32(mtx_a0 + 3);
538 b00 = vld1q_f32(mtx_b0);
539 b10 = vld1q_f32(mtx_b1);
540 b01 = vld1q_f32(mtx_b0 + 4);
541 b11 = vld1q_f32(mtx_b1 + 4);
542
543 // 4x4 block 0
544 acc00 = vmlaq_f32(acc00, b00, a0);
545 acc10 = vmlaq_f32(acc10, b00, a1);
546 acc20 = vmlaq_f32(acc20, b00, a2);
547 acc30 = vmlaq_f32(acc30, b00, a3);
548
549 a4 = vld1q_dup_f32(mtx_a0 + 4);
550 a5 = vld1q_dup_f32(mtx_a0 + 5);
551 a6 = vld1q_dup_f32(mtx_a0 + 6);
552 a7 = vld1q_dup_f32(mtx_a0 + 7);
553
554 // 4x4 block 1
555 acc01 = vmlaq_f32(acc01, b10, a0);
556 acc11 = vmlaq_f32(acc11, b10, a1);
557 acc21 = vmlaq_f32(acc21, b10, a2);
558 acc31 = vmlaq_f32(acc31, b10, a3);
559
560 // 4x4 block 0
561 acc00 = vmlaq_f32(acc00, b01, a4);
562 acc10 = vmlaq_f32(acc10, b01, a5);
563 acc20 = vmlaq_f32(acc20, b01, a6);
564 acc30 = vmlaq_f32(acc30, b01, a7);
565
566 // 4x4 block 1
567 acc01 = vmlaq_f32(acc01, b11, a4);
568 acc11 = vmlaq_f32(acc11, b11, a5);
569 acc21 = vmlaq_f32(acc21, b11, a6);
570 acc31 = vmlaq_f32(acc31, b11, a7);
571
572 mtx_a0 += 8;
573 mtx_b0 += 8;
574 mtx_b1 += 8;
575 }
576
577 for(; mtx_b0 < mtx_b0_end_addr;)
578 {
579 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
580 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
581 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
582 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
583 float32x4_t b00 = vld1q_f32(mtx_b0);
584 float32x4_t b10 = vld1q_f32(mtx_b1);
585
586#if __arm__
587 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
588 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
589 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
590#endif
591 // 4x4 block 0
592 acc00 = vmlaq_f32(acc00, b00, a0);
593 acc10 = vmlaq_f32(acc10, b00, a1);
594 acc20 = vmlaq_f32(acc20, b00, a2);
595 acc30 = vmlaq_f32(acc30, b00, a3);
596
597 // 4x4 block 1
598 acc01 = vmlaq_f32(acc01, b10, a0);
599 acc11 = vmlaq_f32(acc11, b10, a1);
600 acc21 = vmlaq_f32(acc21, b10, a2);
601 acc31 = vmlaq_f32(acc31, b10, a3);
602
603 mtx_a0 += 4;
604 mtx_b0 += 4;
605 mtx_b1 += 4;
606 }
607
608 // Multiply by the weight of matrix product (alpha)
609 if(multiply_alpha)
610 {
611 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
612 acc00 = vmulq_f32(acc00, alpha_f32);
613 acc10 = vmulq_f32(acc10, alpha_f32);
614 acc20 = vmulq_f32(acc20, alpha_f32);
615 acc30 = vmulq_f32(acc30, alpha_f32);
616 acc01 = vmulq_f32(acc01, alpha_f32);
617 acc11 = vmulq_f32(acc11, alpha_f32);
618 acc21 = vmulq_f32(acc21, alpha_f32);
619 acc31 = vmulq_f32(acc31, alpha_f32);
620 }
621
622 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
623 const auto mtx_out1 = mtx_out0 + 4;
624
625 // Store the 4 blocks
626 vst1q_f32(mtx_out0, acc00);
627 vst1q_f32(mtx_out1, acc01);
628 vst1q_f32(mtx_out0 + out_stride1, acc10);
629 vst1q_f32(mtx_out1 + out_stride1, acc11);
630 vst1q_f32(mtx_out0 + out_stride2, acc20);
631 vst1q_f32(mtx_out1 + out_stride2, acc21);
632 vst1q_f32(mtx_out0 + out_stride3, acc30);
633 vst1q_f32(mtx_out1 + out_stride3, acc31);
634 },
635 ina, inb, out);
636}
637
638template <bool multiply_alpha>
639void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
640{
641#ifdef ARM_COMPUTE_ENABLE_FP16
642 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
643 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
644
645 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
646 Window win_a(window);
647 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
648 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
649
650 Window win_b;
651 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
652 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
653 if(input1->info()->num_dimensions() >= 3)
654 {
655 win_b = window;
656 }
657 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
658 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
659 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
660
661 Iterator ina(input0, win_a);
662 Iterator inb(input1, win_b);
663 Iterator out(output, window);
664
665 // Number of iterations of inner loop. Since 8 is the number of accumulations per loop, num_it = (width_mtx_b / 4) / 8
666 const size_t num_it = ((input1->info()->dimension(0)) >> 2) >> 3;
667
668 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
669
670 execute_window_loop(window, [&](const Coordinates & id)
671 {
672 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
673 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
674 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
675 float16x8x4_t c =
676 {
677 {
678 vdupq_n_f16(0.f),
679 vdupq_n_f16(0.f),
680 vdupq_n_f16(0.f),
681 vdupq_n_f16(0.f)
682 }
683 };
684
685 /*
686 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
687 |a00 a01 a02 a03 | a04 a05 a06 a07|
688 |a10 a11 a12 a13 | a14 a15 a16 a17|
689 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
690 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
691 |a40 a41 a42 a43 | a44 a45 a46 a47|
692 |a50 a51 a52 a53 | a54 a55 a56 a57|
693 |a60 a61 a62 a63 | a64 a65 a66 a67|
694 |a70 a71 a72 a73 | a74 a75 a76 a77|
695
696 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
697
698 B Matrix has been transposed as shown below
699
700 |b00 b01 b02 b03 b04 b05 b06 b07|
701 |b10 b11 b12 b13 b14 b15 b16 b17|
702 |b20 b21 b22 b23 b24 b25 b26 b27|
703 |b30 b31 b32 b33 b34 b35 b36 b37|
704 ------------------->
705
706 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
707
708 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
709 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
710
711 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
712 */
713 for(size_t k = num_it; k > 0; mtx_a0 += 16, mtx_b0 += 32, --k)
714 {
715 const float16x8_t p00 = vld1q_f16(mtx_a0);
716 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
717 const float16x8_t q00 = vld1q_f16(mtx_b0);
718 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
719 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
720 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
721
722 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
723 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
724 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
725 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
726
727 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
728 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
729 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
730 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
731
732 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
733 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
734 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
735 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
736
737 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
738 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
739 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
740 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
741 }
742
743 if(multiply_alpha)
744 {
745 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
746 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
747 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
748 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
749 }
750
751 vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
752 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
753 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
754 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
755 },
756 ina, inb, out);
757#else
758 ARM_COMPUTE_ERROR("Not implemented");
759#endif
760}
761
762template <bool multiply_alpha>
763void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
764{
765 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
766 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
767 const size_t out_stride2 = out_stride1 * 2;
768 const size_t out_stride3 = out_stride1 * 3;
769 const int num_elems_matrix_b_x = input1->info()->dimension(0);
770 const int fixed_point_position = input0->info()->fixed_point_position();
771 const qint8x8_t alpha_qs8 = vdup_n_qs8(scvt_qs8_f32(alpha, fixed_point_position));
772 ARM_COMPUTE_UNUSED(alpha_qs8);
773
774 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
775 Window win_a(window);
776 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
777 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
778
779 Window win_b;
780 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
781 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
782 if(input1->info()->num_dimensions() >= 3)
783 {
784 win_b = window;
785 }
786 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
787 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 16x4
788 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, 2 * in_b_stride));
789 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
790
791 Iterator ina(input0, win_a);
792 Iterator inb(input1, win_b);
793 Iterator out(output, window);
794
795 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
796 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
797 // All the values needed for computing a single 32x4 block will be read from consecutive memory positions
798 execute_window_loop(window, [&](const Coordinates & id)
799 {
800 auto mtx_a0 = reinterpret_cast<const qint8_t *>(ina.ptr());
801 auto mtx_b0 = reinterpret_cast<const qint8_t *>(inb.ptr());
802 auto mtx_b1 = mtx_b0 + in_b_stride;
803
804 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
805 qint16x8_t acc10_qs16 = vdupq_n_qs16(0);
806 qint16x8_t acc20_qs16 = vdupq_n_qs16(0);
807 qint16x8_t acc30_qs16 = vdupq_n_qs16(0);
808
809 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
810 qint16x8_t acc11_qs16 = vdupq_n_qs16(0);
811 qint16x8_t acc21_qs16 = vdupq_n_qs16(0);
812 qint16x8_t acc31_qs16 = vdupq_n_qs16(0);
813
814 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
815 qint16x8_t acc12_qs16 = vdupq_n_qs16(0);
816 qint16x8_t acc22_qs16 = vdupq_n_qs16(0);
817 qint16x8_t acc32_qs16 = vdupq_n_qs16(0);
818
819 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
820 qint16x8_t acc13_qs16 = vdupq_n_qs16(0);
821 qint16x8_t acc23_qs16 = vdupq_n_qs16(0);
822 qint16x8_t acc33_qs16 = vdupq_n_qs16(0);
823
824 int k = 0;
825 // This for loop performs 2 accumulations
826 for(; k <= (num_elems_matrix_b_x - 32); k += 32)
827 {
828 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
829 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
830 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
831 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
832 const qint8x8_t a4 = vld1_dup_qs8(mtx_a0 + 4);
833 const qint8x8_t a5 = vld1_dup_qs8(mtx_a0 + 5);
834 const qint8x8_t a6 = vld1_dup_qs8(mtx_a0 + 6);
835 const qint8x8_t a7 = vld1_dup_qs8(mtx_a0 + 7);
836
837 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
838 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
839 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
840 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
841
842 // First accumulation
843 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
844 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
845 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
846 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
847 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
848 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
849 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
850 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
851
852 const qint8x8_t b02 = vld1_qs8(mtx_b0 + 16);
853 const qint8x8_t b03 = vld1_qs8(mtx_b0 + 24);
854 const qint8x8_t b12 = vld1_qs8(mtx_b1 + 16);
855 const qint8x8_t b13 = vld1_qs8(mtx_b1 + 24);
856
857 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
858 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
859 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
860 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
861 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
862 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
863 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
864 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
865
866#if __arm__
867 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
868 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
869 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
870#endif
871
872 // Second accumulation
873 acc00_qs16 = vqmlal_qs8(acc00_qs16, b02, a4, fixed_point_position);
874 acc10_qs16 = vqmlal_qs8(acc10_qs16, b02, a5, fixed_point_position);
875 acc20_qs16 = vqmlal_qs8(acc20_qs16, b02, a6, fixed_point_position);
876 acc30_qs16 = vqmlal_qs8(acc30_qs16, b02, a7, fixed_point_position);
877 acc01_qs16 = vqmlal_qs8(acc01_qs16, b03, a4, fixed_point_position);
878 acc11_qs16 = vqmlal_qs8(acc11_qs16, b03, a5, fixed_point_position);
879 acc21_qs16 = vqmlal_qs8(acc21_qs16, b03, a6, fixed_point_position);
880 acc31_qs16 = vqmlal_qs8(acc31_qs16, b03, a7, fixed_point_position);
881 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a4, fixed_point_position);
882 acc12_qs16 = vqmlal_qs8(acc12_qs16, b12, a5, fixed_point_position);
883 acc22_qs16 = vqmlal_qs8(acc22_qs16, b12, a6, fixed_point_position);
884 acc32_qs16 = vqmlal_qs8(acc32_qs16, b12, a7, fixed_point_position);
885 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a4, fixed_point_position);
886 acc13_qs16 = vqmlal_qs8(acc13_qs16, b13, a5, fixed_point_position);
887 acc23_qs16 = vqmlal_qs8(acc23_qs16, b13, a6, fixed_point_position);
888 acc33_qs16 = vqmlal_qs8(acc33_qs16, b13, a7, fixed_point_position);
889
890 mtx_a0 += 8;
891 mtx_b0 += 32;
892 mtx_b1 += 32;
893 }
894
895 // This for loop performs the left over accumulations
896 for(; k < num_elems_matrix_b_x; k += 16)
897 {
898 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
899 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
900 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
901 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
902
903 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
904 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
905 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
906 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
907
908 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
909 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
910 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
911 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
912 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
913 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
914 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
915 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
916 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
917 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
918 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
919 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
920 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
921 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
922 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
923 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
924
925 mtx_a0 += 4;
926 mtx_b0 += 16;
927 mtx_b1 += 16;
928 }
929
930 // Convert back to qint8x8_t and saturate
931 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
932 qint8x8_t acc10_qs8 = vqmovn_qs16(acc10_qs16);
933 qint8x8_t acc20_qs8 = vqmovn_qs16(acc20_qs16);
934 qint8x8_t acc30_qs8 = vqmovn_qs16(acc30_qs16);
935
936 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
937 qint8x8_t acc11_qs8 = vqmovn_qs16(acc11_qs16);
938 qint8x8_t acc21_qs8 = vqmovn_qs16(acc21_qs16);
939 qint8x8_t acc31_qs8 = vqmovn_qs16(acc31_qs16);
940
941 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
942 qint8x8_t acc12_qs8 = vqmovn_qs16(acc12_qs16);
943 qint8x8_t acc22_qs8 = vqmovn_qs16(acc22_qs16);
944 qint8x8_t acc32_qs8 = vqmovn_qs16(acc32_qs16);
945
946 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
947 qint8x8_t acc13_qs8 = vqmovn_qs16(acc13_qs16);
948 qint8x8_t acc23_qs8 = vqmovn_qs16(acc23_qs16);
949 qint8x8_t acc33_qs8 = vqmovn_qs16(acc33_qs16);
950
951 // Multiply by the weight of the matrix product (alpha)
952 if(multiply_alpha)
953 {
954 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
955 acc10_qs8 = vqmul_qs8(acc10_qs8, alpha_qs8, fixed_point_position);
956 acc20_qs8 = vqmul_qs8(acc20_qs8, alpha_qs8, fixed_point_position);
957 acc30_qs8 = vqmul_qs8(acc30_qs8, alpha_qs8, fixed_point_position);
958 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
959 acc11_qs8 = vqmul_qs8(acc11_qs8, alpha_qs8, fixed_point_position);
960 acc21_qs8 = vqmul_qs8(acc21_qs8, alpha_qs8, fixed_point_position);
961 acc31_qs8 = vqmul_qs8(acc31_qs8, alpha_qs8, fixed_point_position);
962 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
963 acc12_qs8 = vqmul_qs8(acc12_qs8, alpha_qs8, fixed_point_position);
964 acc22_qs8 = vqmul_qs8(acc22_qs8, alpha_qs8, fixed_point_position);
965 acc32_qs8 = vqmul_qs8(acc32_qs8, alpha_qs8, fixed_point_position);
966 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
967 acc13_qs8 = vqmul_qs8(acc13_qs8, alpha_qs8, fixed_point_position);
968 acc23_qs8 = vqmul_qs8(acc23_qs8, alpha_qs8, fixed_point_position);
969 acc33_qs8 = vqmul_qs8(acc33_qs8, alpha_qs8, fixed_point_position);
970 }
971
972 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
973
974 // Store 32x4 output elements
975 vst1_qs8(mtx_out0 + 0, acc00_qs8);
976 vst1_qs8(mtx_out0 + 8, acc01_qs8);
977 vst1_qs8(mtx_out0 + 16, acc02_qs8);
978 vst1_qs8(mtx_out0 + 24, acc03_qs8);
979 vst1_qs8(mtx_out0 + out_stride1 + 0, acc10_qs8);
980 vst1_qs8(mtx_out0 + out_stride1 + 8, acc11_qs8);
981 vst1_qs8(mtx_out0 + out_stride1 + 16, acc12_qs8);
982 vst1_qs8(mtx_out0 + out_stride1 + 24, acc13_qs8);
983 vst1_qs8(mtx_out0 + out_stride2 + 0, acc20_qs8);
984 vst1_qs8(mtx_out0 + out_stride2 + 8, acc21_qs8);
985 vst1_qs8(mtx_out0 + out_stride2 + 16, acc22_qs8);
986 vst1_qs8(mtx_out0 + out_stride2 + 24, acc23_qs8);
987 vst1_qs8(mtx_out0 + out_stride3 + 0, acc30_qs8);
988 vst1_qs8(mtx_out0 + out_stride3 + 8, acc31_qs8);
989 vst1_qs8(mtx_out0 + out_stride3 + 16, acc32_qs8);
990 vst1_qs8(mtx_out0 + out_stride3 + 24, acc33_qs8);
991 },
992 ina, inb, out);
993}
994
995} // namespace
996
997NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
998 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
999{
1000}
1001
1002void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
1003{
1004 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8);
1005 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32, DataType::QS8);
1006 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
1007 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
1008 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
1009 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
1010
1011 if(output->info()->dimension(1) == 1)
1012 {
1013 ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
1014 }
1015
1016 _input0 = input0;
1017 _input1 = input1;
1018 _output = output;
1019 _alpha = alpha;
1020
1021 unsigned int num_elems_processed_per_iteration_x = 0;
1022 const unsigned int num_elems_processed_per_iteration_y = 4;
1023
1024 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1025 if((output->info()->dimension(1) == 1))
1026 {
1027 switch(input0->info()->data_type())
1028 {
1029 case DataType::F32:
1030 {
1031 num_elems_processed_per_iteration_x = 16;
1032 break;
1033 }
1034 case DataType::QS8:
1035 {
1036 num_elems_processed_per_iteration_x = 32;
1037 break;
1038 }
1039 default:
1040 {
1041 ARM_COMPUTE_ERROR("Data type not supported");
1042 break;
1043 }
1044 }
1045
1046 // Configure kernel window
1047 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1048
1049 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
1050
1051 update_window_and_padding(win,
1052 AccessWindowHorizontal(input0->info(), 0, num_elems_processed_per_iteration_x),
1053 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration_x),
1054 output_access);
1055
1056 Coordinates coord;
1057 coord.set_num_dimensions(output->info()->num_dimensions());
1058 output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
1059
1060 INEKernel::configure(win);
1061 }
1062 else
1063 {
1064 switch(input0->info()->data_type())
1065 {
1066 case DataType::F32:
1067 {
1068 num_elems_processed_per_iteration_x = 8;
1069 break;
1070 }
1071 case DataType::QS8:
1072 {
1073 num_elems_processed_per_iteration_x = 32;
1074 break;
1075 }
1076 case DataType::F16:
1077 {
1078#ifdef ARM_COMPUTE_ENABLE_FP16
1079 num_elems_processed_per_iteration_x = 8;
1080 break;
1081#endif
1082 }
1083 default:
1084 {
1085 ARM_COMPUTE_ERROR("Data type not supported");
1086 break;
1087 }
1088 }
1089
1090 // Configure kernel window
1091 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1092
1093 AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
1094
1095 update_window_and_padding(win,
1096 AccessWindowRectangle(input0->info(), 0, 0, 4, 1, 1.f, 0.25f),
1097 AccessWindowTranspose(input1->info(), 0, 0, 4, 1, 0.f, 0.25f),
1098 output_access);
1099
1100 output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
1101
1102 INEKernel::configure(win);
1103 }
1104}
1105
1106void NEGEMMMatrixMultiplyKernel::run(const Window &window)
1107{
1108 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1109 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1110
1111 bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
1112
1113 // Check if the output tensor is a vector and the data type is F32. If so,the kernel runs the vector-matrix multiplication
1114 if((_output->info()->dimension(1) == 1))
1115 {
1116 switch(_input0->info()->data_type())
1117 {
1118 case DataType::F32:
1119 {
1120 multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1121 vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1122 break;
1123 }
1124 case DataType::QS8:
1125 {
1126 multiply_alpha ? vector_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1127 vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1128 break;
1129 }
1130 default:
1131 {
1132 ARM_COMPUTE_ERROR("Data type not supported");
1133 break;
1134 }
1135 }
1136 }
1137 else
1138 {
1139 switch(_input0->info()->data_type())
1140 {
1141 case DataType::F32:
1142 {
1143 multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1144 matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1145 break;
1146 }
1147 case DataType::QS8:
1148 {
1149 multiply_alpha ? matrix_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1150 matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1151 break;
1152 }
1153 case DataType::F16:
1154 {
1155#ifdef ARM_COMPUTE_ENABLE_FP16
1156 multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1157 matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1158 break;
1159#endif
1160 }
1161 default:
1162 {
1163 ARM_COMPUTE_ERROR("Data type not supported");
1164 break;
1165 }
1166 }
1167 }
1168}