blob: 300dc3ffc7ace05ef722b79989beba6ecd5a8c4c [file] [log] [blame]
Dana Zlotnik256ac622022-02-02 15:06:11 +02001/*
2 * Copyright (c) 2017-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h"
26#include "src/core/utils/helpers/float_ops.h"
27
28#include <arm_neon.h>
29
30namespace arm_compute
31{
32namespace cpu
33{
34#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
35void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
36{
37 const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
38 const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
39 const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
40
41 // The implementation computes 32 elements per iteration
42 const int window_start_x = 32 * info.thread_id;
43 const int window_step_x = 32 * info.num_threads;
44 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
45 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
46
47 Window win_out(window);
48 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
49 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
50
51 Window win_a(window);
52 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
53 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
54
55 Window win_b;
56 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
57 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
58 if(rhs->info()->num_dimensions() >= 3)
59 {
60 win_b = window;
61 }
62 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
63 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
64
65 Iterator ina(lhs, win_a);
66 Iterator inb(rhs, win_b);
67 Iterator out(dst, win_out);
68
69 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
70
71 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
72
73 execute_window_loop(win_out, [&](const Coordinates &)
74 {
75 int x = window_start_x;
76 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
77 // window_end_x is computed above which may cause out-of-bound writes to the dst.
78 for(; x < (window_end_x - window_step_x); x += window_step_x)
79 {
80 if(x > width_matrix_b)
81 {
82 return;
83 }
84
85 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
86
87 float16x8_t acc0 = vdupq_n_f16(0.f);
88 float16x8_t acc1 = vdupq_n_f16(0.f);
89 float16x8_t acc2 = vdupq_n_f16(0.f);
90 float16x8_t acc3 = vdupq_n_f16(0.f);
91
92 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
93 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
94 for(; vec_a <= (vec_a_end_addr - 4);)
95 {
96 const float16x4_t a0l = vld1_f16(vec_a);
97
98 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
99 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
100 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
101 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
102 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
103 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
104 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
105 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
106
107 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
108 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
109 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
110 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
111 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
112 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
113 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
114 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
115
116 matrix_b += 2 * in_b_stride;
117
118 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
119 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
120 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
121 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
122 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
123 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
124 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
125 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
126
127 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
128 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
129 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
130 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
131 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
132 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
133 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
134 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
135
136 vec_a += 4;
137 matrix_b += 2 * in_b_stride;
138 }
139
140 for(; vec_a < vec_a_end_addr; ++vec_a)
141 {
142 const float16_t a0 = *vec_a;
143 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
144 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
145 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
146 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
147
148 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
149 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
150 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
151 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
152
153 matrix_b += in_b_stride;
154 }
155
156 // Multiply by the weight of matrix product (alpha)
157 if(multiply_alpha)
158 {
159 acc0 = vmulq_f16(acc0, alpha_f16);
160 acc1 = vmulq_f16(acc1, alpha_f16);
161 acc2 = vmulq_f16(acc2, alpha_f16);
162 acc3 = vmulq_f16(acc3, alpha_f16);
163 }
164
165 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
166
167 vst1q_f16(vec_out + 0, acc0);
168 vst1q_f16(vec_out + 8, acc1);
169 vst1q_f16(vec_out + 16, acc2);
170 vst1q_f16(vec_out + 24, acc3);
171 }
172
173 for(; x < window_end_x; ++x)
174 {
175 if(x > width_matrix_b)
176 {
177 return;
178 }
179
180 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
181
182 float16x4_t vacc = vdup_n_f16(0.f);
183
184 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
185 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
186 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
187 {
188 const float16x4_t a0l = vld1_f16(vec_a);
189
190 const float16x4_t b_col =
191 {
192 *(matrix_b + 0 * in_b_stride),
193 *(matrix_b + 1 * in_b_stride),
194 *(matrix_b + 2 * in_b_stride),
195 *(matrix_b + 3 * in_b_stride),
196 };
197
198 vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
199
200 matrix_b += 4 * in_b_stride;
201 }
202
203 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
204
205 for(; vec_a < vec_a_end_addr; ++vec_a)
206 {
207 const float16_t a0 = *vec_a;
208 const float16_t b00 = *matrix_b;
209
210 acc += b00 * a0;
211
212 matrix_b += in_b_stride;
213 }
214
215 // Multiply by the weight of matrix product (alpha)
216 if(multiply_alpha)
217 {
218 acc *= static_cast<float16_t>(alpha);
219 }
220
221 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
222
223 *(vec_out) = acc;
224 }
225 },
226 ina, inb, out);
227}
228#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
229
230void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
231{
232 const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
233 const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()));
234 const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
235
236 // The implementation computes 16 elements per iteration
237 const int window_start_x = 16 * info.thread_id;
238 const int window_step_x = 16 * info.num_threads;
239 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
240 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
241
242 Window win_out(window);
243 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
244 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
245
246 Window win_a(window);
247 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
248 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
249
250 Window win_b;
251 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
252 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
253 if(rhs->info()->num_dimensions() >= 3)
254 {
255 win_b = window;
256 }
257 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
258 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
259
260 Iterator ina(lhs, win_a);
261 Iterator inb(rhs, win_b);
262 Iterator out(dst, win_out);
263
264 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
265
266 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
267
268 execute_window_loop(win_out, [&](const Coordinates &)
269 {
270 int x = window_start_x;
271 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
272 // window_end_x is computed above which may cause out-of-bound writes to the dst.
273 for(; x < (window_end_x - window_step_x); x += window_step_x)
274 {
275 if(x > width_matrix_b)
276 {
277 return;
278 }
279
280 float32x4_t acc0 = vdupq_n_f32(0.f);
281 float32x4_t acc1 = vdupq_n_f32(0.f);
282 float32x4_t acc2 = vdupq_n_f32(0.f);
283 float32x4_t acc3 = vdupq_n_f32(0.f);
284
285 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
286 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
287
288#if __arm__
289 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
290 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
291 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
292#endif /* __arm__ */
293
294 auto vec_a_end_addr = vec_a + num_elems_vec_a;
295 for(; vec_a <= (vec_a_end_addr - 4);)
296 {
297 float32x2_t a0l = vld1_f32(vec_a);
298
299 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
300 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
301 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
302 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
303
304 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
305 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
306 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
307 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
308
309#if __arm__
310 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
311 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
312 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
313 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
314 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
315#endif /* __arm__ */
316
317 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
318 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
319 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
320 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
321
322 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
323 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
324 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
325 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
326
327 vec_a += 2;
328 matrix_b += 2 * in_b_stride;
329
330 a0l = vld1_f32(vec_a);
331
332 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
333 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
334 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
335 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
336
337 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
338 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
339 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
340 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
341
342 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
343 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
344 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
345 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
346
347 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
348 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
349 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
350 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
351
352 vec_a += 2;
353 matrix_b += 2 * in_b_stride;
354 }
355
356 for(; vec_a < vec_a_end_addr; ++vec_a)
357 {
358 const float a0 = *vec_a;
359
360 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
361 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
362 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
363 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
364
365 acc0 = vmlaq_n_f32(acc0, b00, a0);
366 acc1 = vmlaq_n_f32(acc1, b01, a0);
367 acc2 = vmlaq_n_f32(acc2, b02, a0);
368 acc3 = vmlaq_n_f32(acc3, b03, a0);
369
370 matrix_b += in_b_stride;
371 }
372
373 // Multiply by the weight of matrix product (alpha)
374 if(multiply_alpha)
375 {
376 acc0 = vmulq_f32(acc0, alpha_f32);
377 acc1 = vmulq_f32(acc1, alpha_f32);
378 acc2 = vmulq_f32(acc2, alpha_f32);
379 acc3 = vmulq_f32(acc3, alpha_f32);
380 }
381
382 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
383
384 vst1q_f32(vec_out + 0, acc0);
385 vst1q_f32(vec_out + 4, acc1);
386 vst1q_f32(vec_out + 8, acc2);
387 vst1q_f32(vec_out + 12, acc3);
388 }
389
390 // Left-over loop
391 for(; x < window_end_x; ++x)
392 {
393 if(x > width_matrix_b)
394 {
395 return;
396 }
397
398 float32x4_t vacc = vdupq_n_f32(0.f);
399
400 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
401 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
402
403#if __arm__
404 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
405 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
406 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
407#endif /* __arm__ */
408
409 auto vec_a_end_addr = vec_a + num_elems_vec_a;
410 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
411 {
412 const float32x4_t a0l = vld1q_f32(vec_a);
413
414 const float32x4_t b_col =
415 {
416 *(matrix_b + 0 * in_b_stride),
417 *(matrix_b + 1 * in_b_stride),
418 *(matrix_b + 2 * in_b_stride),
419 *(matrix_b + 3 * in_b_stride),
420 };
421
422#if __arm__
423 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
424 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
425 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
426 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
427 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
428#endif /* __arm__ */
429
430 vacc = vmlaq_f32(vacc, b_col, a0l);
431
432 matrix_b += 4 * in_b_stride;
433 }
434
435 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
436
437 for(; vec_a < vec_a_end_addr; ++vec_a)
438 {
439 const float a0 = *vec_a;
440
441 const float b00 = *matrix_b;
442
443 acc += b00 * a0;
444
445 matrix_b += in_b_stride;
446 }
447
448 // Multiply by the weight of matrix product (alpha)
449 if(multiply_alpha)
450 {
451 acc *= alpha;
452 }
453
454 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
455
456 *vec_out = acc;
457 }
458 },
459 ina, inb, out);
460}
461
462void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
463{
464 ARM_COMPUTE_UNUSED(info);
465 const int out_width = static_cast<int>(dst->info()->dimension(0));
466 const int out_height = static_cast<int>(dst->info()->dimension(1));
467 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
468 const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
469 const size_t out_stride2 = out_stride1 * 2;
470 const size_t out_stride3 = out_stride1 * 3;
471 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
472
473 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
474 Window win_a(window);
475 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
476 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
477
478 Window win_b;
479 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
480 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
481 if(rhs->info()->num_dimensions() >= 3)
482 {
483 win_b = window;
484 }
485 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix
486 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
487 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
488 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
489
490 Iterator ina(lhs, win_a);
491 Iterator inb(rhs, win_b);
492 Iterator out(dst, window);
493
494 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
495
496 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
497
498 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
499 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
500 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
501 execute_window_loop(window, [&](const Coordinates & id)
502 {
503 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
504 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
505 auto mtx_b1 = mtx_b0 + in_b_stride;
506
507 float32x4_t acc00 = vdupq_n_f32(0.f);
508 float32x4_t acc10 = vdupq_n_f32(0.f);
509 float32x4_t acc20 = vdupq_n_f32(0.f);
510 float32x4_t acc30 = vdupq_n_f32(0.f);
511
512 float32x4_t acc01 = vdupq_n_f32(0.f);
513 float32x4_t acc11 = vdupq_n_f32(0.f);
514 float32x4_t acc21 = vdupq_n_f32(0.f);
515 float32x4_t acc31 = vdupq_n_f32(0.f);
516
517#if __arm__
518 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
519 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
520 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
521#endif /* __arm__ */
522
523 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
524 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
525 {
526 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
527 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
528 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
529 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
530
531 float32x4_t b00 = vld1q_f32(mtx_b0);
532 float32x4_t b10 = vld1q_f32(mtx_b1);
533 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
534 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
535
536#if __arm__
537 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
538 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
539 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
540#endif /* __arm__ */
541
542 // 4x4 block 0
543 acc00 = vmlaq_f32(acc00, b00, a0);
544 acc10 = vmlaq_f32(acc10, b00, a1);
545 acc20 = vmlaq_f32(acc20, b00, a2);
546 acc30 = vmlaq_f32(acc30, b00, a3);
547
548 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
549 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
550 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
551 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
552
553 // 4x4 block 1
554 acc01 = vmlaq_f32(acc01, b10, a0);
555 acc11 = vmlaq_f32(acc11, b10, a1);
556 acc21 = vmlaq_f32(acc21, b10, a2);
557 acc31 = vmlaq_f32(acc31, b10, a3);
558
559 // 4x4 block 0
560 acc00 = vmlaq_f32(acc00, b01, a4);
561 acc10 = vmlaq_f32(acc10, b01, a5);
562 acc20 = vmlaq_f32(acc20, b01, a6);
563 acc30 = vmlaq_f32(acc30, b01, a7);
564
565 // 4x4 block 1
566 acc01 = vmlaq_f32(acc01, b11, a4);
567 acc11 = vmlaq_f32(acc11, b11, a5);
568 acc21 = vmlaq_f32(acc21, b11, a6);
569 acc31 = vmlaq_f32(acc31, b11, a7);
570
571 mtx_a0 += 8;
572 mtx_b0 += 8;
573 mtx_b1 += 8;
574
575 a0 = vld1q_dup_f32(mtx_a0 + 0);
576 a1 = vld1q_dup_f32(mtx_a0 + 1);
577 a2 = vld1q_dup_f32(mtx_a0 + 2);
578 a3 = vld1q_dup_f32(mtx_a0 + 3);
579
580 b00 = vld1q_f32(mtx_b0);
581 b10 = vld1q_f32(mtx_b1);
582 b01 = vld1q_f32(mtx_b0 + 4);
583 b11 = vld1q_f32(mtx_b1 + 4);
584
585 // 4x4 block 0
586 acc00 = vmlaq_f32(acc00, b00, a0);
587 acc10 = vmlaq_f32(acc10, b00, a1);
588 acc20 = vmlaq_f32(acc20, b00, a2);
589 acc30 = vmlaq_f32(acc30, b00, a3);
590
591 a4 = vld1q_dup_f32(mtx_a0 + 4);
592 a5 = vld1q_dup_f32(mtx_a0 + 5);
593 a6 = vld1q_dup_f32(mtx_a0 + 6);
594 a7 = vld1q_dup_f32(mtx_a0 + 7);
595
596 // 4x4 block 1
597 acc01 = vmlaq_f32(acc01, b10, a0);
598 acc11 = vmlaq_f32(acc11, b10, a1);
599 acc21 = vmlaq_f32(acc21, b10, a2);
600 acc31 = vmlaq_f32(acc31, b10, a3);
601
602 // 4x4 block 0
603 acc00 = vmlaq_f32(acc00, b01, a4);
604 acc10 = vmlaq_f32(acc10, b01, a5);
605 acc20 = vmlaq_f32(acc20, b01, a6);
606 acc30 = vmlaq_f32(acc30, b01, a7);
607
608 // 4x4 block 1
609 acc01 = vmlaq_f32(acc01, b11, a4);
610 acc11 = vmlaq_f32(acc11, b11, a5);
611 acc21 = vmlaq_f32(acc21, b11, a6);
612 acc31 = vmlaq_f32(acc31, b11, a7);
613
614 mtx_a0 += 8;
615 mtx_b0 += 8;
616 mtx_b1 += 8;
617
618 a0 = vld1q_dup_f32(mtx_a0 + 0);
619 a1 = vld1q_dup_f32(mtx_a0 + 1);
620 a2 = vld1q_dup_f32(mtx_a0 + 2);
621 a3 = vld1q_dup_f32(mtx_a0 + 3);
622 b00 = vld1q_f32(mtx_b0);
623 b10 = vld1q_f32(mtx_b1);
624 b01 = vld1q_f32(mtx_b0 + 4);
625 b11 = vld1q_f32(mtx_b1 + 4);
626
627#if __arm__
628 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
629 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
630 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
631#endif /* __arm__ */
632
633 // 4x4 block 0
634 acc00 = vmlaq_f32(acc00, b00, a0);
635 acc10 = vmlaq_f32(acc10, b00, a1);
636 acc20 = vmlaq_f32(acc20, b00, a2);
637 acc30 = vmlaq_f32(acc30, b00, a3);
638
639 a4 = vld1q_dup_f32(mtx_a0 + 4);
640 a5 = vld1q_dup_f32(mtx_a0 + 5);
641 a6 = vld1q_dup_f32(mtx_a0 + 6);
642 a7 = vld1q_dup_f32(mtx_a0 + 7);
643
644 // 4x4 block 1
645 acc01 = vmlaq_f32(acc01, b10, a0);
646 acc11 = vmlaq_f32(acc11, b10, a1);
647 acc21 = vmlaq_f32(acc21, b10, a2);
648 acc31 = vmlaq_f32(acc31, b10, a3);
649
650 // 4x4 block 0
651 acc00 = vmlaq_f32(acc00, b01, a4);
652 acc10 = vmlaq_f32(acc10, b01, a5);
653 acc20 = vmlaq_f32(acc20, b01, a6);
654 acc30 = vmlaq_f32(acc30, b01, a7);
655
656 // 4x4 block 1
657 acc01 = vmlaq_f32(acc01, b11, a4);
658 acc11 = vmlaq_f32(acc11, b11, a5);
659 acc21 = vmlaq_f32(acc21, b11, a6);
660 acc31 = vmlaq_f32(acc31, b11, a7);
661
662 mtx_a0 += 8;
663 mtx_b0 += 8;
664 mtx_b1 += 8;
665
666 a0 = vld1q_dup_f32(mtx_a0 + 0);
667 a1 = vld1q_dup_f32(mtx_a0 + 1);
668 a2 = vld1q_dup_f32(mtx_a0 + 2);
669 a3 = vld1q_dup_f32(mtx_a0 + 3);
670 b00 = vld1q_f32(mtx_b0);
671 b10 = vld1q_f32(mtx_b1);
672 b01 = vld1q_f32(mtx_b0 + 4);
673 b11 = vld1q_f32(mtx_b1 + 4);
674
675 // 4x4 block 0
676 acc00 = vmlaq_f32(acc00, b00, a0);
677 acc10 = vmlaq_f32(acc10, b00, a1);
678 acc20 = vmlaq_f32(acc20, b00, a2);
679 acc30 = vmlaq_f32(acc30, b00, a3);
680
681 a4 = vld1q_dup_f32(mtx_a0 + 4);
682 a5 = vld1q_dup_f32(mtx_a0 + 5);
683 a6 = vld1q_dup_f32(mtx_a0 + 6);
684 a7 = vld1q_dup_f32(mtx_a0 + 7);
685
686 // 4x4 block 1
687 acc01 = vmlaq_f32(acc01, b10, a0);
688 acc11 = vmlaq_f32(acc11, b10, a1);
689 acc21 = vmlaq_f32(acc21, b10, a2);
690 acc31 = vmlaq_f32(acc31, b10, a3);
691
692 // 4x4 block 0
693 acc00 = vmlaq_f32(acc00, b01, a4);
694 acc10 = vmlaq_f32(acc10, b01, a5);
695 acc20 = vmlaq_f32(acc20, b01, a6);
696 acc30 = vmlaq_f32(acc30, b01, a7);
697
698 // 4x4 block 1
699 acc01 = vmlaq_f32(acc01, b11, a4);
700 acc11 = vmlaq_f32(acc11, b11, a5);
701 acc21 = vmlaq_f32(acc21, b11, a6);
702 acc31 = vmlaq_f32(acc31, b11, a7);
703
704 mtx_a0 += 8;
705 mtx_b0 += 8;
706 mtx_b1 += 8;
707 }
708
709 for(; mtx_b0 < mtx_b0_end_addr;)
710 {
711 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
712 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
713 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
714 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
715 float32x4_t b00 = vld1q_f32(mtx_b0);
716 float32x4_t b10 = vld1q_f32(mtx_b1);
717
718#if __arm__
719 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
720 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
721 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
722#endif /* __arm__ */
723 // 4x4 block 0
724 acc00 = vmlaq_f32(acc00, b00, a0);
725 acc10 = vmlaq_f32(acc10, b00, a1);
726 acc20 = vmlaq_f32(acc20, b00, a2);
727 acc30 = vmlaq_f32(acc30, b00, a3);
728
729 // 4x4 block 1
730 acc01 = vmlaq_f32(acc01, b10, a0);
731 acc11 = vmlaq_f32(acc11, b10, a1);
732 acc21 = vmlaq_f32(acc21, b10, a2);
733 acc31 = vmlaq_f32(acc31, b10, a3);
734
735 mtx_a0 += 4;
736 mtx_b0 += 4;
737 mtx_b1 += 4;
738 }
739
740 // Multiply by the weight of matrix product (alpha)
741 if(multiply_alpha)
742 {
743 acc00 = vmulq_f32(acc00, alpha_f32);
744 acc10 = vmulq_f32(acc10, alpha_f32);
745 acc20 = vmulq_f32(acc20, alpha_f32);
746 acc30 = vmulq_f32(acc30, alpha_f32);
747 acc01 = vmulq_f32(acc01, alpha_f32);
748 acc11 = vmulq_f32(acc11, alpha_f32);
749 acc21 = vmulq_f32(acc21, alpha_f32);
750 acc31 = vmulq_f32(acc31, alpha_f32);
751 }
752
753 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
754 const auto mtx_out1 = mtx_out0 + 4;
755
756 if(id.x() < (out_width - 8))
757 {
758 vst1q_f32(mtx_out0, acc00);
759 vst1q_f32(mtx_out1, acc01);
760 if(id.y() + 1 < out_height)
761 {
762 vst1q_f32(mtx_out0 + out_stride1, acc10);
763 vst1q_f32(mtx_out1 + out_stride1, acc11);
764 if(id.y() + 2 < out_height)
765 {
766 vst1q_f32(mtx_out0 + out_stride2, acc20);
767 vst1q_f32(mtx_out1 + out_stride2, acc21);
768 if(id.y() + 3 < out_height)
769 {
770 vst1q_f32(mtx_out0 + out_stride3, acc30);
771 vst1q_f32(mtx_out1 + out_stride3, acc31);
772 }
773 }
774 }
775 }
776 else if(id.x() < (out_width - 4))
777 {
778 vst1q_f32(mtx_out0, acc00);
779 if(id.y() + 1 < out_height)
780 {
781 vst1q_f32(mtx_out0 + out_stride1, acc10);
782 if(id.y() + 2 < out_height)
783 {
784 vst1q_f32(mtx_out0 + out_stride2, acc20);
785 if(id.y() + 3 < out_height)
786 {
787 vst1q_f32(mtx_out0 + out_stride3, acc30);
788 }
789 }
790 }
791 // Left-over columns
792 const int columns_left = out_width - id.x() - 4;
793 for(auto x = 0; x < columns_left; ++x)
794 {
795 *(mtx_out1 + x) = acc01[x];
796 if(id.y() + 1 < out_height)
797 {
798 *(mtx_out1 + x + out_stride1) = acc11[x];
799 if(id.y() + 2 < out_height)
800 {
801 *(mtx_out1 + x + out_stride2) = acc21[x];
802 if(id.y() + 3 < out_height)
803 {
804 *(mtx_out1 + x + out_stride3) = acc31[x];
805 }
806 }
807 }
808 }
809 }
810 else
811 {
812 // Left-over columns
813 const int columns_left = out_width - id.x();
814 for(int x = 0; x < columns_left; ++x)
815 {
816 *(mtx_out0 + x) = acc00[x];
817 if(id.y() + 1 < out_height)
818 {
819 *(mtx_out0 + x + out_stride1) = acc10[x];
820 if(id.y() + 2 < out_height)
821 {
822 *(mtx_out0 + x + out_stride2) = acc20[x];
823 if(id.y() + 3 < out_height)
824 {
825 *(mtx_out0 + x + out_stride3) = acc30[x];
826 }
827 }
828 }
829 }
830 }
831 },
832 ina, inb, out);
833}
834
835#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
836void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
837{
838 ARM_COMPUTE_UNUSED(info);
839 const int out_width = static_cast<int>(dst->info()->dimension(0));
840 const int out_height = static_cast<int>(dst->info()->dimension(1));
841 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
842 const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
843 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
844
845 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
846 Window win_a(window);
847 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
848 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
849
850 Window win_b;
851 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
852 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
853 if(rhs->info()->num_dimensions() >= 3)
854 {
855 win_b = window;
856 }
857 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix
858 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
Adnan AlSinan304dfdb2022-09-21 13:20:45 +0100859 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
Dana Zlotnik256ac622022-02-02 15:06:11 +0200860
861 Iterator ina(lhs, win_a);
862 Iterator inb(rhs, win_b);
863 Iterator out(dst, window);
864
865 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
866
867 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
868
869 execute_window_loop(window, [&](const Coordinates & id)
870 {
871 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
872 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
873 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
874 float16x8x4_t c =
875 {
876 {
877 vdupq_n_f16(0.f),
878 vdupq_n_f16(0.f),
879 vdupq_n_f16(0.f),
880 vdupq_n_f16(0.f)
881 }
882 };
883
884 /*
885 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
886 |a00 a01 a02 a03 | a04 a05 a06 a07|
887 |a10 a11 a12 a13 | a14 a15 a16 a17|
888 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
889 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
890 |a40 a41 a42 a43 | a44 a45 a46 a47|
891 |a50 a51 a52 a53 | a54 a55 a56 a57|
892 |a60 a61 a62 a63 | a64 a65 a66 a67|
893 |a70 a71 a72 a73 | a74 a75 a76 a77|
894
895 After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ]
896
897 B Matrix has been transposed as shown below
898
899 |b00 b01 b02 b03 b04 b05 b06 b07|
900 |b10 b11 b12 b13 b14 b15 b16 b17|
901 |b20 b21 b22 b23 b24 b25 b26 b27|
902 |b30 b31 b32 b33 b34 b35 b36 b37|
903 ------------------->
904
905 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
906
907 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
908 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
909
910 The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
911 */
912 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
913
914 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
915
916 {
917 const float16x8_t p00 = vld1q_f16(mtx_a0);
918 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
919
920 const float16x8_t q00 = vld1q_f16(mtx_b0);
921 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
922 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
923 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
924
925 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
926 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
927 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
928 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
929
930 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
931 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
932 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
933 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
934
935 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
936 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
937 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
938 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
939
940 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
941 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
942 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
943 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
944
945 mtx_a0 += 16;
946 mtx_b0 += 32;
947 }
948
949 for(; mtx_b0 < mtx_b0_end_addr;)
950
951 {
952 const float16x4_t p00 = vld1_f16(mtx_a0);
953 const float16x8_t q00 = vld1q_f16(mtx_b0);
954
955 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
956 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
957 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
958 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
959
960 mtx_a0 += 4;
961 mtx_b0 += 8;
962 }
963
964 if(multiply_alpha)
965 {
966 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
967 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
968 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
969 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
970 }
971
972 if(id.x() < (out_width - 8))
973 {
974 vst1q_f16(mtx_out, c.val[0]);
975 if(id.y() + 1 < out_height)
976 {
977 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
978 if(id.y() + 2 < out_height)
979 {
980 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
981 if(id.y() + 3 < out_height)
982 {
983 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
984 }
985 }
986 }
987 }
988 else
989 {
990 // Left-over columns
991 const int columns_left = out_width - id.x();
992 for(int x = 0; x < columns_left; ++x)
993 {
994 *(mtx_out + x) = c.val[0][x];
995 if(id.y() + 1 < out_height)
996 {
997 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
998 if(id.y() + 2 < out_height)
999 {
1000 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1001 if(id.y() + 3 < out_height)
1002 {
1003 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1004 }
1005 }
1006 }
1007 }
1008 }
1009 },
1010 ina, inb, out);
1011}
1012#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1013
1014} // namespace cpu
1015
1016} // namespace arm_compute