blob: 8381dd8a737b504eaf40fb9e83fa0893a637e029 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
26#include "arm_compute/core/AccessWindowTranspose.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Types.h"
34#include "arm_compute/core/Utils.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
37
38#include <arm_neon.h>
39#include <cstddef>
40#include <cstdint>
41#include <tuple>
42
43using namespace arm_compute;
44
45namespace arm_compute
46{
47class Coordinates;
48} // namespace arm_compute
49
50namespace
51{
52template <bool multiply_alpha>
Pablo Tello221f3812017-06-28 17:27:56 +010053void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
54{
55#ifdef ARM_COMPUTE_ENABLE_FP16
56 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
57 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
58 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
59
60 // The implementation computes 32 elements per iteration
61 const int window_start_x = 32 * window.thread_id();
62 const int window_step_x = 32 * window.num_threads();
63 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
64 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
65
66 Window win_out(window);
67 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
68 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
69
70 Window win_a(window);
71 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
72 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
73
74 Window win_b;
75 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
76 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
77 if(input1->info()->num_dimensions() >= 3)
78 {
79 win_b = window;
80 }
81 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
82 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
83
84 Iterator ina(input0, win_a);
85 Iterator inb(input1, win_b);
86 Iterator out(output, win_out);
87
88 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
89 ARM_COMPUTE_UNUSED(alpha_f16);
90
91 execute_window_loop(win_out, [&](const Coordinates & id)
92 {
93 if(id.x() > width_matrix_b)
94 {
95 return;
96 }
97
98 float16x8_t acc0 = vdupq_n_f16(0.f);
99 float16x8_t acc1 = vdupq_n_f16(0.f);
100 float16x8_t acc2 = vdupq_n_f16(0.f);
101 float16x8_t acc3 = vdupq_n_f16(0.f);
102
103 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
104 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
105
106 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
107 for(; vec_a <= (vec_a_end_addr - 4);)
108 {
109 const float16x4_t a0l = vld1_f16(vec_a);
110
111 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
112 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
113 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
114 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
115 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
116 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
117 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
118 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
119
120 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
121 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
122 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
123 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
124 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
125 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
126 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
127 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
128
129 matrix_b += 2 * in_b_stride;
130
131 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
132 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
133 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
134 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
135 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
136 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
137 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
138 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
139
140 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
141 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
142 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
143 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
144 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
145 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
146 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
147 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
148
149 vec_a += 4;
150 matrix_b += 2 * in_b_stride;
151 }
152
153 for(; vec_a < vec_a_end_addr;)
154 {
155 const float16_t a0 = *vec_a;
156 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
157 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
158 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
159 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
160
161 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
162 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
163 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
164 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
165
166 vec_a += 1;
167 matrix_b += in_b_stride;
168 }
169
170 // Multiply by the weight of matrix product (alpha)
171 if(multiply_alpha)
172 {
173 acc0 = vmulq_f16(acc0, alpha_f16);
174 acc1 = vmulq_f16(acc1, alpha_f16);
175 acc2 = vmulq_f16(acc2, alpha_f16);
176 acc3 = vmulq_f16(acc3, alpha_f16);
177 }
178
179 const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
180
181 vst1q_f16(vec_out + 0, acc0);
182 vst1q_f16(vec_out + 8, acc1);
183 vst1q_f16(vec_out + 16, acc2);
184 vst1q_f16(vec_out + 24, acc3);
185
186 },
187 ina, inb, out);
188#else /* ARM_COMPUTE_ENABLE_FP16 */
189 ARM_COMPUTE_ERROR("Not implemented");
190#endif /* ARM_COMPUTE_ENABLE_FP16 */
191}
192
193template <bool multiply_alpha>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
195{
196 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
197 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
198 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
199
200 // The implementation computes 16 elements per iteration
201 const int window_start_x = 16 * window.thread_id();
202 const int window_step_x = 16 * window.num_threads();
203 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
204 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
205
206 Window win_out(window);
207 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
208 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
209
210 Window win_a(window);
211 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
212 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
213
214 Window win_b;
215 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
216 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
217 if(input1->info()->num_dimensions() >= 3)
218 {
219 win_b = window;
220 }
221 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
222 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
223
224 Iterator ina(input0, win_a);
225 Iterator inb(input1, win_b);
226 Iterator out(output, win_out);
227
228 execute_window_loop(win_out, [&](const Coordinates & id)
229 {
230 if(id.x() > width_matrix_b)
231 {
232 return;
233 }
234
235 float32x4_t acc0 = vdupq_n_f32(0.f);
236 float32x4_t acc1 = vdupq_n_f32(0.f);
237 float32x4_t acc2 = vdupq_n_f32(0.f);
238 float32x4_t acc3 = vdupq_n_f32(0.f);
239
240 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
241 auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
242
243#if __arm__
244 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
245 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
246 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100247#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100248
249 auto vec_a_end_addr = vec_a + num_elems_vec_a;
250 for(; vec_a <= (vec_a_end_addr - 4);)
251 {
252 float32x2_t a0l = vld1_f32(vec_a);
253
254 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
255 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
256 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
257 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
258
259 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
260 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
261 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
262 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
263
264#if __arm__
265 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
266 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
267 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
268 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
269 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100270#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100271
272 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
273 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
274 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
275 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
276
277 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
278 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
279 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
280 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
281
282 vec_a += 2;
283 matrix_b += 2 * in_b_stride;
284
285 a0l = vld1_f32(vec_a);
286
287 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
288 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
289 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
290 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
291
292 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
293 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
294 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
295 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
296
297 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
298 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
299 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
300 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
301
302 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
303 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
304 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
305 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
306
307 vec_a += 2;
308 matrix_b += 2 * in_b_stride;
309 }
310
311 for(; vec_a < vec_a_end_addr;)
312 {
313 const float a0 = *vec_a;
314
315 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
316 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
317 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
318 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
319
320 acc0 = vmlaq_n_f32(acc0, b00, a0);
321 acc1 = vmlaq_n_f32(acc1, b01, a0);
322 acc2 = vmlaq_n_f32(acc2, b02, a0);
323 acc3 = vmlaq_n_f32(acc3, b03, a0);
324
325 vec_a += 1;
326 matrix_b += in_b_stride;
327 }
328
329 // Multiply by the weight of matrix product (alpha)
330 if(multiply_alpha)
331 {
332 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
333 acc0 = vmulq_f32(acc0, alpha_f32);
334 acc1 = vmulq_f32(acc1, alpha_f32);
335 acc2 = vmulq_f32(acc2, alpha_f32);
336 acc3 = vmulq_f32(acc3, alpha_f32);
337 }
338
339 const auto vec_out = reinterpret_cast<float *>(out.ptr());
340
341 vst1q_f32(vec_out + 0, acc0);
342 vst1q_f32(vec_out + 4, acc1);
343 vst1q_f32(vec_out + 8, acc2);
344 vst1q_f32(vec_out + 12, acc3);
345 },
346 ina, inb, out);
347}
348
349template <bool multiply_alpha>
350void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
351{
352 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
353 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
354 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
355 const int fixed_point_position = input0->info()->fixed_point_position();
356
357 // The implementation computes 32 elements per iteration
358 const int window_start_x = 32 * window.thread_id();
359 const int window_step_x = 32 * window.num_threads();
360 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
361 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
362
363 Window win_out(window);
364 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
365 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
366
367 Window win_a(window);
368 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
369 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
370
371 Window win_b;
372 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
373 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
374 if(input1->info()->num_dimensions() >= 3)
375 {
376 win_b = window;
377 }
378 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
379 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
380
381 Iterator ina(input0, win_a);
382 Iterator inb(input1, win_b);
383 Iterator out(output, win_out);
384
385 execute_window_loop(win_out, [&](const Coordinates & id)
386 {
387 if(id.x() > width_matrix_b)
388 {
389 return;
390 }
391
392 // Reset accumulators
393 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
394 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
395 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
396 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
397
398 auto vec_a = reinterpret_cast<const qint8_t *>(ina.ptr());
399 auto matrix_b = reinterpret_cast<const qint8_t *>(inb.ptr());
400
401 auto vec_a_end_addr = vec_a + num_elems_vec_a;
402 for(; vec_a <= (vec_a_end_addr - 2);)
403 {
404 const qint8x8_t a0 = vld1_dup_qs8(vec_a + 0);
405 const qint8x8_t a1 = vld1_dup_qs8(vec_a + 1);
406
407 const qint8x8_t b00 = vld1_qs8(matrix_b + 0 + 0 * in_b_stride);
408 const qint8x8_t b01 = vld1_qs8(matrix_b + 8 + 0 * in_b_stride);
409 const qint8x8_t b02 = vld1_qs8(matrix_b + 16 + 0 * in_b_stride);
410 const qint8x8_t b03 = vld1_qs8(matrix_b + 24 + 0 * in_b_stride);
411 const qint8x8_t b10 = vld1_qs8(matrix_b + 0 + 1 * in_b_stride);
412 const qint8x8_t b11 = vld1_qs8(matrix_b + 8 + 1 * in_b_stride);
413 const qint8x8_t b12 = vld1_qs8(matrix_b + 16 + 1 * in_b_stride);
414 const qint8x8_t b13 = vld1_qs8(matrix_b + 24 + 1 * in_b_stride);
415
416 // First accumulation
417 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
418 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
419 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
420 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
421
422 // Second accumulation
423 acc00_qs16 = vqmlal_qs8(acc00_qs16, b10, a1, fixed_point_position);
424 acc01_qs16 = vqmlal_qs8(acc01_qs16, b11, a1, fixed_point_position);
425 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a1, fixed_point_position);
426 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a1, fixed_point_position);
427
428 vec_a += 2;
429 matrix_b += 2 * in_b_stride;
430 }
431
432 for(; vec_a < vec_a_end_addr;)
433 {
434 const qint8x8_t a0 = vld1_dup_qs8(vec_a);
435
436 const qint8x8_t b00 = vld1_qs8(matrix_b + 0);
437 const qint8x8_t b01 = vld1_qs8(matrix_b + 8);
438 const qint8x8_t b02 = vld1_qs8(matrix_b + 16);
439 const qint8x8_t b03 = vld1_qs8(matrix_b + 24);
440
441 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
442 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
443 acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
444 acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
445
446 vec_a += 1;
447 matrix_b += in_b_stride;
448 }
449
450 // Convert back to qint8x8_t and saturate
451 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
452 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
453 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
454 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
455
456 // Multiply by the weight of the matrix product (alpha)
457 if(multiply_alpha)
458 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100459 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100460 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
461 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
462 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
463 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
464 }
465
466 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
467
468 // Store 8x4 output elements
469 vst1_qs8(mtx_out0 + 0, acc00_qs8);
470 vst1_qs8(mtx_out0 + 8, acc01_qs8);
471 vst1_qs8(mtx_out0 + 16, acc02_qs8);
472 vst1_qs8(mtx_out0 + 24, acc03_qs8);
473 },
474 ina, inb, out);
475}
476
477template <bool multiply_alpha>
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100478void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
479{
480 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
481 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
482 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
483 const int fixed_point_position = input0->info()->fixed_point_position();
484
485 // The implementation computes 16 elements per iteration
486 const int window_start_x = 16 * window.thread_id();
487 const int window_step_x = 16 * window.num_threads();
488 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
489 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
490 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
491
492 Window win_out(window);
493 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
494 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
495
496 Window win_a(window);
497 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
498 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
499
500 Window win_b;
501 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
502 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
503 if(input1->info()->num_dimensions() >= 3)
504 {
505 win_b = window;
506 }
507 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
508 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
509
510 Iterator ina(input0, win_a);
511 Iterator inb(input1, win_b);
512 Iterator out(output, win_out);
513
514 execute_window_loop(win_out, [&](const Coordinates & id)
515 {
516 if(id.x() > width_matrix_b)
517 {
518 return;
519 }
520
521 // Reset accumulators
522 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
523 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
524 qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
525 qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
526
527 auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
528 auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
529
530 auto vec_a_end_addr = vec_a + num_elems_vec_a;
531 for(; vec_a <= (vec_a_end_addr - 2);)
532 {
533 const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
534 const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
535
536 const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
537 const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
538 const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
539 const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
540 const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
541 const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
542 const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
543 const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
544
545 // First accumulation
546 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
547 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
548 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
549 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
550
551 // Second accumulation
552 acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
553 acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
554 acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
555 acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
556
557 vec_a += 2;
558 matrix_b += 2 * in_b_stride;
559 }
560
561 for(; vec_a < vec_a_end_addr;)
562 {
563 const qint16x4_t a0 = vld1_dup_qs16(vec_a);
564
565 const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
566 const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
567 const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
568 const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
569
570 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
571 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
572 acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
573 acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
574
575 vec_a += 1;
576 matrix_b += in_b_stride;
577 }
578
579 // Convert back to qint16x4_t and saturate
580 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
581 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
582 qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
583 qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
584
585 // Multiply by the weight of the matrix product (alpha)
586 if(multiply_alpha)
587 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100588 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100589 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
590 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
591 acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
592 acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
593 }
594
595 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
596
597 // Store 16x4 output elements
598 vst1_qs16(mtx_out0 + 0, acc00_qs16);
599 vst1_qs16(mtx_out0 + 4, acc01_qs16);
600 vst1_qs16(mtx_out0 + 8, acc02_qs16);
601 vst1_qs16(mtx_out0 + 12, acc03_qs16);
602 },
603 ina, inb, out);
604}
605
606template <bool multiply_alpha>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100607void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
608{
609 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
610 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
611 const size_t out_stride2 = out_stride1 * 2;
612 const size_t out_stride3 = out_stride1 * 3;
613 const int num_elems_matrix_b_x = input1->info()->dimension(0);
614
615 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
616 Window win_a(window);
617 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
618 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
619
620 Window win_b;
621 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
622 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
623 if(input1->info()->num_dimensions() >= 3)
624 {
625 win_b = window;
626 }
627 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
628 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
629 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
630 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
631
632 Iterator ina(input0, win_a);
633 Iterator inb(input1, win_b);
634 Iterator out(output, window);
635
636 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
637 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
638 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
639 execute_window_loop(window, [&](const Coordinates & id)
640 {
641 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
642 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
643 auto mtx_b1 = mtx_b0 + in_b_stride;
644
645 float32x4_t acc00 = vdupq_n_f32(0.f);
646 float32x4_t acc10 = vdupq_n_f32(0.f);
647 float32x4_t acc20 = vdupq_n_f32(0.f);
648 float32x4_t acc30 = vdupq_n_f32(0.f);
649
650 float32x4_t acc01 = vdupq_n_f32(0.f);
651 float32x4_t acc11 = vdupq_n_f32(0.f);
652 float32x4_t acc21 = vdupq_n_f32(0.f);
653 float32x4_t acc31 = vdupq_n_f32(0.f);
654
655#if __arm__
656 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
657 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
658 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100659#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100660
661 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
662 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
663 {
664 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
665 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
666 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
667 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
668
669 float32x4_t b00 = vld1q_f32(mtx_b0);
670 float32x4_t b10 = vld1q_f32(mtx_b1);
671 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
672 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
673
674#if __arm__
675 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
676 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
677 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100678#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100679
680 // 4x4 block 0
681 acc00 = vmlaq_f32(acc00, b00, a0);
682 acc10 = vmlaq_f32(acc10, b00, a1);
683 acc20 = vmlaq_f32(acc20, b00, a2);
684 acc30 = vmlaq_f32(acc30, b00, a3);
685
686 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
687 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
688 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
689 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
690
691 // 4x4 block 1
692 acc01 = vmlaq_f32(acc01, b10, a0);
693 acc11 = vmlaq_f32(acc11, b10, a1);
694 acc21 = vmlaq_f32(acc21, b10, a2);
695 acc31 = vmlaq_f32(acc31, b10, a3);
696
697 // 4x4 block 0
698 acc00 = vmlaq_f32(acc00, b01, a4);
699 acc10 = vmlaq_f32(acc10, b01, a5);
700 acc20 = vmlaq_f32(acc20, b01, a6);
701 acc30 = vmlaq_f32(acc30, b01, a7);
702
703 // 4x4 block 1
704 acc01 = vmlaq_f32(acc01, b11, a4);
705 acc11 = vmlaq_f32(acc11, b11, a5);
706 acc21 = vmlaq_f32(acc21, b11, a6);
707 acc31 = vmlaq_f32(acc31, b11, a7);
708
709 mtx_a0 += 8;
710 mtx_b0 += 8;
711 mtx_b1 += 8;
712
713 a0 = vld1q_dup_f32(mtx_a0 + 0);
714 a1 = vld1q_dup_f32(mtx_a0 + 1);
715 a2 = vld1q_dup_f32(mtx_a0 + 2);
716 a3 = vld1q_dup_f32(mtx_a0 + 3);
717
718 b00 = vld1q_f32(mtx_b0);
719 b10 = vld1q_f32(mtx_b1);
720 b01 = vld1q_f32(mtx_b0 + 4);
721 b11 = vld1q_f32(mtx_b1 + 4);
722
723 // 4x4 block 0
724 acc00 = vmlaq_f32(acc00, b00, a0);
725 acc10 = vmlaq_f32(acc10, b00, a1);
726 acc20 = vmlaq_f32(acc20, b00, a2);
727 acc30 = vmlaq_f32(acc30, b00, a3);
728
729 a4 = vld1q_dup_f32(mtx_a0 + 4);
730 a5 = vld1q_dup_f32(mtx_a0 + 5);
731 a6 = vld1q_dup_f32(mtx_a0 + 6);
732 a7 = vld1q_dup_f32(mtx_a0 + 7);
733
734 // 4x4 block 1
735 acc01 = vmlaq_f32(acc01, b10, a0);
736 acc11 = vmlaq_f32(acc11, b10, a1);
737 acc21 = vmlaq_f32(acc21, b10, a2);
738 acc31 = vmlaq_f32(acc31, b10, a3);
739
740 // 4x4 block 0
741 acc00 = vmlaq_f32(acc00, b01, a4);
742 acc10 = vmlaq_f32(acc10, b01, a5);
743 acc20 = vmlaq_f32(acc20, b01, a6);
744 acc30 = vmlaq_f32(acc30, b01, a7);
745
746 // 4x4 block 1
747 acc01 = vmlaq_f32(acc01, b11, a4);
748 acc11 = vmlaq_f32(acc11, b11, a5);
749 acc21 = vmlaq_f32(acc21, b11, a6);
750 acc31 = vmlaq_f32(acc31, b11, a7);
751
752 mtx_a0 += 8;
753 mtx_b0 += 8;
754 mtx_b1 += 8;
755
756 a0 = vld1q_dup_f32(mtx_a0 + 0);
757 a1 = vld1q_dup_f32(mtx_a0 + 1);
758 a2 = vld1q_dup_f32(mtx_a0 + 2);
759 a3 = vld1q_dup_f32(mtx_a0 + 3);
760 b00 = vld1q_f32(mtx_b0);
761 b10 = vld1q_f32(mtx_b1);
762 b01 = vld1q_f32(mtx_b0 + 4);
763 b11 = vld1q_f32(mtx_b1 + 4);
764
765#if __arm__
766 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
767 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
768 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100769#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100770
771 // 4x4 block 0
772 acc00 = vmlaq_f32(acc00, b00, a0);
773 acc10 = vmlaq_f32(acc10, b00, a1);
774 acc20 = vmlaq_f32(acc20, b00, a2);
775 acc30 = vmlaq_f32(acc30, b00, a3);
776
777 a4 = vld1q_dup_f32(mtx_a0 + 4);
778 a5 = vld1q_dup_f32(mtx_a0 + 5);
779 a6 = vld1q_dup_f32(mtx_a0 + 6);
780 a7 = vld1q_dup_f32(mtx_a0 + 7);
781
782 // 4x4 block 1
783 acc01 = vmlaq_f32(acc01, b10, a0);
784 acc11 = vmlaq_f32(acc11, b10, a1);
785 acc21 = vmlaq_f32(acc21, b10, a2);
786 acc31 = vmlaq_f32(acc31, b10, a3);
787
788 // 4x4 block 0
789 acc00 = vmlaq_f32(acc00, b01, a4);
790 acc10 = vmlaq_f32(acc10, b01, a5);
791 acc20 = vmlaq_f32(acc20, b01, a6);
792 acc30 = vmlaq_f32(acc30, b01, a7);
793
794 // 4x4 block 1
795 acc01 = vmlaq_f32(acc01, b11, a4);
796 acc11 = vmlaq_f32(acc11, b11, a5);
797 acc21 = vmlaq_f32(acc21, b11, a6);
798 acc31 = vmlaq_f32(acc31, b11, a7);
799
800 mtx_a0 += 8;
801 mtx_b0 += 8;
802 mtx_b1 += 8;
803
804 a0 = vld1q_dup_f32(mtx_a0 + 0);
805 a1 = vld1q_dup_f32(mtx_a0 + 1);
806 a2 = vld1q_dup_f32(mtx_a0 + 2);
807 a3 = vld1q_dup_f32(mtx_a0 + 3);
808 b00 = vld1q_f32(mtx_b0);
809 b10 = vld1q_f32(mtx_b1);
810 b01 = vld1q_f32(mtx_b0 + 4);
811 b11 = vld1q_f32(mtx_b1 + 4);
812
813 // 4x4 block 0
814 acc00 = vmlaq_f32(acc00, b00, a0);
815 acc10 = vmlaq_f32(acc10, b00, a1);
816 acc20 = vmlaq_f32(acc20, b00, a2);
817 acc30 = vmlaq_f32(acc30, b00, a3);
818
819 a4 = vld1q_dup_f32(mtx_a0 + 4);
820 a5 = vld1q_dup_f32(mtx_a0 + 5);
821 a6 = vld1q_dup_f32(mtx_a0 + 6);
822 a7 = vld1q_dup_f32(mtx_a0 + 7);
823
824 // 4x4 block 1
825 acc01 = vmlaq_f32(acc01, b10, a0);
826 acc11 = vmlaq_f32(acc11, b10, a1);
827 acc21 = vmlaq_f32(acc21, b10, a2);
828 acc31 = vmlaq_f32(acc31, b10, a3);
829
830 // 4x4 block 0
831 acc00 = vmlaq_f32(acc00, b01, a4);
832 acc10 = vmlaq_f32(acc10, b01, a5);
833 acc20 = vmlaq_f32(acc20, b01, a6);
834 acc30 = vmlaq_f32(acc30, b01, a7);
835
836 // 4x4 block 1
837 acc01 = vmlaq_f32(acc01, b11, a4);
838 acc11 = vmlaq_f32(acc11, b11, a5);
839 acc21 = vmlaq_f32(acc21, b11, a6);
840 acc31 = vmlaq_f32(acc31, b11, a7);
841
842 mtx_a0 += 8;
843 mtx_b0 += 8;
844 mtx_b1 += 8;
845 }
846
847 for(; mtx_b0 < mtx_b0_end_addr;)
848 {
849 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
850 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
851 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
852 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
853 float32x4_t b00 = vld1q_f32(mtx_b0);
854 float32x4_t b10 = vld1q_f32(mtx_b1);
855
856#if __arm__
857 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
858 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
859 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100860#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100861 // 4x4 block 0
862 acc00 = vmlaq_f32(acc00, b00, a0);
863 acc10 = vmlaq_f32(acc10, b00, a1);
864 acc20 = vmlaq_f32(acc20, b00, a2);
865 acc30 = vmlaq_f32(acc30, b00, a3);
866
867 // 4x4 block 1
868 acc01 = vmlaq_f32(acc01, b10, a0);
869 acc11 = vmlaq_f32(acc11, b10, a1);
870 acc21 = vmlaq_f32(acc21, b10, a2);
871 acc31 = vmlaq_f32(acc31, b10, a3);
872
873 mtx_a0 += 4;
874 mtx_b0 += 4;
875 mtx_b1 += 4;
876 }
877
878 // Multiply by the weight of matrix product (alpha)
879 if(multiply_alpha)
880 {
881 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
882 acc00 = vmulq_f32(acc00, alpha_f32);
883 acc10 = vmulq_f32(acc10, alpha_f32);
884 acc20 = vmulq_f32(acc20, alpha_f32);
885 acc30 = vmulq_f32(acc30, alpha_f32);
886 acc01 = vmulq_f32(acc01, alpha_f32);
887 acc11 = vmulq_f32(acc11, alpha_f32);
888 acc21 = vmulq_f32(acc21, alpha_f32);
889 acc31 = vmulq_f32(acc31, alpha_f32);
890 }
891
892 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
893 const auto mtx_out1 = mtx_out0 + 4;
894
895 // Store the 4 blocks
896 vst1q_f32(mtx_out0, acc00);
897 vst1q_f32(mtx_out1, acc01);
898 vst1q_f32(mtx_out0 + out_stride1, acc10);
899 vst1q_f32(mtx_out1 + out_stride1, acc11);
900 vst1q_f32(mtx_out0 + out_stride2, acc20);
901 vst1q_f32(mtx_out1 + out_stride2, acc21);
902 vst1q_f32(mtx_out0 + out_stride3, acc30);
903 vst1q_f32(mtx_out1 + out_stride3, acc31);
904 },
905 ina, inb, out);
906}
907
908template <bool multiply_alpha>
909void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
910{
911#ifdef ARM_COMPUTE_ENABLE_FP16
Pablo Tello221f3812017-06-28 17:27:56 +0100912 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
913 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
914 const int num_elems_matrix_b_x = input1->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100915
916 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
917 Window win_a(window);
918 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
919 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
920
921 Window win_b;
922 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
923 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
924 if(input1->info()->num_dimensions() >= 3)
925 {
926 win_b = window;
927 }
928 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
929 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
930 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
931
932 Iterator ina(input0, win_a);
933 Iterator inb(input1, win_b);
934 Iterator out(output, window);
935
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100936 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
937
938 execute_window_loop(window, [&](const Coordinates & id)
939 {
940 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
941 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
942 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
943 float16x8x4_t c =
944 {
945 {
946 vdupq_n_f16(0.f),
947 vdupq_n_f16(0.f),
948 vdupq_n_f16(0.f),
949 vdupq_n_f16(0.f)
950 }
951 };
952
953 /*
954 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
955 |a00 a01 a02 a03 | a04 a05 a06 a07|
956 |a10 a11 a12 a13 | a14 a15 a16 a17|
957 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
958 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
959 |a40 a41 a42 a43 | a44 a45 a46 a47|
960 |a50 a51 a52 a53 | a54 a55 a56 a57|
961 |a60 a61 a62 a63 | a64 a65 a66 a67|
962 |a70 a71 a72 a73 | a74 a75 a76 a77|
963
964 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
965
966 B Matrix has been transposed as shown below
967
968 |b00 b01 b02 b03 b04 b05 b06 b07|
969 |b10 b11 b12 b13 b14 b15 b16 b17|
970 |b20 b21 b22 b23 b24 b25 b26 b27|
971 |b30 b31 b32 b33 b34 b35 b36 b37|
972 ------------------->
973
974 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
975
976 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
977 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
978
979 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
980 */
Pablo Tello221f3812017-06-28 17:27:56 +0100981 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
982
983 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
984
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100985 {
986 const float16x8_t p00 = vld1q_f16(mtx_a0);
987 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100988
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100989 const float16x8_t q00 = vld1q_f16(mtx_b0);
990 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
991 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
992 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
993
994 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
995 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
996 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
997 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
998
999 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
1000 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
1001 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
1002 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
1003
1004 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
1005 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
1006 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
1007 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
1008
1009 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
1010 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
1011 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
1012 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +01001013
1014 mtx_a0 += 16;
1015 mtx_b0 += 32;
1016 }
1017
1018 for(; mtx_b0 < mtx_b0_end_addr;)
1019
1020 {
1021 const float16x4_t p00 = vld1_f16(mtx_a0);
1022 const float16x8_t q00 = vld1q_f16(mtx_b0);
1023
1024 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
1025 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
1026 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
1027 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
1028
1029 mtx_a0 += 4;
1030 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001031 }
1032
1033 if(multiply_alpha)
1034 {
1035 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
1036 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
1037 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
1038 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
1039 }
1040
1041 vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
1042 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
1043 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
1044 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1045 },
1046 ina, inb, out);
Anthony Barbierac69aa12017-07-03 17:39:37 +01001047#else /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001048 ARM_COMPUTE_ERROR("Not implemented");
Anthony Barbierac69aa12017-07-03 17:39:37 +01001049#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001050}
1051
1052template <bool multiply_alpha>
1053void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1054{
1055 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1056 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1057 const size_t out_stride2 = out_stride1 * 2;
1058 const size_t out_stride3 = out_stride1 * 3;
1059 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1060 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001061 const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001062 ARM_COMPUTE_UNUSED(alpha_qs8);
1063
1064 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1065 Window win_a(window);
1066 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1067 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1068
1069 Window win_b;
1070 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1071 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1072 if(input1->info()->num_dimensions() >= 3)
1073 {
1074 win_b = window;
1075 }
1076 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1077 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 16x4
1078 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, 2 * in_b_stride));
1079 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1080
1081 Iterator ina(input0, win_a);
1082 Iterator inb(input1, win_b);
1083 Iterator out(output, window);
1084
1085 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1086 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
1087 // All the values needed for computing a single 32x4 block will be read from consecutive memory positions
1088 execute_window_loop(window, [&](const Coordinates & id)
1089 {
1090 auto mtx_a0 = reinterpret_cast<const qint8_t *>(ina.ptr());
1091 auto mtx_b0 = reinterpret_cast<const qint8_t *>(inb.ptr());
1092 auto mtx_b1 = mtx_b0 + in_b_stride;
1093
1094 qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
1095 qint16x8_t acc10_qs16 = vdupq_n_qs16(0);
1096 qint16x8_t acc20_qs16 = vdupq_n_qs16(0);
1097 qint16x8_t acc30_qs16 = vdupq_n_qs16(0);
1098
1099 qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
1100 qint16x8_t acc11_qs16 = vdupq_n_qs16(0);
1101 qint16x8_t acc21_qs16 = vdupq_n_qs16(0);
1102 qint16x8_t acc31_qs16 = vdupq_n_qs16(0);
1103
1104 qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
1105 qint16x8_t acc12_qs16 = vdupq_n_qs16(0);
1106 qint16x8_t acc22_qs16 = vdupq_n_qs16(0);
1107 qint16x8_t acc32_qs16 = vdupq_n_qs16(0);
1108
1109 qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
1110 qint16x8_t acc13_qs16 = vdupq_n_qs16(0);
1111 qint16x8_t acc23_qs16 = vdupq_n_qs16(0);
1112 qint16x8_t acc33_qs16 = vdupq_n_qs16(0);
1113
1114 int k = 0;
1115 // This for loop performs 2 accumulations
1116 for(; k <= (num_elems_matrix_b_x - 32); k += 32)
1117 {
1118 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1119 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1120 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1121 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1122 const qint8x8_t a4 = vld1_dup_qs8(mtx_a0 + 4);
1123 const qint8x8_t a5 = vld1_dup_qs8(mtx_a0 + 5);
1124 const qint8x8_t a6 = vld1_dup_qs8(mtx_a0 + 6);
1125 const qint8x8_t a7 = vld1_dup_qs8(mtx_a0 + 7);
1126
1127 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1128 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1129 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1130 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1131
1132 // First accumulation
1133 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1134 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1135 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1136 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1137 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1138 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1139 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1140 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1141
1142 const qint8x8_t b02 = vld1_qs8(mtx_b0 + 16);
1143 const qint8x8_t b03 = vld1_qs8(mtx_b0 + 24);
1144 const qint8x8_t b12 = vld1_qs8(mtx_b1 + 16);
1145 const qint8x8_t b13 = vld1_qs8(mtx_b1 + 24);
1146
1147 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1148 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1149 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1150 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1151 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1152 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1153 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1154 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1155
1156#if __arm__
1157 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
1158 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
1159 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +01001160#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001161
1162 // Second accumulation
1163 acc00_qs16 = vqmlal_qs8(acc00_qs16, b02, a4, fixed_point_position);
1164 acc10_qs16 = vqmlal_qs8(acc10_qs16, b02, a5, fixed_point_position);
1165 acc20_qs16 = vqmlal_qs8(acc20_qs16, b02, a6, fixed_point_position);
1166 acc30_qs16 = vqmlal_qs8(acc30_qs16, b02, a7, fixed_point_position);
1167 acc01_qs16 = vqmlal_qs8(acc01_qs16, b03, a4, fixed_point_position);
1168 acc11_qs16 = vqmlal_qs8(acc11_qs16, b03, a5, fixed_point_position);
1169 acc21_qs16 = vqmlal_qs8(acc21_qs16, b03, a6, fixed_point_position);
1170 acc31_qs16 = vqmlal_qs8(acc31_qs16, b03, a7, fixed_point_position);
1171 acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a4, fixed_point_position);
1172 acc12_qs16 = vqmlal_qs8(acc12_qs16, b12, a5, fixed_point_position);
1173 acc22_qs16 = vqmlal_qs8(acc22_qs16, b12, a6, fixed_point_position);
1174 acc32_qs16 = vqmlal_qs8(acc32_qs16, b12, a7, fixed_point_position);
1175 acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a4, fixed_point_position);
1176 acc13_qs16 = vqmlal_qs8(acc13_qs16, b13, a5, fixed_point_position);
1177 acc23_qs16 = vqmlal_qs8(acc23_qs16, b13, a6, fixed_point_position);
1178 acc33_qs16 = vqmlal_qs8(acc33_qs16, b13, a7, fixed_point_position);
1179
1180 mtx_a0 += 8;
1181 mtx_b0 += 32;
1182 mtx_b1 += 32;
1183 }
1184
1185 // This for loop performs the left over accumulations
1186 for(; k < num_elems_matrix_b_x; k += 16)
1187 {
1188 const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
1189 const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
1190 const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
1191 const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
1192
1193 const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
1194 const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
1195 const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
1196 const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
1197
1198 acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
1199 acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
1200 acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
1201 acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
1202 acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
1203 acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
1204 acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
1205 acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
1206 acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
1207 acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
1208 acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
1209 acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
1210 acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
1211 acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
1212 acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
1213 acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
1214
1215 mtx_a0 += 4;
1216 mtx_b0 += 16;
1217 mtx_b1 += 16;
1218 }
1219
1220 // Convert back to qint8x8_t and saturate
1221 qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
1222 qint8x8_t acc10_qs8 = vqmovn_qs16(acc10_qs16);
1223 qint8x8_t acc20_qs8 = vqmovn_qs16(acc20_qs16);
1224 qint8x8_t acc30_qs8 = vqmovn_qs16(acc30_qs16);
1225
1226 qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
1227 qint8x8_t acc11_qs8 = vqmovn_qs16(acc11_qs16);
1228 qint8x8_t acc21_qs8 = vqmovn_qs16(acc21_qs16);
1229 qint8x8_t acc31_qs8 = vqmovn_qs16(acc31_qs16);
1230
1231 qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
1232 qint8x8_t acc12_qs8 = vqmovn_qs16(acc12_qs16);
1233 qint8x8_t acc22_qs8 = vqmovn_qs16(acc22_qs16);
1234 qint8x8_t acc32_qs8 = vqmovn_qs16(acc32_qs16);
1235
1236 qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
1237 qint8x8_t acc13_qs8 = vqmovn_qs16(acc13_qs16);
1238 qint8x8_t acc23_qs8 = vqmovn_qs16(acc23_qs16);
1239 qint8x8_t acc33_qs8 = vqmovn_qs16(acc33_qs16);
1240
1241 // Multiply by the weight of the matrix product (alpha)
1242 if(multiply_alpha)
1243 {
1244 acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
1245 acc10_qs8 = vqmul_qs8(acc10_qs8, alpha_qs8, fixed_point_position);
1246 acc20_qs8 = vqmul_qs8(acc20_qs8, alpha_qs8, fixed_point_position);
1247 acc30_qs8 = vqmul_qs8(acc30_qs8, alpha_qs8, fixed_point_position);
1248 acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
1249 acc11_qs8 = vqmul_qs8(acc11_qs8, alpha_qs8, fixed_point_position);
1250 acc21_qs8 = vqmul_qs8(acc21_qs8, alpha_qs8, fixed_point_position);
1251 acc31_qs8 = vqmul_qs8(acc31_qs8, alpha_qs8, fixed_point_position);
1252 acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
1253 acc12_qs8 = vqmul_qs8(acc12_qs8, alpha_qs8, fixed_point_position);
1254 acc22_qs8 = vqmul_qs8(acc22_qs8, alpha_qs8, fixed_point_position);
1255 acc32_qs8 = vqmul_qs8(acc32_qs8, alpha_qs8, fixed_point_position);
1256 acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
1257 acc13_qs8 = vqmul_qs8(acc13_qs8, alpha_qs8, fixed_point_position);
1258 acc23_qs8 = vqmul_qs8(acc23_qs8, alpha_qs8, fixed_point_position);
1259 acc33_qs8 = vqmul_qs8(acc33_qs8, alpha_qs8, fixed_point_position);
1260 }
1261
1262 const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
1263
1264 // Store 32x4 output elements
1265 vst1_qs8(mtx_out0 + 0, acc00_qs8);
1266 vst1_qs8(mtx_out0 + 8, acc01_qs8);
1267 vst1_qs8(mtx_out0 + 16, acc02_qs8);
1268 vst1_qs8(mtx_out0 + 24, acc03_qs8);
1269 vst1_qs8(mtx_out0 + out_stride1 + 0, acc10_qs8);
1270 vst1_qs8(mtx_out0 + out_stride1 + 8, acc11_qs8);
1271 vst1_qs8(mtx_out0 + out_stride1 + 16, acc12_qs8);
1272 vst1_qs8(mtx_out0 + out_stride1 + 24, acc13_qs8);
1273 vst1_qs8(mtx_out0 + out_stride2 + 0, acc20_qs8);
1274 vst1_qs8(mtx_out0 + out_stride2 + 8, acc21_qs8);
1275 vst1_qs8(mtx_out0 + out_stride2 + 16, acc22_qs8);
1276 vst1_qs8(mtx_out0 + out_stride2 + 24, acc23_qs8);
1277 vst1_qs8(mtx_out0 + out_stride3 + 0, acc30_qs8);
1278 vst1_qs8(mtx_out0 + out_stride3 + 8, acc31_qs8);
1279 vst1_qs8(mtx_out0 + out_stride3 + 16, acc32_qs8);
1280 vst1_qs8(mtx_out0 + out_stride3 + 24, acc33_qs8);
1281 },
1282 ina, inb, out);
1283}
1284
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001285template <bool multiply_alpha>
1286void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
1287{
1288 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
1289 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
1290 const size_t out_stride2 = out_stride1 * 2;
1291 const size_t out_stride3 = out_stride1 * 3;
1292 const int num_elems_matrix_b_x = input1->info()->dimension(0);
1293 const int fixed_point_position = input0->info()->fixed_point_position();
Georgios Pinitas21efeb42017-07-04 12:47:17 +01001294 const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001295 ARM_COMPUTE_UNUSED(alpha_qs16);
1296
1297 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1298 Window win_a(window);
1299 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1300 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
1301
1302 Window win_b;
1303 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1304 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1305 if(input1->info()->num_dimensions() >= 3)
1306 {
1307 win_b = window;
1308 }
1309 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
1310 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
1311 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1312
1313 Iterator ina(input0, win_a);
1314 Iterator inb(input1, win_b);
1315 Iterator out(output, window);
1316
1317 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
1318 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
1319 // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
1320 execute_window_loop(window, [&](const Coordinates & id)
1321 {
1322 auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
1323 auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
1324 auto mtx_b1 = mtx_b0 + in_b_stride;
1325
1326 qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
1327 qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
1328 qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
1329 qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
1330
1331 qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
1332 qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
1333 qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
1334 qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
1335
1336 // This for loop performs 1 accumulation
1337 for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
1338 {
1339 const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
1340 const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
1341 const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
1342 const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
1343
1344 const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
1345 const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
1346
1347 acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
1348 acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
1349 acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
1350 acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
1351 acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
1352 acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
1353 acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
1354 acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
1355
1356 mtx_a0 += 4;
1357 mtx_b0 += 8;
1358 mtx_b1 += 8;
1359 }
1360
1361 // Convert back to qint16x4_t and saturate
1362 qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
1363 qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
1364 qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
1365 qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
1366
1367 qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
1368 qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
1369 qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
1370 qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
1371
1372 // Multiply by the weight of the matrix product (alpha)
1373 if(multiply_alpha)
1374 {
1375 acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
1376 acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
1377 acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
1378 acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
1379 acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
1380 acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
1381 acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
1382 acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
1383 }
1384
1385 const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
1386
1387 // Store 8x4 output elements
1388 vst1_qs16(mtx_out0 + 0, acc00_qs16);
1389 vst1_qs16(mtx_out0 + 4, acc01_qs16);
1390 vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
1391 vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
1392 vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
1393 vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
1394 vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
1395 vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
1396 },
1397 ina, inb, out);
1398}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001399} // namespace
1400
1401NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
1402 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1403{
1404}
1405
1406void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
1407{
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001408 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001409 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
1410 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
1411
1412 if(output->info()->dimension(1) == 1)
1413 {
1414 ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
1415 }
1416
1417 _input0 = input0;
1418 _input1 = input1;
1419 _output = output;
1420 _alpha = alpha;
1421
1422 unsigned int num_elems_processed_per_iteration_x = 0;
1423 const unsigned int num_elems_processed_per_iteration_y = 4;
1424
1425 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1426 if((output->info()->dimension(1) == 1))
1427 {
1428 switch(input0->info()->data_type())
1429 {
1430 case DataType::F32:
1431 {
1432 num_elems_processed_per_iteration_x = 16;
1433 break;
1434 }
1435 case DataType::QS8:
1436 {
1437 num_elems_processed_per_iteration_x = 32;
1438 break;
1439 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001440 case DataType::QS16:
1441 {
1442 num_elems_processed_per_iteration_x = 16;
1443 break;
1444 }
Pablo Tello221f3812017-06-28 17:27:56 +01001445#ifdef ARM_COMPUTE_ENABLE_FP16
1446 case DataType::F16:
1447 {
1448 num_elems_processed_per_iteration_x = 32;
1449 break;
1450 }
1451#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001452 default:
1453 {
1454 ARM_COMPUTE_ERROR("Data type not supported");
1455 break;
1456 }
1457 }
1458
1459 // Configure kernel window
1460 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1461
1462 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
1463
1464 update_window_and_padding(win,
1465 AccessWindowHorizontal(input0->info(), 0, num_elems_processed_per_iteration_x),
1466 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration_x),
1467 output_access);
1468
1469 Coordinates coord;
1470 coord.set_num_dimensions(output->info()->num_dimensions());
1471 output_access.set_valid_region(win, ValidRegion(coord, output->info()->tensor_shape()));
1472
1473 INEKernel::configure(win);
1474 }
1475 else
1476 {
1477 switch(input0->info()->data_type())
1478 {
1479 case DataType::F32:
1480 {
1481 num_elems_processed_per_iteration_x = 8;
1482 break;
1483 }
1484 case DataType::QS8:
1485 {
1486 num_elems_processed_per_iteration_x = 32;
1487 break;
1488 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001489 case DataType::QS16:
1490 {
1491 num_elems_processed_per_iteration_x = 8;
1492 break;
1493 }
Pablo Tello221f3812017-06-28 17:27:56 +01001494#ifdef ARM_COMPUTE_ENABLE_FP16
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001495 case DataType::F16:
1496 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001497 num_elems_processed_per_iteration_x = 8;
1498 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001499 }
Pablo Tello221f3812017-06-28 17:27:56 +01001500#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001501 default:
1502 {
1503 ARM_COMPUTE_ERROR("Data type not supported");
1504 break;
1505 }
1506 }
1507
1508 // Configure kernel window
1509 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1510
1511 AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
1512
1513 update_window_and_padding(win,
1514 AccessWindowRectangle(input0->info(), 0, 0, 4, 1, 1.f, 0.25f),
1515 AccessWindowTranspose(input1->info(), 0, 0, 4, 1, 0.f, 0.25f),
1516 output_access);
1517
1518 output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
1519
1520 INEKernel::configure(win);
1521 }
1522}
1523
1524void NEGEMMMatrixMultiplyKernel::run(const Window &window)
1525{
1526 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1527 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1528
1529 bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
1530
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001531 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001532 if((_output->info()->dimension(1) == 1))
1533 {
1534 switch(_input0->info()->data_type())
1535 {
1536 case DataType::F32:
1537 {
1538 multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1539 vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1540 break;
1541 }
1542 case DataType::QS8:
1543 {
1544 multiply_alpha ? vector_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1545 vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1546 break;
1547 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001548 case DataType::QS16:
1549 {
1550 multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
1551 vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
1552 break;
1553 }
Pablo Tello221f3812017-06-28 17:27:56 +01001554#ifdef ARM_COMPUTE_ENABLE_FP16
1555 case DataType::F16:
1556 {
1557 multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1558 vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1559 break;
1560 }
1561#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001562 default:
1563 {
1564 ARM_COMPUTE_ERROR("Data type not supported");
1565 break;
1566 }
1567 }
1568 }
1569 else
1570 {
1571 switch(_input0->info()->data_type())
1572 {
1573 case DataType::F32:
1574 {
1575 multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1576 matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1577 break;
1578 }
1579 case DataType::QS8:
1580 {
1581 multiply_alpha ? matrix_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
1582 matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
1583 break;
1584 }
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001585 case DataType::QS16:
1586 {
1587 multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
1588 matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
1589 break;
1590 }
Pablo Tello221f3812017-06-28 17:27:56 +01001591#ifdef ARM_COMPUTE_ENABLE_FP16
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001592 case DataType::F16:
1593 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001594 multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1595 matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1596 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001597 }
Pablo Tello221f3812017-06-28 17:27:56 +01001598#endif /* ARM_COMPUTE_ENABLE_FP16 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001599 default:
1600 {
1601 ARM_COMPUTE_ERROR("Data type not supported");
1602 break;
1603 }
1604 }
1605 }
1606}