blob: 4fcf6e2f379710b7adfdf612a963afa091feebf9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
Moritz Pflanzer484e7b32017-08-09 11:43:18 +010026#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/AccessWindowTranspose.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
38
39#include <arm_neon.h>
40#include <cstddef>
41#include <cstdint>
42#include <tuple>
43
44using namespace arm_compute;
45
46namespace arm_compute
47{
48class Coordinates;
49} // namespace arm_compute
50
51namespace
52{
53template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +010054void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Pablo Tello221f3812017-06-28 17:27:56 +010055{
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +010056#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello221f3812017-06-28 17:27:56 +010057 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
58 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
59 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
60
61 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +010062 const int window_start_x = 32 * info.thread_id;
63 const int window_step_x = 32 * info.num_threads;
Pablo Tello221f3812017-06-28 17:27:56 +010064 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
65 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
66
67 Window win_out(window);
68 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
69 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
70
71 Window win_a(window);
72 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
73 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
74
75 Window win_b;
76 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
77 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
78 if(input1->info()->num_dimensions() >= 3)
79 {
80 win_b = window;
81 }
82 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
83 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
84
85 Iterator ina(input0, win_a);
86 Iterator inb(input1, win_b);
87 Iterator out(output, win_out);
88
89 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
90 ARM_COMPUTE_UNUSED(alpha_f16);
91
92 execute_window_loop(win_out, [&](const Coordinates & id)
93 {
94 if(id.x() > width_matrix_b)
95 {
96 return;
97 }
98
99 float16x8_t acc0 = vdupq_n_f16(0.f);
100 float16x8_t acc1 = vdupq_n_f16(0.f);
101 float16x8_t acc2 = vdupq_n_f16(0.f);
102 float16x8_t acc3 = vdupq_n_f16(0.f);
103
104 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
105 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
106
107 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
108 for(; vec_a <= (vec_a_end_addr - 4);)
109 {
110 const float16x4_t a0l = vld1_f16(vec_a);
111
112 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
113 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
114 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
115 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
116 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
117 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
118 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
119 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
120
121 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
122 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
123 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
124 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
125 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
126 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
127 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
128 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
129
130 matrix_b += 2 * in_b_stride;
131
132 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
133 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
134 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
135 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
136 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
137 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
138 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
139 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
140
141 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
142 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
143 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
144 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
145 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
146 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
147 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
148 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
149
150 vec_a += 4;
151 matrix_b += 2 * in_b_stride;
152 }
153
154 for(; vec_a < vec_a_end_addr;)
155 {
156 const float16_t a0 = *vec_a;
157 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
158 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
159 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
160 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
161
162 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
163 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
164 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
165 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
166
167 vec_a += 1;
168 matrix_b += in_b_stride;
169 }
170
171 // Multiply by the weight of matrix product (alpha)
172 if(multiply_alpha)
173 {
174 acc0 = vmulq_f16(acc0, alpha_f16);
175 acc1 = vmulq_f16(acc1, alpha_f16);
176 acc2 = vmulq_f16(acc2, alpha_f16);
177 acc3 = vmulq_f16(acc3, alpha_f16);
178 }
179
180 const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
181
182 vst1q_f16(vec_out + 0, acc0);
183 vst1q_f16(vec_out + 8, acc1);
184 vst1q_f16(vec_out + 16, acc2);
185 vst1q_f16(vec_out + 24, acc3);
186
187 },
188 ina, inb, out);
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100189#else /* ARM_COMPUTE_AARCH64_V8_2 */
Georgios Pinitas30f02152017-09-27 11:20:48 +0100190 ARM_COMPUTE_UNUSED(input0);
191 ARM_COMPUTE_UNUSED(input1);
192 ARM_COMPUTE_UNUSED(output);
193 ARM_COMPUTE_UNUSED(window);
194 ARM_COMPUTE_UNUSED(info);
195 ARM_COMPUTE_UNUSED(alpha);
Pablo Tello221f3812017-06-28 17:27:56 +0100196 ARM_COMPUTE_ERROR("Not implemented");
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100197#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello221f3812017-06-28 17:27:56 +0100198}
199
200template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100201void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100202{
203 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
204 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
205 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
206
207 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100208 const int window_start_x = 16 * info.thread_id;
209 const int window_step_x = 16 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
211 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
212
213 Window win_out(window);
214 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
215 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
216
217 Window win_a(window);
218 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
219 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
220
221 Window win_b;
222 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
223 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
224 if(input1->info()->num_dimensions() >= 3)
225 {
226 win_b = window;
227 }
228 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
229 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
230
231 Iterator ina(input0, win_a);
232 Iterator inb(input1, win_b);
233 Iterator out(output, win_out);
234
235 execute_window_loop(win_out, [&](const Coordinates & id)
236 {
237 if(id.x() > width_matrix_b)
238 {
239 return;
240 }
241
242 float32x4_t acc0 = vdupq_n_f32(0.f);
243 float32x4_t acc1 = vdupq_n_f32(0.f);
244 float32x4_t acc2 = vdupq_n_f32(0.f);
245 float32x4_t acc3 = vdupq_n_f32(0.f);
246
247 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
248 auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
249
250#if __arm__
251 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
252 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
253 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100254#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255
256 auto vec_a_end_addr = vec_a + num_elems_vec_a;
257 for(; vec_a <= (vec_a_end_addr - 4);)
258 {
259 float32x2_t a0l = vld1_f32(vec_a);
260
261 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
262 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
263 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
264 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
265
266 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
267 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
268 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
269 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
270
271#if __arm__
272 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
273 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
274 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
275 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
276 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100277#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278
279 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
280 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
281 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
282 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
283
284 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
285 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
286 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
287 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
288
289 vec_a += 2;
290 matrix_b += 2 * in_b_stride;
291
292 a0l = vld1_f32(vec_a);
293
294 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
295 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
296 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
297 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
298
299 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
300 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
301 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
302 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
303
304 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
305 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
306 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
307 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
308
309 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
310 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
311 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
312 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
313
314 vec_a += 2;
315 matrix_b += 2 * in_b_stride;
316 }
317
318 for(; vec_a < vec_a_end_addr;)
319 {
320 const float a0 = *vec_a;
321
322 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
323 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
324 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
325 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
326
327 acc0 = vmlaq_n_f32(acc0, b00, a0);
328 acc1 = vmlaq_n_f32(acc1, b01, a0);
329 acc2 = vmlaq_n_f32(acc2, b02, a0);
330 acc3 = vmlaq_n_f32(acc3, b03, a0);
331
332 vec_a += 1;
333 matrix_b += in_b_stride;
334 }
335
336 // Multiply by the weight of matrix product (alpha)
337 if(multiply_alpha)
338 {
339 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
340 acc0 = vmulq_f32(acc0, alpha_f32);
341 acc1 = vmulq_f32(acc1, alpha_f32);
342 acc2 = vmulq_f32(acc2, alpha_f32);
343 acc3 = vmulq_f32(acc3, alpha_f32);
344 }
345
346 const auto vec_out = reinterpret_cast<float *>(out.ptr());
347
348 vst1q_f32(vec_out + 0, acc0);
349 vst1q_f32(vec_out + 4, acc1);
350 vst1q_f32(vec_out + 8, acc2);
351 vst1q_f32(vec_out + 12, acc3);
352 },
353 ina, inb, out);
354}
355
356template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100357void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100358{
359 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
360 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
361 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
362 const int fixed_point_position = input0->info()->fixed_point_position();
363
364 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100365 const int window_start_x = 32 * info.thread_id;
366 const int window_step_x = 32 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100367 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
368 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
369
370 Window win_out(window);
371 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
372 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
373
374 Window win_a(window);
375 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
376 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
377
378 Window win_b;
379 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
380 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
381 if(input1->info()->num_dimensions() >= 3)
382 {
383 win_b = window;
384 }
385 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
386 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
387
388 Iterator ina(input0, win_a);
389 Iterator inb(input1, win_b);
390 Iterator out(output, win_out);
391
392 execute_window_loop(win_out, [&](const Coordinates & id)
393 {
394 if(id.x() > width_matrix_b)
395 {
396 return;
397 }
398
399 // Reset accumulators
400 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
401 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
402 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
403 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
404
405 auto vec_a = reinterpret_cast<const qint8_t *>(ina.ptr());
406 auto matrix_b = reinterpret_cast<const qint8_t *>(inb.ptr());
407
408 auto vec_a_end_addr = vec_a + num_elems_vec_a;
409 for(; vec_a <= (vec_a_end_addr - 2);)
410 {
411 const qint8x8_t a0 = vld1_dup_qs8(vec_a + 0);
412 const qint8x8_t a1 = vld1_dup_qs8(vec_a + 1);
413
414 const qint8x8_t b00 = vld1_qs8(matrix_b + 0 + 0 * in_b_stride);
415 const qint8x8_t b01 = vld1_qs8(matrix_b + 8 + 0 * in_b_stride);
416 const qint8x8_t b02 = vld1_qs8(matrix_b + 16 + 0 * in_b_stride);
417 const qint8x8_t b03 = vld1_qs8(matrix_b + 24 + 0 * in_b_stride);
418 const qint8x8_t b10 = vld1_qs8(matrix_b + 0 + 1 * in_b_stride);
419 const qint8x8_t b11 = vld1_qs8(matrix_b + 8 + 1 * in_b_stride);
420 const qint8x8_t b12 = vld1_qs8(matrix_b + 16 + 1 * in_b_stride);
421 const qint8x8_t b13 = vld1_qs8(matrix_b + 24 + 1 * in_b_stride);
422
423 // First accumulation
424 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
425 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
426 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
427 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
428
429 // Second accumulation
430 acc00_qs16 = vqmlal_qs8(acc00_qs16, b10, a1, fixed_point_position);
431 acc01_qs16 = vqmlal_qs8(acc01_qs16, b11, a1, fixed_point_position);
432 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a1, fixed_point_position);
433 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a1, fixed_point_position);
434
435 vec_a += 2;
436 matrix_b += 2 * in_b_stride;
437 }
438
439 for(; vec_a < vec_a_end_addr;)
440 {
441 const qint8x8_t a0 = vld1_dup_qs8(vec_a);
442
443 const qint8x8_t b00 = vld1_qs8(matrix_b + 0);
444 const qint8x8_t b01 = vld1_qs8(matrix_b + 8);
445 const qint8x8_t b02 = vld1_qs8(matrix_b + 16);
446 const qint8x8_t b03 = vld1_qs8(matrix_b + 24);
447
448 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
449 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
450 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
451 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
452
453 vec_a += 1;
454 matrix_b += in_b_stride;
455 }
456
457 // Convert back to qint8x8_t and saturate
458 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
459 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
460 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
461 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
462
463 // Multiply by the weight of the matrix product (alpha)
464 if(multiply_alpha)
465 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100466 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100467 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
468 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
469 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
470 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
471 }
472
473 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
474
475 // Store 8x4 output elements
476 vst1_qs8(mtx_out0 + 0, acc00_qs8);
477 vst1_qs8(mtx_out0 + 8, acc01_qs8);
478 vst1_qs8(mtx_out0 + 16, acc02_qs8);
479 vst1_qs8(mtx_out0 + 24, acc03_qs8);
480 },
481 ina, inb, out);
482}
483
484template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100485void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100486{
487 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
488 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
489 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
490 const int fixed_point_position = input0->info()->fixed_point_position();
491
492 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100493 const int window_start_x = 16 * info.thread_id;
494 const int window_step_x = 16 * info.num_threads;
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100495 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
496 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
497 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
498
499 Window win_out(window);
500 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
501 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
502
503 Window win_a(window);
504 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
505 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
506
507 Window win_b;
508 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
509 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
510 if(input1->info()->num_dimensions() >= 3)
511 {
512 win_b = window;
513 }
514 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
515 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
516
517 Iterator ina(input0, win_a);
518 Iterator inb(input1, win_b);
519 Iterator out(output, win_out);
520
521 execute_window_loop(win_out, [&](const Coordinates & id)
522 {
523 if(id.x() > width_matrix_b)
524 {
525 return;
526 }
527
528 // Reset accumulators
529 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
530 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
531 qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
532 qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
533
534 auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
535 auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
536
537 auto vec_a_end_addr = vec_a + num_elems_vec_a;
538 for(; vec_a <= (vec_a_end_addr - 2);)
539 {
540 const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
541 const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
542
543 const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
544 const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
545 const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
546 const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
547 const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
548 const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
549 const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
550 const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
551
552 // First accumulation
553 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
554 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
555 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
556 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
557
558 // Second accumulation
559 acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
560 acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
561 acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
562 acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
563
564 vec_a += 2;
565 matrix_b += 2 * in_b_stride;
566 }
567
568 for(; vec_a < vec_a_end_addr;)
569 {
570 const qint16x4_t a0 = vld1_dup_qs16(vec_a);
571
572 const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
573 const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
574 const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
575 const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
576
577 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
578 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
579 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
580 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
581
582 vec_a += 1;
583 matrix_b += in_b_stride;
584 }
585
586 // Convert back to qint16x4_t and saturate
587 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
588 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
589 qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
590 qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
591
592 // Multiply by the weight of the matrix product (alpha)
593 if(multiply_alpha)
594 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100595 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100596 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
597 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
598 acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
599 acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
600 }
601
602 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
603
604 // Store 16x4 output elements
605 vst1_qs16(mtx_out0 + 0, acc00_qs16);
606 vst1_qs16(mtx_out0 + 4, acc01_qs16);
607 vst1_qs16(mtx_out0 + 8, acc02_qs16);
608 vst1_qs16(mtx_out0 + 12, acc03_qs16);
609 },
610 ina, inb, out);
611}
612
613template <bool multiply_alpha>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
615{
616 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
617 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
618 const size_t out_stride2 = out_stride1 * 2;
619 const size_t out_stride3 = out_stride1 * 3;
620 const int num_elems_matrix_b_x = input1->info()->dimension(0);
621
622 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
623 Window win_a(window);
624 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
625 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
626
627 Window win_b;
628 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
629 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
630 if(input1->info()->num_dimensions() >= 3)
631 {
632 win_b = window;
633 }
634 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
635 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
636 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
637 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
638
639 Iterator ina(input0, win_a);
640 Iterator inb(input1, win_b);
641 Iterator out(output, window);
642
643 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
644 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
645 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
646 execute_window_loop(window, [&](const Coordinates & id)
647 {
648 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
649 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
650 auto mtx_b1 = mtx_b0 + in_b_stride;
651
652 float32x4_t acc00 = vdupq_n_f32(0.f);
653 float32x4_t acc10 = vdupq_n_f32(0.f);
654 float32x4_t acc20 = vdupq_n_f32(0.f);
655 float32x4_t acc30 = vdupq_n_f32(0.f);
656
657 float32x4_t acc01 = vdupq_n_f32(0.f);
658 float32x4_t acc11 = vdupq_n_f32(0.f);
659 float32x4_t acc21 = vdupq_n_f32(0.f);
660 float32x4_t acc31 = vdupq_n_f32(0.f);
661
662#if __arm__
663 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
664 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
665 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100666#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100667
668 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
669 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
670 {
671 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
672 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
673 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
674 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
675
676 float32x4_t b00 = vld1q_f32(mtx_b0);
677 float32x4_t b10 = vld1q_f32(mtx_b1);
678 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
679 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
680
681#if __arm__
682 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
683 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
684 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100685#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100686
687 // 4x4 block 0
688 acc00 = vmlaq_f32(acc00, b00, a0);
689 acc10 = vmlaq_f32(acc10, b00, a1);
690 acc20 = vmlaq_f32(acc20, b00, a2);
691 acc30 = vmlaq_f32(acc30, b00, a3);
692
693 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
694 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
695 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
696 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
697
698 // 4x4 block 1
699 acc01 = vmlaq_f32(acc01, b10, a0);
700 acc11 = vmlaq_f32(acc11, b10, a1);
701 acc21 = vmlaq_f32(acc21, b10, a2);
702 acc31 = vmlaq_f32(acc31, b10, a3);
703
704 // 4x4 block 0
705 acc00 = vmlaq_f32(acc00, b01, a4);
706 acc10 = vmlaq_f32(acc10, b01, a5);
707 acc20 = vmlaq_f32(acc20, b01, a6);
708 acc30 = vmlaq_f32(acc30, b01, a7);
709
710 // 4x4 block 1
711 acc01 = vmlaq_f32(acc01, b11, a4);
712 acc11 = vmlaq_f32(acc11, b11, a5);
713 acc21 = vmlaq_f32(acc21, b11, a6);
714 acc31 = vmlaq_f32(acc31, b11, a7);
715
716 mtx_a0 += 8;
717 mtx_b0 += 8;
718 mtx_b1 += 8;
719
720 a0 = vld1q_dup_f32(mtx_a0 + 0);
721 a1 = vld1q_dup_f32(mtx_a0 + 1);
722 a2 = vld1q_dup_f32(mtx_a0 + 2);
723 a3 = vld1q_dup_f32(mtx_a0 + 3);
724
725 b00 = vld1q_f32(mtx_b0);
726 b10 = vld1q_f32(mtx_b1);
727 b01 = vld1q_f32(mtx_b0 + 4);
728 b11 = vld1q_f32(mtx_b1 + 4);
729
730 // 4x4 block 0
731 acc00 = vmlaq_f32(acc00, b00, a0);
732 acc10 = vmlaq_f32(acc10, b00, a1);
733 acc20 = vmlaq_f32(acc20, b00, a2);
734 acc30 = vmlaq_f32(acc30, b00, a3);
735
736 a4 = vld1q_dup_f32(mtx_a0 + 4);
737 a5 = vld1q_dup_f32(mtx_a0 + 5);
738 a6 = vld1q_dup_f32(mtx_a0 + 6);
739 a7 = vld1q_dup_f32(mtx_a0 + 7);
740
741 // 4x4 block 1
742 acc01 = vmlaq_f32(acc01, b10, a0);
743 acc11 = vmlaq_f32(acc11, b10, a1);
744 acc21 = vmlaq_f32(acc21, b10, a2);
745 acc31 = vmlaq_f32(acc31, b10, a3);
746
747 // 4x4 block 0
748 acc00 = vmlaq_f32(acc00, b01, a4);
749 acc10 = vmlaq_f32(acc10, b01, a5);
750 acc20 = vmlaq_f32(acc20, b01, a6);
751 acc30 = vmlaq_f32(acc30, b01, a7);
752
753 // 4x4 block 1
754 acc01 = vmlaq_f32(acc01, b11, a4);
755 acc11 = vmlaq_f32(acc11, b11, a5);
756 acc21 = vmlaq_f32(acc21, b11, a6);
757 acc31 = vmlaq_f32(acc31, b11, a7);
758
759 mtx_a0 += 8;
760 mtx_b0 += 8;
761 mtx_b1 += 8;
762
763 a0 = vld1q_dup_f32(mtx_a0 + 0);
764 a1 = vld1q_dup_f32(mtx_a0 + 1);
765 a2 = vld1q_dup_f32(mtx_a0 + 2);
766 a3 = vld1q_dup_f32(mtx_a0 + 3);
767 b00 = vld1q_f32(mtx_b0);
768 b10 = vld1q_f32(mtx_b1);
769 b01 = vld1q_f32(mtx_b0 + 4);
770 b11 = vld1q_f32(mtx_b1 + 4);
771
772#if __arm__
773 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
774 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
775 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100776#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100777
778 // 4x4 block 0
779 acc00 = vmlaq_f32(acc00, b00, a0);
780 acc10 = vmlaq_f32(acc10, b00, a1);
781 acc20 = vmlaq_f32(acc20, b00, a2);
782 acc30 = vmlaq_f32(acc30, b00, a3);
783
784 a4 = vld1q_dup_f32(mtx_a0 + 4);
785 a5 = vld1q_dup_f32(mtx_a0 + 5);
786 a6 = vld1q_dup_f32(mtx_a0 + 6);
787 a7 = vld1q_dup_f32(mtx_a0 + 7);
788
789 // 4x4 block 1
790 acc01 = vmlaq_f32(acc01, b10, a0);
791 acc11 = vmlaq_f32(acc11, b10, a1);
792 acc21 = vmlaq_f32(acc21, b10, a2);
793 acc31 = vmlaq_f32(acc31, b10, a3);
794
795 // 4x4 block 0
796 acc00 = vmlaq_f32(acc00, b01, a4);
797 acc10 = vmlaq_f32(acc10, b01, a5);
798 acc20 = vmlaq_f32(acc20, b01, a6);
799 acc30 = vmlaq_f32(acc30, b01, a7);
800
801 // 4x4 block 1
802 acc01 = vmlaq_f32(acc01, b11, a4);
803 acc11 = vmlaq_f32(acc11, b11, a5);
804 acc21 = vmlaq_f32(acc21, b11, a6);
805 acc31 = vmlaq_f32(acc31, b11, a7);
806
807 mtx_a0 += 8;
808 mtx_b0 += 8;
809 mtx_b1 += 8;
810
811 a0 = vld1q_dup_f32(mtx_a0 + 0);
812 a1 = vld1q_dup_f32(mtx_a0 + 1);
813 a2 = vld1q_dup_f32(mtx_a0 + 2);
814 a3 = vld1q_dup_f32(mtx_a0 + 3);
815 b00 = vld1q_f32(mtx_b0);
816 b10 = vld1q_f32(mtx_b1);
817 b01 = vld1q_f32(mtx_b0 + 4);
818 b11 = vld1q_f32(mtx_b1 + 4);
819
820 // 4x4 block 0
821 acc00 = vmlaq_f32(acc00, b00, a0);
822 acc10 = vmlaq_f32(acc10, b00, a1);
823 acc20 = vmlaq_f32(acc20, b00, a2);
824 acc30 = vmlaq_f32(acc30, b00, a3);
825
826 a4 = vld1q_dup_f32(mtx_a0 + 4);
827 a5 = vld1q_dup_f32(mtx_a0 + 5);
828 a6 = vld1q_dup_f32(mtx_a0 + 6);
829 a7 = vld1q_dup_f32(mtx_a0 + 7);
830
831 // 4x4 block 1
832 acc01 = vmlaq_f32(acc01, b10, a0);
833 acc11 = vmlaq_f32(acc11, b10, a1);
834 acc21 = vmlaq_f32(acc21, b10, a2);
835 acc31 = vmlaq_f32(acc31, b10, a3);
836
837 // 4x4 block 0
838 acc00 = vmlaq_f32(acc00, b01, a4);
839 acc10 = vmlaq_f32(acc10, b01, a5);
840 acc20 = vmlaq_f32(acc20, b01, a6);
841 acc30 = vmlaq_f32(acc30, b01, a7);
842
843 // 4x4 block 1
844 acc01 = vmlaq_f32(acc01, b11, a4);
845 acc11 = vmlaq_f32(acc11, b11, a5);
846 acc21 = vmlaq_f32(acc21, b11, a6);
847 acc31 = vmlaq_f32(acc31, b11, a7);
848
849 mtx_a0 += 8;
850 mtx_b0 += 8;
851 mtx_b1 += 8;
852 }
853
854 for(; mtx_b0 < mtx_b0_end_addr;)
855 {
856 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
857 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
858 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
859 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
860 float32x4_t b00 = vld1q_f32(mtx_b0);
861 float32x4_t b10 = vld1q_f32(mtx_b1);
862
863#if __arm__
864 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
865 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
866 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100867#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100868 // 4x4 block 0
869 acc00 = vmlaq_f32(acc00, b00, a0);
870 acc10 = vmlaq_f32(acc10, b00, a1);
871 acc20 = vmlaq_f32(acc20, b00, a2);
872 acc30 = vmlaq_f32(acc30, b00, a3);
873
874 // 4x4 block 1
875 acc01 = vmlaq_f32(acc01, b10, a0);
876 acc11 = vmlaq_f32(acc11, b10, a1);
877 acc21 = vmlaq_f32(acc21, b10, a2);
878 acc31 = vmlaq_f32(acc31, b10, a3);
879
880 mtx_a0 += 4;
881 mtx_b0 += 4;
882 mtx_b1 += 4;
883 }
884
885 // Multiply by the weight of matrix product (alpha)
886 if(multiply_alpha)
887 {
888 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
889 acc00 = vmulq_f32(acc00, alpha_f32);
890 acc10 = vmulq_f32(acc10, alpha_f32);
891 acc20 = vmulq_f32(acc20, alpha_f32);
892 acc30 = vmulq_f32(acc30, alpha_f32);
893 acc01 = vmulq_f32(acc01, alpha_f32);
894 acc11 = vmulq_f32(acc11, alpha_f32);
895 acc21 = vmulq_f32(acc21, alpha_f32);
896 acc31 = vmulq_f32(acc31, alpha_f32);
897 }
898
899 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
900 const auto mtx_out1 = mtx_out0 + 4;
901
902 // Store the 4 blocks
903 vst1q_f32(mtx_out0, acc00);
904 vst1q_f32(mtx_out1, acc01);
905 vst1q_f32(mtx_out0 + out_stride1, acc10);
906 vst1q_f32(mtx_out1 + out_stride1, acc11);
907 vst1q_f32(mtx_out0 + out_stride2, acc20);
908 vst1q_f32(mtx_out1 + out_stride2, acc21);
909 vst1q_f32(mtx_out0 + out_stride3, acc30);
910 vst1q_f32(mtx_out1 + out_stride3, acc31);
911 },
912 ina, inb, out);
913}
914
915template <bool multiply_alpha>
916void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
917{
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100918#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello221f3812017-06-28 17:27:56 +0100919 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
920 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
921 const int num_elems_matrix_b_x = input1->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100922
923 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
924 Window win_a(window);
925 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
926 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
927
928 Window win_b;
929 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
930 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
931 if(input1->info()->num_dimensions() >= 3)
932 {
933 win_b = window;
934 }
935 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
936 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
937 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
938
939 Iterator ina(input0, win_a);
940 Iterator inb(input1, win_b);
941 Iterator out(output, window);
942
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100943 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
944
945 execute_window_loop(window, [&](const Coordinates & id)
946 {
947 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
948 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
949 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
950 float16x8x4_t c =
951 {
952 {
953 vdupq_n_f16(0.f),
954 vdupq_n_f16(0.f),
955 vdupq_n_f16(0.f),
956 vdupq_n_f16(0.f)
957 }
958 };
959
960 /*
961 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
962 |a00 a01 a02 a03 | a04 a05 a06 a07|
963 |a10 a11 a12 a13 | a14 a15 a16 a17|
964 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
965 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
966 |a40 a41 a42 a43 | a44 a45 a46 a47|
967 |a50 a51 a52 a53 | a54 a55 a56 a57|
968 |a60 a61 a62 a63 | a64 a65 a66 a67|
969 |a70 a71 a72 a73 | a74 a75 a76 a77|
970
971 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
972
973 B Matrix has been transposed as shown below
974
975 |b00 b01 b02 b03 b04 b05 b06 b07|
976 |b10 b11 b12 b13 b14 b15 b16 b17|
977 |b20 b21 b22 b23 b24 b25 b26 b27|
978 |b30 b31 b32 b33 b34 b35 b36 b37|
979 ------------------->
980
981 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
982
983 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
984 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
985
986 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
987 */
Pablo Tello221f3812017-06-28 17:27:56 +0100988 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
989
990 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
991
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100992 {
993 const float16x8_t p00 = vld1q_f16(mtx_a0);
994 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100995
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100996 const float16x8_t q00 = vld1q_f16(mtx_b0);
997 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
998 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
999 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
1000
1001 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
1002 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
1003 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
1004 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
1005
1006 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
1007 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
1008 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
1009 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
1010
1011 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
1012 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
1013 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
1014 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
1015
1016 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
1017 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
1018 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
1019 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +01001020
1021 mtx_a0 += 16;
1022 mtx_b0 += 32;
1023 }
1024
1025 for(; mtx_b0 < mtx_b0_end_addr;)
1026
1027 {
1028 const float16x4_t p00 = vld1_f16(mtx_a0);
1029 const float16x8_t q00 = vld1q_f16(mtx_b0);
1030
1031 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
1032 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
1033 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
1034 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
1035
1036 mtx_a0 += 4;
1037 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001038 }
1039
1040 if(multiply_alpha)
1041 {
1042 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
1043 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
1044 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
1045 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
1046 }
1047
1048 vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
1049 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
1050 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
1051 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1052 },
1053 ina, inb, out);
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001054#else /* ARM_COMPUTE_AARCH64_V8_2 */
Georgios Pinitas30f02152017-09-27 11:20:48 +01001055 ARM_COMPUTE_UNUSED(input0);
1056 ARM_COMPUTE_UNUSED(input1);
1057 ARM_COMPUTE_UNUSED(output);
1058 ARM_COMPUTE_UNUSED(window);
1059 ARM_COMPUTE_UNUSED(alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001060 ARM_COMPUTE_ERROR("Not implemented");
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001061#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001062}
1063
1064template <bool multiply_alpha>
1065void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1066{
1067 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1068 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1069 const size_t out_stride2 = out_stride1 * 2;
1070 const size_t out_stride3 = out_stride1 * 3;
1071 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1072 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001073 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001074 ARM_COMPUTE_UNUSED(alpha_qs8);
1075
1076 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1077 Window win_a(window);
1078 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1079 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1080
1081 Window win_b;
1082 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1083 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1084 if(input1->info()->num_dimensions() >= 3)
1085 {
1086 win_b = window;
1087 }
1088 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1089 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 16x4
1090 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, 2 * in_b_stride));
1091 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1092
1093 Iterator ina(input0, win_a);
1094 Iterator inb(input1, win_b);
1095 Iterator out(output, window);
1096
1097 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1098 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
1099 // All the values needed for computing a single 32x4 block will be read from consecutive memory positions
1100 execute_window_loop(window, [&](const Coordinates & id)
1101 {
1102 auto mtx_a0 = reinterpret_cast<const qint8_t *>(ina.ptr());
1103 auto mtx_b0 = reinterpret_cast<const qint8_t *>(inb.ptr());
1104 auto mtx_b1 = mtx_b0 + in_b_stride;
1105
1106 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
1107 qint16x8_t acc10_qs16 = vdupq_n_qs16(0);
1108 qint16x8_t acc20_qs16 = vdupq_n_qs16(0);
1109 qint16x8_t acc30_qs16 = vdupq_n_qs16(0);
1110
1111 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
1112 qint16x8_t acc11_qs16 = vdupq_n_qs16(0);
1113 qint16x8_t acc21_qs16 = vdupq_n_qs16(0);
1114 qint16x8_t acc31_qs16 = vdupq_n_qs16(0);
1115
1116 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
1117 qint16x8_t acc12_qs16 = vdupq_n_qs16(0);
1118 qint16x8_t acc22_qs16 = vdupq_n_qs16(0);
1119 qint16x8_t acc32_qs16 = vdupq_n_qs16(0);
1120
1121 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
1122 qint16x8_t acc13_qs16 = vdupq_n_qs16(0);
1123 qint16x8_t acc23_qs16 = vdupq_n_qs16(0);
1124 qint16x8_t acc33_qs16 = vdupq_n_qs16(0);
1125
1126 int k = 0;
1127 // This for loop performs 2 accumulations
1128 for(; k <= (num_elems_matrix_b_x - 32); k += 32)
1129 {
1130 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1131 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1132 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1133 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1134 const qint8x8_t a4 = vld1_dup_qs8(mtx_a0 + 4);
1135 const qint8x8_t a5 = vld1_dup_qs8(mtx_a0 + 5);
1136 const qint8x8_t a6 = vld1_dup_qs8(mtx_a0 + 6);
1137 const qint8x8_t a7 = vld1_dup_qs8(mtx_a0 + 7);
1138
1139 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1140 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1141 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1142 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1143
1144 // First accumulation
1145 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1146 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1147 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1148 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1149 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1150 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1151 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1152 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1153
1154 const qint8x8_t b02 = vld1_qs8(mtx_b0 + 16);
1155 const qint8x8_t b03 = vld1_qs8(mtx_b0 + 24);
1156 const qint8x8_t b12 = vld1_qs8(mtx_b1 + 16);
1157 const qint8x8_t b13 = vld1_qs8(mtx_b1 + 24);
1158
1159 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1160 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1161 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1162 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1163 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1164 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1165 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1166 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1167
1168#if __arm__
1169 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
1170 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
1171 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +01001172#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001173
1174 // Second accumulation
1175 acc00_qs16 = vqmlal_qs8(acc00_qs16, b02, a4, fixed_point_position);
1176 acc10_qs16 = vqmlal_qs8(acc10_qs16, b02, a5, fixed_point_position);
1177 acc20_qs16 = vqmlal_qs8(acc20_qs16, b02, a6, fixed_point_position);
1178 acc30_qs16 = vqmlal_qs8(acc30_qs16, b02, a7, fixed_point_position);
1179 acc01_qs16 = vqmlal_qs8(acc01_qs16, b03, a4, fixed_point_position);
1180 acc11_qs16 = vqmlal_qs8(acc11_qs16, b03, a5, fixed_point_position);
1181 acc21_qs16 = vqmlal_qs8(acc21_qs16, b03, a6, fixed_point_position);
1182 acc31_qs16 = vqmlal_qs8(acc31_qs16, b03, a7, fixed_point_position);
1183 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a4, fixed_point_position);
1184 acc12_qs16 = vqmlal_qs8(acc12_qs16, b12, a5, fixed_point_position);
1185 acc22_qs16 = vqmlal_qs8(acc22_qs16, b12, a6, fixed_point_position);
1186 acc32_qs16 = vqmlal_qs8(acc32_qs16, b12, a7, fixed_point_position);
1187 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a4, fixed_point_position);
1188 acc13_qs16 = vqmlal_qs8(acc13_qs16, b13, a5, fixed_point_position);
1189 acc23_qs16 = vqmlal_qs8(acc23_qs16, b13, a6, fixed_point_position);
1190 acc33_qs16 = vqmlal_qs8(acc33_qs16, b13, a7, fixed_point_position);
1191
1192 mtx_a0 += 8;
1193 mtx_b0 += 32;
1194 mtx_b1 += 32;
1195 }
1196
1197 // This for loop performs the left over accumulations
1198 for(; k < num_elems_matrix_b_x; k += 16)
1199 {
1200 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1201 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1202 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1203 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1204
1205 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1206 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1207 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1208 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1209
1210 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1211 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1212 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1213 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1214 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1215 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1216 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1217 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1218 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1219 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1220 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1221 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1222 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1223 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1224 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1225 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1226
1227 mtx_a0 += 4;
1228 mtx_b0 += 16;
1229 mtx_b1 += 16;
1230 }
1231
1232 // Convert back to qint8x8_t and saturate
1233 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
1234 qint8x8_t acc10_qs8 = vqmovn_qs16(acc10_qs16);
1235 qint8x8_t acc20_qs8 = vqmovn_qs16(acc20_qs16);
1236 qint8x8_t acc30_qs8 = vqmovn_qs16(acc30_qs16);
1237
1238 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
1239 qint8x8_t acc11_qs8 = vqmovn_qs16(acc11_qs16);
1240 qint8x8_t acc21_qs8 = vqmovn_qs16(acc21_qs16);
1241 qint8x8_t acc31_qs8 = vqmovn_qs16(acc31_qs16);
1242
1243 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
1244 qint8x8_t acc12_qs8 = vqmovn_qs16(acc12_qs16);
1245 qint8x8_t acc22_qs8 = vqmovn_qs16(acc22_qs16);
1246 qint8x8_t acc32_qs8 = vqmovn_qs16(acc32_qs16);
1247
1248 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
1249 qint8x8_t acc13_qs8 = vqmovn_qs16(acc13_qs16);
1250 qint8x8_t acc23_qs8 = vqmovn_qs16(acc23_qs16);
1251 qint8x8_t acc33_qs8 = vqmovn_qs16(acc33_qs16);
1252
1253 // Multiply by the weight of the matrix product (alpha)
1254 if(multiply_alpha)
1255 {
1256 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
1257 acc10_qs8 = vqmul_qs8(acc10_qs8, alpha_qs8, fixed_point_position);
1258 acc20_qs8 = vqmul_qs8(acc20_qs8, alpha_qs8, fixed_point_position);
1259 acc30_qs8 = vqmul_qs8(acc30_qs8, alpha_qs8, fixed_point_position);
1260 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
1261 acc11_qs8 = vqmul_qs8(acc11_qs8, alpha_qs8, fixed_point_position);
1262 acc21_qs8 = vqmul_qs8(acc21_qs8, alpha_qs8, fixed_point_position);
1263 acc31_qs8 = vqmul_qs8(acc31_qs8, alpha_qs8, fixed_point_position);
1264 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
1265 acc12_qs8 = vqmul_qs8(acc12_qs8, alpha_qs8, fixed_point_position);
1266 acc22_qs8 = vqmul_qs8(acc22_qs8, alpha_qs8, fixed_point_position);
1267 acc32_qs8 = vqmul_qs8(acc32_qs8, alpha_qs8, fixed_point_position);
1268 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
1269 acc13_qs8 = vqmul_qs8(acc13_qs8, alpha_qs8, fixed_point_position);
1270 acc23_qs8 = vqmul_qs8(acc23_qs8, alpha_qs8, fixed_point_position);
1271 acc33_qs8 = vqmul_qs8(acc33_qs8, alpha_qs8, fixed_point_position);
1272 }
1273
1274 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
1275
1276 // Store 32x4 output elements
1277 vst1_qs8(mtx_out0 + 0, acc00_qs8);
1278 vst1_qs8(mtx_out0 + 8, acc01_qs8);
1279 vst1_qs8(mtx_out0 + 16, acc02_qs8);
1280 vst1_qs8(mtx_out0 + 24, acc03_qs8);
1281 vst1_qs8(mtx_out0 + out_stride1 + 0, acc10_qs8);
1282 vst1_qs8(mtx_out0 + out_stride1 + 8, acc11_qs8);
1283 vst1_qs8(mtx_out0 + out_stride1 + 16, acc12_qs8);
1284 vst1_qs8(mtx_out0 + out_stride1 + 24, acc13_qs8);
1285 vst1_qs8(mtx_out0 + out_stride2 + 0, acc20_qs8);
1286 vst1_qs8(mtx_out0 + out_stride2 + 8, acc21_qs8);
1287 vst1_qs8(mtx_out0 + out_stride2 + 16, acc22_qs8);
1288 vst1_qs8(mtx_out0 + out_stride2 + 24, acc23_qs8);
1289 vst1_qs8(mtx_out0 + out_stride3 + 0, acc30_qs8);
1290 vst1_qs8(mtx_out0 + out_stride3 + 8, acc31_qs8);
1291 vst1_qs8(mtx_out0 + out_stride3 + 16, acc32_qs8);
1292 vst1_qs8(mtx_out0 + out_stride3 + 24, acc33_qs8);
1293 },
1294 ina, inb, out);
1295}
1296
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001297template <bool multiply_alpha>
1298void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1299{
1300 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1301 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1302 const size_t out_stride2 = out_stride1 * 2;
1303 const size_t out_stride3 = out_stride1 * 3;
1304 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1305 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001306 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001307 ARM_COMPUTE_UNUSED(alpha_qs16);
1308
1309 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1310 Window win_a(window);
1311 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1312 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1313
1314 Window win_b;
1315 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1316 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1317 if(input1->info()->num_dimensions() >= 3)
1318 {
1319 win_b = window;
1320 }
1321 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1322 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
1323 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1324
1325 Iterator ina(input0, win_a);
1326 Iterator inb(input1, win_b);
1327 Iterator out(output, window);
1328
1329 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1330 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
1331 // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
1332 execute_window_loop(window, [&](const Coordinates & id)
1333 {
1334 auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
1335 auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
1336 auto mtx_b1 = mtx_b0 + in_b_stride;
1337
1338 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
1339 qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
1340 qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
1341 qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
1342
1343 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
1344 qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
1345 qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
1346 qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
1347
1348 // This for loop performs 1 accumulation
1349 for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
1350 {
1351 const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
1352 const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
1353 const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
1354 const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
1355
1356 const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
1357 const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
1358
1359 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
1360 acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
1361 acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
1362 acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
1363 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
1364 acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
1365 acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
1366 acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
1367
1368 mtx_a0 += 4;
1369 mtx_b0 += 8;
1370 mtx_b1 += 8;
1371 }
1372
1373 // Convert back to qint16x4_t and saturate
1374 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
1375 qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
1376 qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
1377 qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
1378
1379 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
1380 qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
1381 qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
1382 qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
1383
1384 // Multiply by the weight of the matrix product (alpha)
1385 if(multiply_alpha)
1386 {
1387 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
1388 acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
1389 acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
1390 acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
1391 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
1392 acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
1393 acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
1394 acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
1395 }
1396
1397 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
1398
1399 // Store 8x4 output elements
1400 vst1_qs16(mtx_out0 + 0, acc00_qs16);
1401 vst1_qs16(mtx_out0 + 4, acc01_qs16);
1402 vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
1403 vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
1404 vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
1405 vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
1406 vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
1407 vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
1408 },
1409 ina, inb, out);
1410}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001411} // namespace
1412
1413NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
1414 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1415{
1416}
1417
1418void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
1419{
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001420 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001421 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
1422 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
1423
1424 if(output->info()->dimension(1) == 1)
1425 {
1426 ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
1427 }
1428
1429 _input0 = input0;
1430 _input1 = input1;
1431 _output = output;
1432 _alpha = alpha;
1433
1434 unsigned int num_elems_processed_per_iteration_x = 0;
1435 const unsigned int num_elems_processed_per_iteration_y = 4;
1436
1437 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1438 if((output->info()->dimension(1) == 1))
1439 {
1440 switch(input0->info()->data_type())
1441 {
1442 case DataType::F32:
1443 {
1444 num_elems_processed_per_iteration_x = 16;
1445 break;
1446 }
1447 case DataType::QS8:
1448 {
1449 num_elems_processed_per_iteration_x = 32;
1450 break;
1451 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001452 case DataType::QS16:
1453 {
1454 num_elems_processed_per_iteration_x = 16;
1455 break;
1456 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001457#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello221f3812017-06-28 17:27:56 +01001458 case DataType::F16:
1459 {
1460 num_elems_processed_per_iteration_x = 32;
1461 break;
1462 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001463#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001464 default:
1465 {
1466 ARM_COMPUTE_ERROR("Data type not supported");
1467 break;
1468 }
1469 }
1470
1471 // Configure kernel window
1472 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1473
1474 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
1475
1476 update_window_and_padding(win,
Moritz Pflanzer484e7b32017-08-09 11:43:18 +01001477 AccessWindowStatic(input0->info(), 0, 0, input0->info()->tensor_shape().x(), 1),
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001478 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration_x),
1479 output_access);
1480
1481 Coordinates coord;
1482 coord.set_num_dimensions(output->info()->num_dimensions());
1483 output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
1484
1485 INEKernel::configure(win);
1486 }
1487 else
1488 {
1489 switch(input0->info()->data_type())
1490 {
1491 case DataType::F32:
1492 {
1493 num_elems_processed_per_iteration_x = 8;
1494 break;
1495 }
1496 case DataType::QS8:
1497 {
1498 num_elems_processed_per_iteration_x = 32;
1499 break;
1500 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001501 case DataType::QS16:
1502 {
1503 num_elems_processed_per_iteration_x = 8;
1504 break;
1505 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001506#ifdef ARM_COMPUTE_AARCH64_V8_2
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001507 case DataType::F16:
1508 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001509 num_elems_processed_per_iteration_x = 8;
1510 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001511 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001512#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001513 default:
1514 {
1515 ARM_COMPUTE_ERROR("Data type not supported");
1516 break;
1517 }
1518 }
1519
1520 // Configure kernel window
1521 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1522
1523 AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
1524
1525 update_window_and_padding(win,
1526 AccessWindowRectangle(input0->info(), 0, 0, 4, 1, 1.f, 0.25f),
Georgios Pinitasce54b562017-09-14 17:21:51 +01001527 AccessWindowStatic(input1->info(), 0, 0, input1->info()->tensor_shape().x(), ceil_to_multiple(input1->info()->tensor_shape().y(), 4)),
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001528 output_access);
1529
1530 output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
1531
1532 INEKernel::configure(win);
1533 }
1534}
1535
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001536void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001537{
1538 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1539 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1540
1541 bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
1542
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001543 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001544 if((_output->info()->dimension(1) == 1))
1545 {
1546 switch(_input0->info()->data_type())
1547 {
1548 case DataType::F32:
1549 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001550 multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, info, _alpha) :
1551 vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001552 break;
1553 }
1554 case DataType::QS8:
1555 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001556 multiply_alpha ? vector_matrix_multiply_qs8<true>(_input0, _input1, _output, window, info, _alpha) :
1557 vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, info, _alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001558 break;
1559 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001560 case DataType::QS16:
1561 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001562 multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, info, _alpha) :
1563 vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, info, _alpha);
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001564 break;
1565 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001566#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello221f3812017-06-28 17:27:56 +01001567 case DataType::F16:
1568 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001569 multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, info, _alpha) :
1570 vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, info, _alpha);
Pablo Tello221f3812017-06-28 17:27:56 +01001571 break;
1572 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001573#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001574 default:
1575 {
1576 ARM_COMPUTE_ERROR("Data type not supported");
1577 break;
1578 }
1579 }
1580 }
1581 else
1582 {
1583 switch(_input0->info()->data_type())
1584 {
1585 case DataType::F32:
1586 {
1587 multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1588 matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1589 break;
1590 }
1591 case DataType::QS8:
1592 {
1593 multiply_alpha ? matrix_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1594 matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1595 break;
1596 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001597 case DataType::QS16:
1598 {
1599 multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
1600 matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
1601 break;
1602 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001603#ifdef ARM_COMPUTE_AARCH64_V8_2
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001604 case DataType::F16:
1605 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001606 multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1607 matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1608 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001609 }
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001610#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001611 default:
1612 {
1613 ARM_COMPUTE_ERROR("Data type not supported");
1614 break;
1615 }
1616 }
1617 }
1618}