blob: d015ec3e41436400e59ea6edb9e4f130237eea72 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
Moritz Pflanzer484e7b32017-08-09 11:43:18 +010026#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/AccessWindowTranspose.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/Window.h"
38
39#include <arm_neon.h>
40#include <cstddef>
41#include <cstdint>
42#include <tuple>
43
44using namespace arm_compute;
45
46namespace arm_compute
47{
48class Coordinates;
49} // namespace arm_compute
50
51namespace
52{
53template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +010054void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Pablo Tello221f3812017-06-28 17:27:56 +010055{
56#ifdef ARM_COMPUTE_ENABLE_FP16
57 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
58 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
59 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
60
61 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +010062 const int window_start_x = 32 * info.thread_id;
63 const int window_step_x = 32 * info.num_threads;
Pablo Tello221f3812017-06-28 17:27:56 +010064 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
65 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
66
67 Window win_out(window);
68 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
69 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
70
71 Window win_a(window);
72 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
73 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
74
75 Window win_b;
76 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
77 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
78 if(input1->info()->num_dimensions() >= 3)
79 {
80 win_b = window;
81 }
82 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
83 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
84
85 Iterator ina(input0, win_a);
86 Iterator inb(input1, win_b);
87 Iterator out(output, win_out);
88
89 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
90 ARM_COMPUTE_UNUSED(alpha_f16);
91
92 execute_window_loop(win_out, [&](const Coordinates & id)
93 {
94 if(id.x() > width_matrix_b)
95 {
96 return;
97 }
98
99 float16x8_t acc0 = vdupq_n_f16(0.f);
100 float16x8_t acc1 = vdupq_n_f16(0.f);
101 float16x8_t acc2 = vdupq_n_f16(0.f);
102 float16x8_t acc3 = vdupq_n_f16(0.f);
103
104 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
105 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
106
107 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
108 for(; vec_a <= (vec_a_end_addr - 4);)
109 {
110 const float16x4_t a0l = vld1_f16(vec_a);
111
112 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
113 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
114 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
115 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
116 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
117 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
118 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
119 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
120
121 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
122 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
123 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
124 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
125 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
126 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
127 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
128 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
129
130 matrix_b += 2 * in_b_stride;
131
132 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
133 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
134 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
135 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
136 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
137 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
138 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
139 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
140
141 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
142 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
143 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
144 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
145 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
146 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
147 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
148 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
149
150 vec_a += 4;
151 matrix_b += 2 * in_b_stride;
152 }
153
154 for(; vec_a < vec_a_end_addr;)
155 {
156 const float16_t a0 = *vec_a;
157 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
158 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
159 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
160 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
161
162 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
163 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
164 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
165 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
166
167 vec_a += 1;
168 matrix_b += in_b_stride;
169 }
170
171 // Multiply by the weight of matrix product (alpha)
172 if(multiply_alpha)
173 {
174 acc0 = vmulq_f16(acc0, alpha_f16);
175 acc1 = vmulq_f16(acc1, alpha_f16);
176 acc2 = vmulq_f16(acc2, alpha_f16);
177 acc3 = vmulq_f16(acc3, alpha_f16);
178 }
179
180 const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
181
182 vst1q_f16(vec_out + 0, acc0);
183 vst1q_f16(vec_out + 8, acc1);
184 vst1q_f16(vec_out + 16, acc2);
185 vst1q_f16(vec_out + 24, acc3);
186
187 },
188 ina, inb, out);
189#else /* ARM_COMPUTE_ENABLE_FP16 */
190 ARM_COMPUTE_ERROR("Not implemented");
191#endif /* ARM_COMPUTE_ENABLE_FP16 */
192}
193
194template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100195void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196{
197 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
198 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
199 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
200
201 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100202 const int window_start_x = 16 * info.thread_id;
203 const int window_step_x = 16 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
205 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
206
207 Window win_out(window);
208 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
209 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
210
211 Window win_a(window);
212 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
213 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
214
215 Window win_b;
216 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
217 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
218 if(input1->info()->num_dimensions() >= 3)
219 {
220 win_b = window;
221 }
222 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
223 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
224
225 Iterator ina(input0, win_a);
226 Iterator inb(input1, win_b);
227 Iterator out(output, win_out);
228
229 execute_window_loop(win_out, [&](const Coordinates & id)
230 {
231 if(id.x() > width_matrix_b)
232 {
233 return;
234 }
235
236 float32x4_t acc0 = vdupq_n_f32(0.f);
237 float32x4_t acc1 = vdupq_n_f32(0.f);
238 float32x4_t acc2 = vdupq_n_f32(0.f);
239 float32x4_t acc3 = vdupq_n_f32(0.f);
240
241 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
242 auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
243
244#if __arm__
245 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
246 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
247 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100248#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249
250 auto vec_a_end_addr = vec_a + num_elems_vec_a;
251 for(; vec_a <= (vec_a_end_addr - 4);)
252 {
253 float32x2_t a0l = vld1_f32(vec_a);
254
255 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
256 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
257 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
258 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
259
260 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
261 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
262 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
263 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
264
265#if __arm__
266 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
267 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
268 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
269 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
270 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100271#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272
273 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
274 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
275 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
276 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
277
278 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
279 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
280 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
281 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
282
283 vec_a += 2;
284 matrix_b += 2 * in_b_stride;
285
286 a0l = vld1_f32(vec_a);
287
288 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
289 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
290 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
291 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
292
293 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
294 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
295 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
296 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
297
298 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
299 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
300 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
301 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
302
303 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
304 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
305 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
306 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
307
308 vec_a += 2;
309 matrix_b += 2 * in_b_stride;
310 }
311
312 for(; vec_a < vec_a_end_addr;)
313 {
314 const float a0 = *vec_a;
315
316 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
317 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
318 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
319 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
320
321 acc0 = vmlaq_n_f32(acc0, b00, a0);
322 acc1 = vmlaq_n_f32(acc1, b01, a0);
323 acc2 = vmlaq_n_f32(acc2, b02, a0);
324 acc3 = vmlaq_n_f32(acc3, b03, a0);
325
326 vec_a += 1;
327 matrix_b += in_b_stride;
328 }
329
330 // Multiply by the weight of matrix product (alpha)
331 if(multiply_alpha)
332 {
333 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
334 acc0 = vmulq_f32(acc0, alpha_f32);
335 acc1 = vmulq_f32(acc1, alpha_f32);
336 acc2 = vmulq_f32(acc2, alpha_f32);
337 acc3 = vmulq_f32(acc3, alpha_f32);
338 }
339
340 const auto vec_out = reinterpret_cast<float *>(out.ptr());
341
342 vst1q_f32(vec_out + 0, acc0);
343 vst1q_f32(vec_out + 4, acc1);
344 vst1q_f32(vec_out + 8, acc2);
345 vst1q_f32(vec_out + 12, acc3);
346 },
347 ina, inb, out);
348}
349
350template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100351void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100352{
353 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
354 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
355 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
356 const int fixed_point_position = input0->info()->fixed_point_position();
357
358 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100359 const int window_start_x = 32 * info.thread_id;
360 const int window_step_x = 32 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100361 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
362 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
363
364 Window win_out(window);
365 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
366 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
367
368 Window win_a(window);
369 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
370 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
371
372 Window win_b;
373 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
374 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
375 if(input1->info()->num_dimensions() >= 3)
376 {
377 win_b = window;
378 }
379 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
380 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
381
382 Iterator ina(input0, win_a);
383 Iterator inb(input1, win_b);
384 Iterator out(output, win_out);
385
386 execute_window_loop(win_out, [&](const Coordinates & id)
387 {
388 if(id.x() > width_matrix_b)
389 {
390 return;
391 }
392
393 // Reset accumulators
394 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
395 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
396 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
397 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
398
399 auto vec_a = reinterpret_cast<const qint8_t *>(ina.ptr());
400 auto matrix_b = reinterpret_cast<const qint8_t *>(inb.ptr());
401
402 auto vec_a_end_addr = vec_a + num_elems_vec_a;
403 for(; vec_a <= (vec_a_end_addr - 2);)
404 {
405 const qint8x8_t a0 = vld1_dup_qs8(vec_a + 0);
406 const qint8x8_t a1 = vld1_dup_qs8(vec_a + 1);
407
408 const qint8x8_t b00 = vld1_qs8(matrix_b + 0 + 0 * in_b_stride);
409 const qint8x8_t b01 = vld1_qs8(matrix_b + 8 + 0 * in_b_stride);
410 const qint8x8_t b02 = vld1_qs8(matrix_b + 16 + 0 * in_b_stride);
411 const qint8x8_t b03 = vld1_qs8(matrix_b + 24 + 0 * in_b_stride);
412 const qint8x8_t b10 = vld1_qs8(matrix_b + 0 + 1 * in_b_stride);
413 const qint8x8_t b11 = vld1_qs8(matrix_b + 8 + 1 * in_b_stride);
414 const qint8x8_t b12 = vld1_qs8(matrix_b + 16 + 1 * in_b_stride);
415 const qint8x8_t b13 = vld1_qs8(matrix_b + 24 + 1 * in_b_stride);
416
417 // First accumulation
418 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
419 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
420 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
421 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
422
423 // Second accumulation
424 acc00_qs16 = vqmlal_qs8(acc00_qs16, b10, a1, fixed_point_position);
425 acc01_qs16 = vqmlal_qs8(acc01_qs16, b11, a1, fixed_point_position);
426 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a1, fixed_point_position);
427 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a1, fixed_point_position);
428
429 vec_a += 2;
430 matrix_b += 2 * in_b_stride;
431 }
432
433 for(; vec_a < vec_a_end_addr;)
434 {
435 const qint8x8_t a0 = vld1_dup_qs8(vec_a);
436
437 const qint8x8_t b00 = vld1_qs8(matrix_b + 0);
438 const qint8x8_t b01 = vld1_qs8(matrix_b + 8);
439 const qint8x8_t b02 = vld1_qs8(matrix_b + 16);
440 const qint8x8_t b03 = vld1_qs8(matrix_b + 24);
441
442 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
443 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
444 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
445 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
446
447 vec_a += 1;
448 matrix_b += in_b_stride;
449 }
450
451 // Convert back to qint8x8_t and saturate
452 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
453 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
454 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
455 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
456
457 // Multiply by the weight of the matrix product (alpha)
458 if(multiply_alpha)
459 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100460 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100461 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
462 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
463 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
464 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
465 }
466
467 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
468
469 // Store 8x4 output elements
470 vst1_qs8(mtx_out0 + 0, acc00_qs8);
471 vst1_qs8(mtx_out0 + 8, acc01_qs8);
472 vst1_qs8(mtx_out0 + 16, acc02_qs8);
473 vst1_qs8(mtx_out0 + 24, acc03_qs8);
474 },
475 ina, inb, out);
476}
477
478template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100479void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100480{
481 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
482 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
483 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
484 const int fixed_point_position = input0->info()->fixed_point_position();
485
486 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100487 const int window_start_x = 16 * info.thread_id;
488 const int window_step_x = 16 * info.num_threads;
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100489 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
490 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
491 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
492
493 Window win_out(window);
494 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
495 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
496
497 Window win_a(window);
498 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
499 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
500
501 Window win_b;
502 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
503 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
504 if(input1->info()->num_dimensions() >= 3)
505 {
506 win_b = window;
507 }
508 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
509 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
510
511 Iterator ina(input0, win_a);
512 Iterator inb(input1, win_b);
513 Iterator out(output, win_out);
514
515 execute_window_loop(win_out, [&](const Coordinates & id)
516 {
517 if(id.x() > width_matrix_b)
518 {
519 return;
520 }
521
522 // Reset accumulators
523 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
524 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
525 qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
526 qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
527
528 auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
529 auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
530
531 auto vec_a_end_addr = vec_a + num_elems_vec_a;
532 for(; vec_a <= (vec_a_end_addr - 2);)
533 {
534 const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
535 const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
536
537 const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
538 const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
539 const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
540 const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
541 const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
542 const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
543 const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
544 const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
545
546 // First accumulation
547 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
548 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
549 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
550 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
551
552 // Second accumulation
553 acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
554 acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
555 acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
556 acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
557
558 vec_a += 2;
559 matrix_b += 2 * in_b_stride;
560 }
561
562 for(; vec_a < vec_a_end_addr;)
563 {
564 const qint16x4_t a0 = vld1_dup_qs16(vec_a);
565
566 const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
567 const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
568 const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
569 const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
570
571 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
572 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
573 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
574 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
575
576 vec_a += 1;
577 matrix_b += in_b_stride;
578 }
579
580 // Convert back to qint16x4_t and saturate
581 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
582 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
583 qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
584 qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
585
586 // Multiply by the weight of the matrix product (alpha)
587 if(multiply_alpha)
588 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100589 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100590 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
591 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
592 acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
593 acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
594 }
595
596 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
597
598 // Store 16x4 output elements
599 vst1_qs16(mtx_out0 + 0, acc00_qs16);
600 vst1_qs16(mtx_out0 + 4, acc01_qs16);
601 vst1_qs16(mtx_out0 + 8, acc02_qs16);
602 vst1_qs16(mtx_out0 + 12, acc03_qs16);
603 },
604 ina, inb, out);
605}
606
607template <bool multiply_alpha>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100608void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
609{
610 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
611 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
612 const size_t out_stride2 = out_stride1 * 2;
613 const size_t out_stride3 = out_stride1 * 3;
614 const int num_elems_matrix_b_x = input1->info()->dimension(0);
615
616 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
617 Window win_a(window);
618 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
619 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
620
621 Window win_b;
622 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
623 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
624 if(input1->info()->num_dimensions() >= 3)
625 {
626 win_b = window;
627 }
628 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
629 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
630 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
631 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
632
633 Iterator ina(input0, win_a);
634 Iterator inb(input1, win_b);
635 Iterator out(output, window);
636
637 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
638 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
639 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
640 execute_window_loop(window, [&](const Coordinates & id)
641 {
642 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
643 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
644 auto mtx_b1 = mtx_b0 + in_b_stride;
645
646 float32x4_t acc00 = vdupq_n_f32(0.f);
647 float32x4_t acc10 = vdupq_n_f32(0.f);
648 float32x4_t acc20 = vdupq_n_f32(0.f);
649 float32x4_t acc30 = vdupq_n_f32(0.f);
650
651 float32x4_t acc01 = vdupq_n_f32(0.f);
652 float32x4_t acc11 = vdupq_n_f32(0.f);
653 float32x4_t acc21 = vdupq_n_f32(0.f);
654 float32x4_t acc31 = vdupq_n_f32(0.f);
655
656#if __arm__
657 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
658 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
659 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100660#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100661
662 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
663 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
664 {
665 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
666 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
667 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
668 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
669
670 float32x4_t b00 = vld1q_f32(mtx_b0);
671 float32x4_t b10 = vld1q_f32(mtx_b1);
672 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
673 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
674
675#if __arm__
676 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
677 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
678 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100679#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100680
681 // 4x4 block 0
682 acc00 = vmlaq_f32(acc00, b00, a0);
683 acc10 = vmlaq_f32(acc10, b00, a1);
684 acc20 = vmlaq_f32(acc20, b00, a2);
685 acc30 = vmlaq_f32(acc30, b00, a3);
686
687 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
688 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
689 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
690 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
691
692 // 4x4 block 1
693 acc01 = vmlaq_f32(acc01, b10, a0);
694 acc11 = vmlaq_f32(acc11, b10, a1);
695 acc21 = vmlaq_f32(acc21, b10, a2);
696 acc31 = vmlaq_f32(acc31, b10, a3);
697
698 // 4x4 block 0
699 acc00 = vmlaq_f32(acc00, b01, a4);
700 acc10 = vmlaq_f32(acc10, b01, a5);
701 acc20 = vmlaq_f32(acc20, b01, a6);
702 acc30 = vmlaq_f32(acc30, b01, a7);
703
704 // 4x4 block 1
705 acc01 = vmlaq_f32(acc01, b11, a4);
706 acc11 = vmlaq_f32(acc11, b11, a5);
707 acc21 = vmlaq_f32(acc21, b11, a6);
708 acc31 = vmlaq_f32(acc31, b11, a7);
709
710 mtx_a0 += 8;
711 mtx_b0 += 8;
712 mtx_b1 += 8;
713
714 a0 = vld1q_dup_f32(mtx_a0 + 0);
715 a1 = vld1q_dup_f32(mtx_a0 + 1);
716 a2 = vld1q_dup_f32(mtx_a0 + 2);
717 a3 = vld1q_dup_f32(mtx_a0 + 3);
718
719 b00 = vld1q_f32(mtx_b0);
720 b10 = vld1q_f32(mtx_b1);
721 b01 = vld1q_f32(mtx_b0 + 4);
722 b11 = vld1q_f32(mtx_b1 + 4);
723
724 // 4x4 block 0
725 acc00 = vmlaq_f32(acc00, b00, a0);
726 acc10 = vmlaq_f32(acc10, b00, a1);
727 acc20 = vmlaq_f32(acc20, b00, a2);
728 acc30 = vmlaq_f32(acc30, b00, a3);
729
730 a4 = vld1q_dup_f32(mtx_a0 + 4);
731 a5 = vld1q_dup_f32(mtx_a0 + 5);
732 a6 = vld1q_dup_f32(mtx_a0 + 6);
733 a7 = vld1q_dup_f32(mtx_a0 + 7);
734
735 // 4x4 block 1
736 acc01 = vmlaq_f32(acc01, b10, a0);
737 acc11 = vmlaq_f32(acc11, b10, a1);
738 acc21 = vmlaq_f32(acc21, b10, a2);
739 acc31 = vmlaq_f32(acc31, b10, a3);
740
741 // 4x4 block 0
742 acc00 = vmlaq_f32(acc00, b01, a4);
743 acc10 = vmlaq_f32(acc10, b01, a5);
744 acc20 = vmlaq_f32(acc20, b01, a6);
745 acc30 = vmlaq_f32(acc30, b01, a7);
746
747 // 4x4 block 1
748 acc01 = vmlaq_f32(acc01, b11, a4);
749 acc11 = vmlaq_f32(acc11, b11, a5);
750 acc21 = vmlaq_f32(acc21, b11, a6);
751 acc31 = vmlaq_f32(acc31, b11, a7);
752
753 mtx_a0 += 8;
754 mtx_b0 += 8;
755 mtx_b1 += 8;
756
757 a0 = vld1q_dup_f32(mtx_a0 + 0);
758 a1 = vld1q_dup_f32(mtx_a0 + 1);
759 a2 = vld1q_dup_f32(mtx_a0 + 2);
760 a3 = vld1q_dup_f32(mtx_a0 + 3);
761 b00 = vld1q_f32(mtx_b0);
762 b10 = vld1q_f32(mtx_b1);
763 b01 = vld1q_f32(mtx_b0 + 4);
764 b11 = vld1q_f32(mtx_b1 + 4);
765
766#if __arm__
767 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
768 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
769 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100770#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100771
772 // 4x4 block 0
773 acc00 = vmlaq_f32(acc00, b00, a0);
774 acc10 = vmlaq_f32(acc10, b00, a1);
775 acc20 = vmlaq_f32(acc20, b00, a2);
776 acc30 = vmlaq_f32(acc30, b00, a3);
777
778 a4 = vld1q_dup_f32(mtx_a0 + 4);
779 a5 = vld1q_dup_f32(mtx_a0 + 5);
780 a6 = vld1q_dup_f32(mtx_a0 + 6);
781 a7 = vld1q_dup_f32(mtx_a0 + 7);
782
783 // 4x4 block 1
784 acc01 = vmlaq_f32(acc01, b10, a0);
785 acc11 = vmlaq_f32(acc11, b10, a1);
786 acc21 = vmlaq_f32(acc21, b10, a2);
787 acc31 = vmlaq_f32(acc31, b10, a3);
788
789 // 4x4 block 0
790 acc00 = vmlaq_f32(acc00, b01, a4);
791 acc10 = vmlaq_f32(acc10, b01, a5);
792 acc20 = vmlaq_f32(acc20, b01, a6);
793 acc30 = vmlaq_f32(acc30, b01, a7);
794
795 // 4x4 block 1
796 acc01 = vmlaq_f32(acc01, b11, a4);
797 acc11 = vmlaq_f32(acc11, b11, a5);
798 acc21 = vmlaq_f32(acc21, b11, a6);
799 acc31 = vmlaq_f32(acc31, b11, a7);
800
801 mtx_a0 += 8;
802 mtx_b0 += 8;
803 mtx_b1 += 8;
804
805 a0 = vld1q_dup_f32(mtx_a0 + 0);
806 a1 = vld1q_dup_f32(mtx_a0 + 1);
807 a2 = vld1q_dup_f32(mtx_a0 + 2);
808 a3 = vld1q_dup_f32(mtx_a0 + 3);
809 b00 = vld1q_f32(mtx_b0);
810 b10 = vld1q_f32(mtx_b1);
811 b01 = vld1q_f32(mtx_b0 + 4);
812 b11 = vld1q_f32(mtx_b1 + 4);
813
814 // 4x4 block 0
815 acc00 = vmlaq_f32(acc00, b00, a0);
816 acc10 = vmlaq_f32(acc10, b00, a1);
817 acc20 = vmlaq_f32(acc20, b00, a2);
818 acc30 = vmlaq_f32(acc30, b00, a3);
819
820 a4 = vld1q_dup_f32(mtx_a0 + 4);
821 a5 = vld1q_dup_f32(mtx_a0 + 5);
822 a6 = vld1q_dup_f32(mtx_a0 + 6);
823 a7 = vld1q_dup_f32(mtx_a0 + 7);
824
825 // 4x4 block 1
826 acc01 = vmlaq_f32(acc01, b10, a0);
827 acc11 = vmlaq_f32(acc11, b10, a1);
828 acc21 = vmlaq_f32(acc21, b10, a2);
829 acc31 = vmlaq_f32(acc31, b10, a3);
830
831 // 4x4 block 0
832 acc00 = vmlaq_f32(acc00, b01, a4);
833 acc10 = vmlaq_f32(acc10, b01, a5);
834 acc20 = vmlaq_f32(acc20, b01, a6);
835 acc30 = vmlaq_f32(acc30, b01, a7);
836
837 // 4x4 block 1
838 acc01 = vmlaq_f32(acc01, b11, a4);
839 acc11 = vmlaq_f32(acc11, b11, a5);
840 acc21 = vmlaq_f32(acc21, b11, a6);
841 acc31 = vmlaq_f32(acc31, b11, a7);
842
843 mtx_a0 += 8;
844 mtx_b0 += 8;
845 mtx_b1 += 8;
846 }
847
848 for(; mtx_b0 < mtx_b0_end_addr;)
849 {
850 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
851 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
852 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
853 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
854 float32x4_t b00 = vld1q_f32(mtx_b0);
855 float32x4_t b10 = vld1q_f32(mtx_b1);
856
857#if __arm__
858 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
859 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
860 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100861#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100862 // 4x4 block 0
863 acc00 = vmlaq_f32(acc00, b00, a0);
864 acc10 = vmlaq_f32(acc10, b00, a1);
865 acc20 = vmlaq_f32(acc20, b00, a2);
866 acc30 = vmlaq_f32(acc30, b00, a3);
867
868 // 4x4 block 1
869 acc01 = vmlaq_f32(acc01, b10, a0);
870 acc11 = vmlaq_f32(acc11, b10, a1);
871 acc21 = vmlaq_f32(acc21, b10, a2);
872 acc31 = vmlaq_f32(acc31, b10, a3);
873
874 mtx_a0 += 4;
875 mtx_b0 += 4;
876 mtx_b1 += 4;
877 }
878
879 // Multiply by the weight of matrix product (alpha)
880 if(multiply_alpha)
881 {
882 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
883 acc00 = vmulq_f32(acc00, alpha_f32);
884 acc10 = vmulq_f32(acc10, alpha_f32);
885 acc20 = vmulq_f32(acc20, alpha_f32);
886 acc30 = vmulq_f32(acc30, alpha_f32);
887 acc01 = vmulq_f32(acc01, alpha_f32);
888 acc11 = vmulq_f32(acc11, alpha_f32);
889 acc21 = vmulq_f32(acc21, alpha_f32);
890 acc31 = vmulq_f32(acc31, alpha_f32);
891 }
892
893 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
894 const auto mtx_out1 = mtx_out0 + 4;
895
896 // Store the 4 blocks
897 vst1q_f32(mtx_out0, acc00);
898 vst1q_f32(mtx_out1, acc01);
899 vst1q_f32(mtx_out0 + out_stride1, acc10);
900 vst1q_f32(mtx_out1 + out_stride1, acc11);
901 vst1q_f32(mtx_out0 + out_stride2, acc20);
902 vst1q_f32(mtx_out1 + out_stride2, acc21);
903 vst1q_f32(mtx_out0 + out_stride3, acc30);
904 vst1q_f32(mtx_out1 + out_stride3, acc31);
905 },
906 ina, inb, out);
907}
908
909template <bool multiply_alpha>
910void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
911{
912#ifdef ARM_COMPUTE_ENABLE_FP16
Pablo Tello221f3812017-06-28 17:27:56 +0100913 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
914 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
915 const int num_elems_matrix_b_x = input1->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100916
917 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
918 Window win_a(window);
919 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
920 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
921
922 Window win_b;
923 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
924 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
925 if(input1->info()->num_dimensions() >= 3)
926 {
927 win_b = window;
928 }
929 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
930 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
931 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
932
933 Iterator ina(input0, win_a);
934 Iterator inb(input1, win_b);
935 Iterator out(output, window);
936
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100937 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
938
939 execute_window_loop(window, [&](const Coordinates & id)
940 {
941 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
942 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
943 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
944 float16x8x4_t c =
945 {
946 {
947 vdupq_n_f16(0.f),
948 vdupq_n_f16(0.f),
949 vdupq_n_f16(0.f),
950 vdupq_n_f16(0.f)
951 }
952 };
953
954 /*
955 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
956 |a00 a01 a02 a03 | a04 a05 a06 a07|
957 |a10 a11 a12 a13 | a14 a15 a16 a17|
958 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
959 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
960 |a40 a41 a42 a43 | a44 a45 a46 a47|
961 |a50 a51 a52 a53 | a54 a55 a56 a57|
962 |a60 a61 a62 a63 | a64 a65 a66 a67|
963 |a70 a71 a72 a73 | a74 a75 a76 a77|
964
965 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
966
967 B Matrix has been transposed as shown below
968
969 |b00 b01 b02 b03 b04 b05 b06 b07|
970 |b10 b11 b12 b13 b14 b15 b16 b17|
971 |b20 b21 b22 b23 b24 b25 b26 b27|
972 |b30 b31 b32 b33 b34 b35 b36 b37|
973 ------------------->
974
975 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
976
977 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
978 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
979
980 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
981 */
Pablo Tello221f3812017-06-28 17:27:56 +0100982 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
983
984 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
985
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100986 {
987 const float16x8_t p00 = vld1q_f16(mtx_a0);
988 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100989
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100990 const float16x8_t q00 = vld1q_f16(mtx_b0);
991 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
992 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
993 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
994
995 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
996 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
997 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
998 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
999
1000 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
1001 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
1002 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
1003 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
1004
1005 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
1006 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
1007 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
1008 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
1009
1010 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
1011 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
1012 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
1013 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +01001014
1015 mtx_a0 += 16;
1016 mtx_b0 += 32;
1017 }
1018
1019 for(; mtx_b0 < mtx_b0_end_addr;)
1020
1021 {
1022 const float16x4_t p00 = vld1_f16(mtx_a0);
1023 const float16x8_t q00 = vld1q_f16(mtx_b0);
1024
1025 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
1026 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
1027 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
1028 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
1029
1030 mtx_a0 += 4;
1031 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001032 }
1033
1034 if(multiply_alpha)
1035 {
1036 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
1037 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
1038 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
1039 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
1040 }
1041
1042 vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
1043 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
1044 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
1045 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1046 },
1047 ina, inb, out);
Anthony Barbierac69aa12017-07-03 17:39:37 +01001048#else /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001049 ARM_COMPUTE_ERROR("Not implemented");
Anthony Barbierac69aa12017-07-03 17:39:37 +01001050#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001051}
1052
1053template <bool multiply_alpha>
1054void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1055{
1056 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1057 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1058 const size_t out_stride2 = out_stride1 * 2;
1059 const size_t out_stride3 = out_stride1 * 3;
1060 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1061 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001062 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001063 ARM_COMPUTE_UNUSED(alpha_qs8);
1064
1065 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1066 Window win_a(window);
1067 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1068 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1069
1070 Window win_b;
1071 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1072 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1073 if(input1->info()->num_dimensions() >= 3)
1074 {
1075 win_b = window;
1076 }
1077 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1078 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 16x4
1079 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, 2 * in_b_stride));
1080 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1081
1082 Iterator ina(input0, win_a);
1083 Iterator inb(input1, win_b);
1084 Iterator out(output, window);
1085
1086 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1087 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
1088 // All the values needed for computing a single 32x4 block will be read from consecutive memory positions
1089 execute_window_loop(window, [&](const Coordinates & id)
1090 {
1091 auto mtx_a0 = reinterpret_cast<const qint8_t *>(ina.ptr());
1092 auto mtx_b0 = reinterpret_cast<const qint8_t *>(inb.ptr());
1093 auto mtx_b1 = mtx_b0 + in_b_stride;
1094
1095 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
1096 qint16x8_t acc10_qs16 = vdupq_n_qs16(0);
1097 qint16x8_t acc20_qs16 = vdupq_n_qs16(0);
1098 qint16x8_t acc30_qs16 = vdupq_n_qs16(0);
1099
1100 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
1101 qint16x8_t acc11_qs16 = vdupq_n_qs16(0);
1102 qint16x8_t acc21_qs16 = vdupq_n_qs16(0);
1103 qint16x8_t acc31_qs16 = vdupq_n_qs16(0);
1104
1105 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
1106 qint16x8_t acc12_qs16 = vdupq_n_qs16(0);
1107 qint16x8_t acc22_qs16 = vdupq_n_qs16(0);
1108 qint16x8_t acc32_qs16 = vdupq_n_qs16(0);
1109
1110 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
1111 qint16x8_t acc13_qs16 = vdupq_n_qs16(0);
1112 qint16x8_t acc23_qs16 = vdupq_n_qs16(0);
1113 qint16x8_t acc33_qs16 = vdupq_n_qs16(0);
1114
1115 int k = 0;
1116 // This for loop performs 2 accumulations
1117 for(; k <= (num_elems_matrix_b_x - 32); k += 32)
1118 {
1119 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1120 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1121 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1122 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1123 const qint8x8_t a4 = vld1_dup_qs8(mtx_a0 + 4);
1124 const qint8x8_t a5 = vld1_dup_qs8(mtx_a0 + 5);
1125 const qint8x8_t a6 = vld1_dup_qs8(mtx_a0 + 6);
1126 const qint8x8_t a7 = vld1_dup_qs8(mtx_a0 + 7);
1127
1128 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1129 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1130 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1131 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1132
1133 // First accumulation
1134 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1135 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1136 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1137 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1138 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1139 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1140 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1141 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1142
1143 const qint8x8_t b02 = vld1_qs8(mtx_b0 + 16);
1144 const qint8x8_t b03 = vld1_qs8(mtx_b0 + 24);
1145 const qint8x8_t b12 = vld1_qs8(mtx_b1 + 16);
1146 const qint8x8_t b13 = vld1_qs8(mtx_b1 + 24);
1147
1148 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1149 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1150 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1151 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1152 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1153 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1154 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1155 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1156
1157#if __arm__
1158 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
1159 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
1160 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +01001161#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001162
1163 // Second accumulation
1164 acc00_qs16 = vqmlal_qs8(acc00_qs16, b02, a4, fixed_point_position);
1165 acc10_qs16 = vqmlal_qs8(acc10_qs16, b02, a5, fixed_point_position);
1166 acc20_qs16 = vqmlal_qs8(acc20_qs16, b02, a6, fixed_point_position);
1167 acc30_qs16 = vqmlal_qs8(acc30_qs16, b02, a7, fixed_point_position);
1168 acc01_qs16 = vqmlal_qs8(acc01_qs16, b03, a4, fixed_point_position);
1169 acc11_qs16 = vqmlal_qs8(acc11_qs16, b03, a5, fixed_point_position);
1170 acc21_qs16 = vqmlal_qs8(acc21_qs16, b03, a6, fixed_point_position);
1171 acc31_qs16 = vqmlal_qs8(acc31_qs16, b03, a7, fixed_point_position);
1172 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a4, fixed_point_position);
1173 acc12_qs16 = vqmlal_qs8(acc12_qs16, b12, a5, fixed_point_position);
1174 acc22_qs16 = vqmlal_qs8(acc22_qs16, b12, a6, fixed_point_position);
1175 acc32_qs16 = vqmlal_qs8(acc32_qs16, b12, a7, fixed_point_position);
1176 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a4, fixed_point_position);
1177 acc13_qs16 = vqmlal_qs8(acc13_qs16, b13, a5, fixed_point_position);
1178 acc23_qs16 = vqmlal_qs8(acc23_qs16, b13, a6, fixed_point_position);
1179 acc33_qs16 = vqmlal_qs8(acc33_qs16, b13, a7, fixed_point_position);
1180
1181 mtx_a0 += 8;
1182 mtx_b0 += 32;
1183 mtx_b1 += 32;
1184 }
1185
1186 // This for loop performs the left over accumulations
1187 for(; k < num_elems_matrix_b_x; k += 16)
1188 {
1189 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1190 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1191 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1192 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1193
1194 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1195 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1196 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1197 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1198
1199 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1200 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1201 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1202 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1203 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1204 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1205 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1206 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1207 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1208 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1209 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1210 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1211 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1212 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1213 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1214 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1215
1216 mtx_a0 += 4;
1217 mtx_b0 += 16;
1218 mtx_b1 += 16;
1219 }
1220
1221 // Convert back to qint8x8_t and saturate
1222 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
1223 qint8x8_t acc10_qs8 = vqmovn_qs16(acc10_qs16);
1224 qint8x8_t acc20_qs8 = vqmovn_qs16(acc20_qs16);
1225 qint8x8_t acc30_qs8 = vqmovn_qs16(acc30_qs16);
1226
1227 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
1228 qint8x8_t acc11_qs8 = vqmovn_qs16(acc11_qs16);
1229 qint8x8_t acc21_qs8 = vqmovn_qs16(acc21_qs16);
1230 qint8x8_t acc31_qs8 = vqmovn_qs16(acc31_qs16);
1231
1232 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
1233 qint8x8_t acc12_qs8 = vqmovn_qs16(acc12_qs16);
1234 qint8x8_t acc22_qs8 = vqmovn_qs16(acc22_qs16);
1235 qint8x8_t acc32_qs8 = vqmovn_qs16(acc32_qs16);
1236
1237 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
1238 qint8x8_t acc13_qs8 = vqmovn_qs16(acc13_qs16);
1239 qint8x8_t acc23_qs8 = vqmovn_qs16(acc23_qs16);
1240 qint8x8_t acc33_qs8 = vqmovn_qs16(acc33_qs16);
1241
1242 // Multiply by the weight of the matrix product (alpha)
1243 if(multiply_alpha)
1244 {
1245 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
1246 acc10_qs8 = vqmul_qs8(acc10_qs8, alpha_qs8, fixed_point_position);
1247 acc20_qs8 = vqmul_qs8(acc20_qs8, alpha_qs8, fixed_point_position);
1248 acc30_qs8 = vqmul_qs8(acc30_qs8, alpha_qs8, fixed_point_position);
1249 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
1250 acc11_qs8 = vqmul_qs8(acc11_qs8, alpha_qs8, fixed_point_position);
1251 acc21_qs8 = vqmul_qs8(acc21_qs8, alpha_qs8, fixed_point_position);
1252 acc31_qs8 = vqmul_qs8(acc31_qs8, alpha_qs8, fixed_point_position);
1253 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
1254 acc12_qs8 = vqmul_qs8(acc12_qs8, alpha_qs8, fixed_point_position);
1255 acc22_qs8 = vqmul_qs8(acc22_qs8, alpha_qs8, fixed_point_position);
1256 acc32_qs8 = vqmul_qs8(acc32_qs8, alpha_qs8, fixed_point_position);
1257 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
1258 acc13_qs8 = vqmul_qs8(acc13_qs8, alpha_qs8, fixed_point_position);
1259 acc23_qs8 = vqmul_qs8(acc23_qs8, alpha_qs8, fixed_point_position);
1260 acc33_qs8 = vqmul_qs8(acc33_qs8, alpha_qs8, fixed_point_position);
1261 }
1262
1263 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
1264
1265 // Store 32x4 output elements
1266 vst1_qs8(mtx_out0 + 0, acc00_qs8);
1267 vst1_qs8(mtx_out0 + 8, acc01_qs8);
1268 vst1_qs8(mtx_out0 + 16, acc02_qs8);
1269 vst1_qs8(mtx_out0 + 24, acc03_qs8);
1270 vst1_qs8(mtx_out0 + out_stride1 + 0, acc10_qs8);
1271 vst1_qs8(mtx_out0 + out_stride1 + 8, acc11_qs8);
1272 vst1_qs8(mtx_out0 + out_stride1 + 16, acc12_qs8);
1273 vst1_qs8(mtx_out0 + out_stride1 + 24, acc13_qs8);
1274 vst1_qs8(mtx_out0 + out_stride2 + 0, acc20_qs8);
1275 vst1_qs8(mtx_out0 + out_stride2 + 8, acc21_qs8);
1276 vst1_qs8(mtx_out0 + out_stride2 + 16, acc22_qs8);
1277 vst1_qs8(mtx_out0 + out_stride2 + 24, acc23_qs8);
1278 vst1_qs8(mtx_out0 + out_stride3 + 0, acc30_qs8);
1279 vst1_qs8(mtx_out0 + out_stride3 + 8, acc31_qs8);
1280 vst1_qs8(mtx_out0 + out_stride3 + 16, acc32_qs8);
1281 vst1_qs8(mtx_out0 + out_stride3 + 24, acc33_qs8);
1282 },
1283 ina, inb, out);
1284}
1285
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001286template <bool multiply_alpha>
1287void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1288{
1289 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1290 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1291 const size_t out_stride2 = out_stride1 * 2;
1292 const size_t out_stride3 = out_stride1 * 3;
1293 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1294 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001295 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001296 ARM_COMPUTE_UNUSED(alpha_qs16);
1297
1298 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1299 Window win_a(window);
1300 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1301 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1302
1303 Window win_b;
1304 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1305 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1306 if(input1->info()->num_dimensions() >= 3)
1307 {
1308 win_b = window;
1309 }
1310 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1311 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
1312 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1313
1314 Iterator ina(input0, win_a);
1315 Iterator inb(input1, win_b);
1316 Iterator out(output, window);
1317
1318 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1319 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
1320 // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
1321 execute_window_loop(window, [&](const Coordinates & id)
1322 {
1323 auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
1324 auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
1325 auto mtx_b1 = mtx_b0 + in_b_stride;
1326
1327 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
1328 qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
1329 qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
1330 qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
1331
1332 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
1333 qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
1334 qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
1335 qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
1336
1337 // This for loop performs 1 accumulation
1338 for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
1339 {
1340 const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
1341 const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
1342 const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
1343 const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
1344
1345 const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
1346 const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
1347
1348 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
1349 acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
1350 acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
1351 acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
1352 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
1353 acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
1354 acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
1355 acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
1356
1357 mtx_a0 += 4;
1358 mtx_b0 += 8;
1359 mtx_b1 += 8;
1360 }
1361
1362 // Convert back to qint16x4_t and saturate
1363 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
1364 qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
1365 qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
1366 qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
1367
1368 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
1369 qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
1370 qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
1371 qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
1372
1373 // Multiply by the weight of the matrix product (alpha)
1374 if(multiply_alpha)
1375 {
1376 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
1377 acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
1378 acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
1379 acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
1380 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
1381 acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
1382 acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
1383 acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
1384 }
1385
1386 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
1387
1388 // Store 8x4 output elements
1389 vst1_qs16(mtx_out0 + 0, acc00_qs16);
1390 vst1_qs16(mtx_out0 + 4, acc01_qs16);
1391 vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
1392 vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
1393 vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
1394 vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
1395 vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
1396 vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
1397 },
1398 ina, inb, out);
1399}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001400} // namespace
1401
1402NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
1403 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1404{
1405}
1406
1407void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
1408{
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001409 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001410 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
1411 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
1412
1413 if(output->info()->dimension(1) == 1)
1414 {
1415 ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
1416 }
1417
1418 _input0 = input0;
1419 _input1 = input1;
1420 _output = output;
1421 _alpha = alpha;
1422
1423 unsigned int num_elems_processed_per_iteration_x = 0;
1424 const unsigned int num_elems_processed_per_iteration_y = 4;
1425
1426 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1427 if((output->info()->dimension(1) == 1))
1428 {
1429 switch(input0->info()->data_type())
1430 {
1431 case DataType::F32:
1432 {
1433 num_elems_processed_per_iteration_x = 16;
1434 break;
1435 }
1436 case DataType::QS8:
1437 {
1438 num_elems_processed_per_iteration_x = 32;
1439 break;
1440 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001441 case DataType::QS16:
1442 {
1443 num_elems_processed_per_iteration_x = 16;
1444 break;
1445 }
Pablo Tello221f3812017-06-28 17:27:56 +01001446#ifdef ARM_COMPUTE_ENABLE_FP16
1447 case DataType::F16:
1448 {
1449 num_elems_processed_per_iteration_x = 32;
1450 break;
1451 }
1452#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001453 default:
1454 {
1455 ARM_COMPUTE_ERROR("Data type not supported");
1456 break;
1457 }
1458 }
1459
1460 // Configure kernel window
1461 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1462
1463 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
1464
1465 update_window_and_padding(win,
Moritz Pflanzer484e7b32017-08-09 11:43:18 +01001466 AccessWindowStatic(input0->info(), 0, 0, input0->info()->tensor_shape().x(), 1),
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001467 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration_x),
1468 output_access);
1469
1470 Coordinates coord;
1471 coord.set_num_dimensions(output->info()->num_dimensions());
1472 output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
1473
1474 INEKernel::configure(win);
1475 }
1476 else
1477 {
1478 switch(input0->info()->data_type())
1479 {
1480 case DataType::F32:
1481 {
1482 num_elems_processed_per_iteration_x = 8;
1483 break;
1484 }
1485 case DataType::QS8:
1486 {
1487 num_elems_processed_per_iteration_x = 32;
1488 break;
1489 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001490 case DataType::QS16:
1491 {
1492 num_elems_processed_per_iteration_x = 8;
1493 break;
1494 }
Pablo Tello221f3812017-06-28 17:27:56 +01001495#ifdef ARM_COMPUTE_ENABLE_FP16
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001496 case DataType::F16:
1497 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001498 num_elems_processed_per_iteration_x = 8;
1499 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001500 }
Pablo Tello221f3812017-06-28 17:27:56 +01001501#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001502 default:
1503 {
1504 ARM_COMPUTE_ERROR("Data type not supported");
1505 break;
1506 }
1507 }
1508
1509 // Configure kernel window
1510 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1511
1512 AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
1513
1514 update_window_and_padding(win,
1515 AccessWindowRectangle(input0->info(), 0, 0, 4, 1, 1.f, 0.25f),
Georgios Pinitasce54b562017-09-14 17:21:51 +01001516 AccessWindowStatic(input1->info(), 0, 0, input1->info()->tensor_shape().x(), ceil_to_multiple(input1->info()->tensor_shape().y(), 4)),
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001517 output_access);
1518
1519 output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
1520
1521 INEKernel::configure(win);
1522 }
1523}
1524
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001525void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001526{
1527 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1528 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1529
1530 bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
1531
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001532 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001533 if((_output->info()->dimension(1) == 1))
1534 {
1535 switch(_input0->info()->data_type())
1536 {
1537 case DataType::F32:
1538 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001539 multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, info, _alpha) :
1540 vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001541 break;
1542 }
1543 case DataType::QS8:
1544 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001545 multiply_alpha ? vector_matrix_multiply_qs8<true>(_input0, _input1, _output, window, info, _alpha) :
1546 vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, info, _alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001547 break;
1548 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001549 case DataType::QS16:
1550 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001551 multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, info, _alpha) :
1552 vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, info, _alpha);
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001553 break;
1554 }
Pablo Tello221f3812017-06-28 17:27:56 +01001555#ifdef ARM_COMPUTE_ENABLE_FP16
1556 case DataType::F16:
1557 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001558 multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, info, _alpha) :
1559 vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, info, _alpha);
Pablo Tello221f3812017-06-28 17:27:56 +01001560 break;
1561 }
1562#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001563 default:
1564 {
1565 ARM_COMPUTE_ERROR("Data type not supported");
1566 break;
1567 }
1568 }
1569 }
1570 else
1571 {
1572 switch(_input0->info()->data_type())
1573 {
1574 case DataType::F32:
1575 {
1576 multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1577 matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1578 break;
1579 }
1580 case DataType::QS8:
1581 {
1582 multiply_alpha ? matrix_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1583 matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1584 break;
1585 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001586 case DataType::QS16:
1587 {
1588 multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
1589 matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
1590 break;
1591 }
Pablo Tello221f3812017-06-28 17:27:56 +01001592#ifdef ARM_COMPUTE_ENABLE_FP16
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001593 case DataType::F16:
1594 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001595 multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1596 matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1597 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001598 }
Pablo Tello221f3812017-06-28 17:27:56 +01001599#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001600 default:
1601 {
1602 ARM_COMPUTE_ERROR("Data type not supported");
1603 break;
1604 }
1605 }
1606 }
1607}