blob: 6f74e3fc06821b11bff81b8c6a487750faa6f8a5 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitasddb93bb2020-10-02 16:38:59 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Types.h"
32#include "arm_compute/core/Utils.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
Gian Marco Iodice82d9dd12019-06-10 16:45:40 +010035#include "arm_compute/core/utils/helpers/float_ops.h"
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010037#include "src/core/NEON/NEFixedPoint.h"
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000038
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040
41namespace arm_compute
42{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043namespace
44{
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010045#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Moritz Pflanzerc186b572017-09-07 09:48:04 +010046void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Pablo Tello221f3812017-06-28 17:27:56 +010047{
Pablo Tello221f3812017-06-28 17:27:56 +010048 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010049 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / input1->info()->element_size());
Pablo Tello221f3812017-06-28 17:27:56 +010050 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
51
52 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +010053 const int window_start_x = 32 * info.thread_id;
54 const int window_step_x = 32 * info.num_threads;
Pablo Tello221f3812017-06-28 17:27:56 +010055 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
56 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
57
58 Window win_out(window);
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010059 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
Pablo Tello221f3812017-06-28 17:27:56 +010060 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
61
62 Window win_a(window);
63 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
64 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
65
66 Window win_b;
67 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
68 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
69 if(input1->info()->num_dimensions() >= 3)
70 {
71 win_b = window;
72 }
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010073 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
Pablo Tello221f3812017-06-28 17:27:56 +010074 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
75
76 Iterator ina(input0, win_a);
77 Iterator inb(input1, win_b);
78 Iterator out(output, win_out);
79
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010080 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
81
Pablo Tello221f3812017-06-28 17:27:56 +010082 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
Pablo Tello221f3812017-06-28 17:27:56 +010083
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010084 execute_window_loop(win_out, [&](const Coordinates &)
Pablo Tello221f3812017-06-28 17:27:56 +010085 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010086 int x = window_start_x;
87 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
88 // window_end_x is computed above which may cause out-of-bound writes to the output.
89 for(; x < (window_end_x - window_step_x); x += window_step_x)
Pablo Tello221f3812017-06-28 17:27:56 +010090 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010091 if(x > width_matrix_b)
92 {
93 return;
94 }
95
96 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
97
98 float16x8_t acc0 = vdupq_n_f16(0.f);
99 float16x8_t acc1 = vdupq_n_f16(0.f);
100 float16x8_t acc2 = vdupq_n_f16(0.f);
101 float16x8_t acc3 = vdupq_n_f16(0.f);
102
103 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
104 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
105 for(; vec_a <= (vec_a_end_addr - 4);)
106 {
107 const float16x4_t a0l = vld1_f16(vec_a);
108
109 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
110 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
111 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
112 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
113 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
114 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
115 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
116 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
117
118 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
119 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
120 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
121 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
122 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
123 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
124 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
125 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
126
127 matrix_b += 2 * in_b_stride;
128
129 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
130 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
131 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
132 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
133 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
134 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
135 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
136 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
137
138 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
139 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
140 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
141 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
142 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
143 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
144 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
145 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
146
147 vec_a += 4;
148 matrix_b += 2 * in_b_stride;
149 }
150
151 for(; vec_a < vec_a_end_addr; ++vec_a)
152 {
153 const float16_t a0 = *vec_a;
154 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
155 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
156 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
157 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
158
159 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
160 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
161 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
162 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
163
164 matrix_b += in_b_stride;
165 }
166
167 // Multiply by the weight of matrix product (alpha)
168 if(multiply_alpha)
169 {
170 acc0 = vmulq_f16(acc0, alpha_f16);
171 acc1 = vmulq_f16(acc1, alpha_f16);
172 acc2 = vmulq_f16(acc2, alpha_f16);
173 acc3 = vmulq_f16(acc3, alpha_f16);
174 }
175
176 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
177
178 vst1q_f16(vec_out + 0, acc0);
179 vst1q_f16(vec_out + 8, acc1);
180 vst1q_f16(vec_out + 16, acc2);
181 vst1q_f16(vec_out + 24, acc3);
Pablo Tello221f3812017-06-28 17:27:56 +0100182 }
183
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100184 for(; x < window_end_x; ++x)
Pablo Tello221f3812017-06-28 17:27:56 +0100185 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100186 if(x > width_matrix_b)
187 {
188 return;
189 }
Pablo Tello221f3812017-06-28 17:27:56 +0100190
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100191 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
Pablo Tello221f3812017-06-28 17:27:56 +0100192
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100193 float16x4_t vacc = vdup_n_f16(0.f);
Pablo Tello221f3812017-06-28 17:27:56 +0100194
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100195 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
196 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
197 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
198 {
199 const float16x4_t a0l = vld1_f16(vec_a);
Pablo Tello221f3812017-06-28 17:27:56 +0100200
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100201 const float16x4_t b_col =
202 {
203 *(matrix_b + 0 * in_b_stride),
204 *(matrix_b + 1 * in_b_stride),
205 *(matrix_b + 2 * in_b_stride),
206 *(matrix_b + 3 * in_b_stride),
207 };
Pablo Tello221f3812017-06-28 17:27:56 +0100208
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100209 vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
Pablo Tello221f3812017-06-28 17:27:56 +0100210
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100211 matrix_b += 4 * in_b_stride;
212 }
213
214 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
215
216 for(; vec_a < vec_a_end_addr; ++vec_a)
217 {
218 const float16_t a0 = *vec_a;
219 const float16_t b00 = *matrix_b;
220
221 acc += b00 * a0;
222
223 matrix_b += in_b_stride;
224 }
225
226 // Multiply by the weight of matrix product (alpha)
227 if(multiply_alpha)
228 {
229 acc *= static_cast<float16_t>(alpha);
230 }
231
232 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
233
234 *(vec_out) = acc;
Pablo Tello221f3812017-06-28 17:27:56 +0100235 }
Pablo Tello221f3812017-06-28 17:27:56 +0100236 },
237 ina, inb, out);
Pablo Tello221f3812017-06-28 17:27:56 +0100238}
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100239#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello221f3812017-06-28 17:27:56 +0100240
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100241void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100242{
243 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
244 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
245 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
246
247 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100248 const int window_start_x = 16 * info.thread_id;
249 const int window_step_x = 16 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
251 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
252
253 Window win_out(window);
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100254 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
256
257 Window win_a(window);
258 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
259 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
260
261 Window win_b;
262 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
263 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
264 if(input1->info()->num_dimensions() >= 3)
265 {
266 win_b = window;
267 }
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100268 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100269 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
270
271 Iterator ina(input0, win_a);
272 Iterator inb(input1, win_b);
273 Iterator out(output, win_out);
274
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100275 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
276
277 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
278
279 execute_window_loop(win_out, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100280 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100281 int x = window_start_x;
282 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
283 // window_end_x is computed above which may cause out-of-bound writes to the output.
284 for(; x < (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100285 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100286 if(x > width_matrix_b)
287 {
288 return;
289 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100290
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100291 float32x4_t acc0 = vdupq_n_f32(0.f);
292 float32x4_t acc1 = vdupq_n_f32(0.f);
293 float32x4_t acc2 = vdupq_n_f32(0.f);
294 float32x4_t acc3 = vdupq_n_f32(0.f);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100296 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
297 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298
299#if __arm__
300 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100301 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
302 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100303#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100304
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100305 auto vec_a_end_addr = vec_a + num_elems_vec_a;
306 for(; vec_a <= (vec_a_end_addr - 4);)
307 {
308 float32x2_t a0l = vld1_f32(vec_a);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100310 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
311 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
312 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
313 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100314
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100315 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
316 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
317 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
318 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100319
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100320#if __arm__
321 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
322 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
323 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
324 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
325 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
326#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100327
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100328 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
329 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
330 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
331 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100332
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100333 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
334 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
335 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
336 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100338 vec_a += 2;
339 matrix_b += 2 * in_b_stride;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100340
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100341 a0l = vld1_f32(vec_a);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100343 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
344 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
345 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
346 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
347
348 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
349 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
350 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
351 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
352
353 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
354 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
355 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
356 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
357
358 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
359 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
360 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
361 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
362
363 vec_a += 2;
364 matrix_b += 2 * in_b_stride;
365 }
366
367 for(; vec_a < vec_a_end_addr; ++vec_a)
368 {
369 const float a0 = *vec_a;
370
371 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
372 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
373 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
374 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
375
376 acc0 = vmlaq_n_f32(acc0, b00, a0);
377 acc1 = vmlaq_n_f32(acc1, b01, a0);
378 acc2 = vmlaq_n_f32(acc2, b02, a0);
379 acc3 = vmlaq_n_f32(acc3, b03, a0);
380
381 matrix_b += in_b_stride;
382 }
383
384 // Multiply by the weight of matrix product (alpha)
385 if(multiply_alpha)
386 {
387 acc0 = vmulq_f32(acc0, alpha_f32);
388 acc1 = vmulq_f32(acc1, alpha_f32);
389 acc2 = vmulq_f32(acc2, alpha_f32);
390 acc3 = vmulq_f32(acc3, alpha_f32);
391 }
392
393 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
394
395 vst1q_f32(vec_out + 0, acc0);
396 vst1q_f32(vec_out + 4, acc1);
397 vst1q_f32(vec_out + 8, acc2);
398 vst1q_f32(vec_out + 12, acc3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100399 }
400
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100401 // Left-over loop
402 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100403 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100404 if(x > width_matrix_b)
405 {
406 return;
407 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100408
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100409 float32x4_t vacc = vdupq_n_f32(0.f);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100410
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100411 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
412 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100413
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100414#if __arm__
415 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
416 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
417 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
418#endif /* __arm__ */
419
420 auto vec_a_end_addr = vec_a + num_elems_vec_a;
421 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
422 {
423 const float32x4_t a0l = vld1q_f32(vec_a);
424
425 const float32x4_t b_col =
426 {
427 *(matrix_b + 0 * in_b_stride),
428 *(matrix_b + 1 * in_b_stride),
429 *(matrix_b + 2 * in_b_stride),
430 *(matrix_b + 3 * in_b_stride),
431 };
432
433#if __arm__
434 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
435 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
436 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
437 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
438 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
439#endif /* __arm__ */
440
441 vacc = vmlaq_f32(vacc, b_col, a0l);
442
443 matrix_b += 4 * in_b_stride;
444 }
445
446 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
447
448 for(; vec_a < vec_a_end_addr; ++vec_a)
449 {
450 const float a0 = *vec_a;
451
452 const float b00 = *matrix_b;
453
454 acc += b00 * a0;
455
456 matrix_b += in_b_stride;
457 }
458
459 // Multiply by the weight of matrix product (alpha)
460 if(multiply_alpha)
461 {
462 acc *= alpha;
463 }
464
465 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
466
467 *vec_out = acc;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100468 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469 },
470 ina, inb, out);
471}
472
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
474{
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100475 const int out_width = static_cast<int>(output->info()->dimension(0));
476 const int out_height = static_cast<int>(output->info()->dimension(1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100477 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
478 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
479 const size_t out_stride2 = out_stride1 * 2;
480 const size_t out_stride3 = out_stride1 * 3;
481 const int num_elems_matrix_b_x = input1->info()->dimension(0);
482
483 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
484 Window win_a(window);
485 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
486 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
487
488 Window win_b;
489 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
490 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
491 if(input1->info()->num_dimensions() >= 3)
492 {
493 win_b = window;
494 }
495 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
496 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
497 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
498 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
499
500 Iterator ina(input0, win_a);
501 Iterator inb(input1, win_b);
502 Iterator out(output, window);
503
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100504 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
505
506 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
507
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100508 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
509 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
510 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100511 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 {
513 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
514 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
515 auto mtx_b1 = mtx_b0 + in_b_stride;
516
517 float32x4_t acc00 = vdupq_n_f32(0.f);
518 float32x4_t acc10 = vdupq_n_f32(0.f);
519 float32x4_t acc20 = vdupq_n_f32(0.f);
520 float32x4_t acc30 = vdupq_n_f32(0.f);
521
522 float32x4_t acc01 = vdupq_n_f32(0.f);
523 float32x4_t acc11 = vdupq_n_f32(0.f);
524 float32x4_t acc21 = vdupq_n_f32(0.f);
525 float32x4_t acc31 = vdupq_n_f32(0.f);
526
527#if __arm__
528 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
529 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
530 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100531#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532
533 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
534 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
535 {
536 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
537 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
538 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
539 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
540
541 float32x4_t b00 = vld1q_f32(mtx_b0);
542 float32x4_t b10 = vld1q_f32(mtx_b1);
543 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
544 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
545
546#if __arm__
547 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
548 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
549 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100550#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100551
552 // 4x4 block 0
553 acc00 = vmlaq_f32(acc00, b00, a0);
554 acc10 = vmlaq_f32(acc10, b00, a1);
555 acc20 = vmlaq_f32(acc20, b00, a2);
556 acc30 = vmlaq_f32(acc30, b00, a3);
557
558 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
559 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
560 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
561 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
562
563 // 4x4 block 1
564 acc01 = vmlaq_f32(acc01, b10, a0);
565 acc11 = vmlaq_f32(acc11, b10, a1);
566 acc21 = vmlaq_f32(acc21, b10, a2);
567 acc31 = vmlaq_f32(acc31, b10, a3);
568
569 // 4x4 block 0
570 acc00 = vmlaq_f32(acc00, b01, a4);
571 acc10 = vmlaq_f32(acc10, b01, a5);
572 acc20 = vmlaq_f32(acc20, b01, a6);
573 acc30 = vmlaq_f32(acc30, b01, a7);
574
575 // 4x4 block 1
576 acc01 = vmlaq_f32(acc01, b11, a4);
577 acc11 = vmlaq_f32(acc11, b11, a5);
578 acc21 = vmlaq_f32(acc21, b11, a6);
579 acc31 = vmlaq_f32(acc31, b11, a7);
580
581 mtx_a0 += 8;
582 mtx_b0 += 8;
583 mtx_b1 += 8;
584
585 a0 = vld1q_dup_f32(mtx_a0 + 0);
586 a1 = vld1q_dup_f32(mtx_a0 + 1);
587 a2 = vld1q_dup_f32(mtx_a0 + 2);
588 a3 = vld1q_dup_f32(mtx_a0 + 3);
589
590 b00 = vld1q_f32(mtx_b0);
591 b10 = vld1q_f32(mtx_b1);
592 b01 = vld1q_f32(mtx_b0 + 4);
593 b11 = vld1q_f32(mtx_b1 + 4);
594
595 // 4x4 block 0
596 acc00 = vmlaq_f32(acc00, b00, a0);
597 acc10 = vmlaq_f32(acc10, b00, a1);
598 acc20 = vmlaq_f32(acc20, b00, a2);
599 acc30 = vmlaq_f32(acc30, b00, a3);
600
601 a4 = vld1q_dup_f32(mtx_a0 + 4);
602 a5 = vld1q_dup_f32(mtx_a0 + 5);
603 a6 = vld1q_dup_f32(mtx_a0 + 6);
604 a7 = vld1q_dup_f32(mtx_a0 + 7);
605
606 // 4x4 block 1
607 acc01 = vmlaq_f32(acc01, b10, a0);
608 acc11 = vmlaq_f32(acc11, b10, a1);
609 acc21 = vmlaq_f32(acc21, b10, a2);
610 acc31 = vmlaq_f32(acc31, b10, a3);
611
612 // 4x4 block 0
613 acc00 = vmlaq_f32(acc00, b01, a4);
614 acc10 = vmlaq_f32(acc10, b01, a5);
615 acc20 = vmlaq_f32(acc20, b01, a6);
616 acc30 = vmlaq_f32(acc30, b01, a7);
617
618 // 4x4 block 1
619 acc01 = vmlaq_f32(acc01, b11, a4);
620 acc11 = vmlaq_f32(acc11, b11, a5);
621 acc21 = vmlaq_f32(acc21, b11, a6);
622 acc31 = vmlaq_f32(acc31, b11, a7);
623
624 mtx_a0 += 8;
625 mtx_b0 += 8;
626 mtx_b1 += 8;
627
628 a0 = vld1q_dup_f32(mtx_a0 + 0);
629 a1 = vld1q_dup_f32(mtx_a0 + 1);
630 a2 = vld1q_dup_f32(mtx_a0 + 2);
631 a3 = vld1q_dup_f32(mtx_a0 + 3);
632 b00 = vld1q_f32(mtx_b0);
633 b10 = vld1q_f32(mtx_b1);
634 b01 = vld1q_f32(mtx_b0 + 4);
635 b11 = vld1q_f32(mtx_b1 + 4);
636
637#if __arm__
638 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
639 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
640 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100641#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642
643 // 4x4 block 0
644 acc00 = vmlaq_f32(acc00, b00, a0);
645 acc10 = vmlaq_f32(acc10, b00, a1);
646 acc20 = vmlaq_f32(acc20, b00, a2);
647 acc30 = vmlaq_f32(acc30, b00, a3);
648
649 a4 = vld1q_dup_f32(mtx_a0 + 4);
650 a5 = vld1q_dup_f32(mtx_a0 + 5);
651 a6 = vld1q_dup_f32(mtx_a0 + 6);
652 a7 = vld1q_dup_f32(mtx_a0 + 7);
653
654 // 4x4 block 1
655 acc01 = vmlaq_f32(acc01, b10, a0);
656 acc11 = vmlaq_f32(acc11, b10, a1);
657 acc21 = vmlaq_f32(acc21, b10, a2);
658 acc31 = vmlaq_f32(acc31, b10, a3);
659
660 // 4x4 block 0
661 acc00 = vmlaq_f32(acc00, b01, a4);
662 acc10 = vmlaq_f32(acc10, b01, a5);
663 acc20 = vmlaq_f32(acc20, b01, a6);
664 acc30 = vmlaq_f32(acc30, b01, a7);
665
666 // 4x4 block 1
667 acc01 = vmlaq_f32(acc01, b11, a4);
668 acc11 = vmlaq_f32(acc11, b11, a5);
669 acc21 = vmlaq_f32(acc21, b11, a6);
670 acc31 = vmlaq_f32(acc31, b11, a7);
671
672 mtx_a0 += 8;
673 mtx_b0 += 8;
674 mtx_b1 += 8;
675
676 a0 = vld1q_dup_f32(mtx_a0 + 0);
677 a1 = vld1q_dup_f32(mtx_a0 + 1);
678 a2 = vld1q_dup_f32(mtx_a0 + 2);
679 a3 = vld1q_dup_f32(mtx_a0 + 3);
680 b00 = vld1q_f32(mtx_b0);
681 b10 = vld1q_f32(mtx_b1);
682 b01 = vld1q_f32(mtx_b0 + 4);
683 b11 = vld1q_f32(mtx_b1 + 4);
684
685 // 4x4 block 0
686 acc00 = vmlaq_f32(acc00, b00, a0);
687 acc10 = vmlaq_f32(acc10, b00, a1);
688 acc20 = vmlaq_f32(acc20, b00, a2);
689 acc30 = vmlaq_f32(acc30, b00, a3);
690
691 a4 = vld1q_dup_f32(mtx_a0 + 4);
692 a5 = vld1q_dup_f32(mtx_a0 + 5);
693 a6 = vld1q_dup_f32(mtx_a0 + 6);
694 a7 = vld1q_dup_f32(mtx_a0 + 7);
695
696 // 4x4 block 1
697 acc01 = vmlaq_f32(acc01, b10, a0);
698 acc11 = vmlaq_f32(acc11, b10, a1);
699 acc21 = vmlaq_f32(acc21, b10, a2);
700 acc31 = vmlaq_f32(acc31, b10, a3);
701
702 // 4x4 block 0
703 acc00 = vmlaq_f32(acc00, b01, a4);
704 acc10 = vmlaq_f32(acc10, b01, a5);
705 acc20 = vmlaq_f32(acc20, b01, a6);
706 acc30 = vmlaq_f32(acc30, b01, a7);
707
708 // 4x4 block 1
709 acc01 = vmlaq_f32(acc01, b11, a4);
710 acc11 = vmlaq_f32(acc11, b11, a5);
711 acc21 = vmlaq_f32(acc21, b11, a6);
712 acc31 = vmlaq_f32(acc31, b11, a7);
713
714 mtx_a0 += 8;
715 mtx_b0 += 8;
716 mtx_b1 += 8;
717 }
718
719 for(; mtx_b0 < mtx_b0_end_addr;)
720 {
721 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
722 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
723 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
724 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
725 float32x4_t b00 = vld1q_f32(mtx_b0);
726 float32x4_t b10 = vld1q_f32(mtx_b1);
727
728#if __arm__
729 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
730 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
731 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100732#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100733 // 4x4 block 0
734 acc00 = vmlaq_f32(acc00, b00, a0);
735 acc10 = vmlaq_f32(acc10, b00, a1);
736 acc20 = vmlaq_f32(acc20, b00, a2);
737 acc30 = vmlaq_f32(acc30, b00, a3);
738
739 // 4x4 block 1
740 acc01 = vmlaq_f32(acc01, b10, a0);
741 acc11 = vmlaq_f32(acc11, b10, a1);
742 acc21 = vmlaq_f32(acc21, b10, a2);
743 acc31 = vmlaq_f32(acc31, b10, a3);
744
745 mtx_a0 += 4;
746 mtx_b0 += 4;
747 mtx_b1 += 4;
748 }
749
750 // Multiply by the weight of matrix product (alpha)
751 if(multiply_alpha)
752 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100753 acc00 = vmulq_f32(acc00, alpha_f32);
754 acc10 = vmulq_f32(acc10, alpha_f32);
755 acc20 = vmulq_f32(acc20, alpha_f32);
756 acc30 = vmulq_f32(acc30, alpha_f32);
757 acc01 = vmulq_f32(acc01, alpha_f32);
758 acc11 = vmulq_f32(acc11, alpha_f32);
759 acc21 = vmulq_f32(acc21, alpha_f32);
760 acc31 = vmulq_f32(acc31, alpha_f32);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100761 }
762
763 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
764 const auto mtx_out1 = mtx_out0 + 4;
765
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100766 if(id.x() < (out_width - 8))
767 {
768 vst1q_f32(mtx_out0, acc00);
769 vst1q_f32(mtx_out1, acc01);
770 if(id.y() + 1 < out_height)
771 {
772 vst1q_f32(mtx_out0 + out_stride1, acc10);
773 vst1q_f32(mtx_out1 + out_stride1, acc11);
774 if(id.y() + 2 < out_height)
775 {
776 vst1q_f32(mtx_out0 + out_stride2, acc20);
777 vst1q_f32(mtx_out1 + out_stride2, acc21);
778 if(id.y() + 3 < out_height)
779 {
780 vst1q_f32(mtx_out0 + out_stride3, acc30);
781 vst1q_f32(mtx_out1 + out_stride3, acc31);
782 }
783 }
784 }
785 }
786 else if(id.x() < (out_width - 4))
787 {
788 vst1q_f32(mtx_out0, acc00);
789 if(id.y() + 1 < out_height)
790 {
791 vst1q_f32(mtx_out0 + out_stride1, acc10);
792 if(id.y() + 2 < out_height)
793 {
794 vst1q_f32(mtx_out0 + out_stride2, acc20);
795 if(id.y() + 3 < out_height)
796 {
797 vst1q_f32(mtx_out0 + out_stride3, acc30);
798 }
799 }
800 }
801 // Left-over columns
802 const int columns_left = out_width - id.x() - 4;
803 for(auto x = 0; x < columns_left; ++x)
804 {
805 *(mtx_out1 + x) = acc01[x];
806 if(id.y() + 1 < out_height)
807 {
808 *(mtx_out1 + x + out_stride1) = acc11[x];
809 if(id.y() + 2 < out_height)
810 {
811 *(mtx_out1 + x + out_stride2) = acc21[x];
812 if(id.y() + 3 < out_height)
813 {
814 *(mtx_out1 + x + out_stride3) = acc31[x];
815 }
816 }
817 }
818 }
819 }
820 else
821 {
822 // Left-over columns
823 const int columns_left = out_width - id.x();
824 for(int x = 0; x < columns_left; ++x)
825 {
826 *(mtx_out0 + x) = acc00[x];
827 if(id.y() + 1 < out_height)
828 {
829 *(mtx_out0 + x + out_stride1) = acc10[x];
830 if(id.y() + 2 < out_height)
831 {
832 *(mtx_out0 + x + out_stride2) = acc20[x];
833 if(id.y() + 3 < out_height)
834 {
835 *(mtx_out0 + x + out_stride3) = acc30[x];
836 }
837 }
838 }
839 }
840 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100841 },
842 ina, inb, out);
843}
844
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100845#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100846void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
847{
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100848 const int out_width = static_cast<int>(output->info()->dimension(0));
849 const int out_height = static_cast<int>(output->info()->dimension(1));
Pablo Tello221f3812017-06-28 17:27:56 +0100850 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
851 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
852 const int num_elems_matrix_b_x = input1->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100853
854 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
855 Window win_a(window);
856 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
857 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
858
859 Window win_b;
860 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
861 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
862 if(input1->info()->num_dimensions() >= 3)
863 {
864 win_b = window;
865 }
866 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
867 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
868 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
869
870 Iterator ina(input0, win_a);
871 Iterator inb(input1, win_b);
872 Iterator out(output, window);
873
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100874 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
875
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
877
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100878 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100879 {
880 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
881 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
882 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
883 float16x8x4_t c =
884 {
885 {
886 vdupq_n_f16(0.f),
887 vdupq_n_f16(0.f),
888 vdupq_n_f16(0.f),
889 vdupq_n_f16(0.f)
890 }
891 };
892
893 /*
894 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
895 |a00 a01 a02 a03 | a04 a05 a06 a07|
896 |a10 a11 a12 a13 | a14 a15 a16 a17|
897 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
898 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
899 |a40 a41 a42 a43 | a44 a45 a46 a47|
900 |a50 a51 a52 a53 | a54 a55 a56 a57|
901 |a60 a61 a62 a63 | a64 a65 a66 a67|
902 |a70 a71 a72 a73 | a74 a75 a76 a77|
903
904 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
905
906 B Matrix has been transposed as shown below
907
908 |b00 b01 b02 b03 b04 b05 b06 b07|
909 |b10 b11 b12 b13 b14 b15 b16 b17|
910 |b20 b21 b22 b23 b24 b25 b26 b27|
911 |b30 b31 b32 b33 b34 b35 b36 b37|
912 ------------------->
913
914 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
915
916 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
917 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
918
919 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
920 */
Pablo Tello221f3812017-06-28 17:27:56 +0100921 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
922
923 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
924
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100925 {
926 const float16x8_t p00 = vld1q_f16(mtx_a0);
927 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100928
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100929 const float16x8_t q00 = vld1q_f16(mtx_b0);
930 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
931 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
932 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
933
934 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
935 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
936 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
937 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
938
939 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
940 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
941 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
942 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
943
944 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
945 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
946 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
947 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
948
949 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
950 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
951 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
952 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +0100953
954 mtx_a0 += 16;
955 mtx_b0 += 32;
956 }
957
958 for(; mtx_b0 < mtx_b0_end_addr;)
959
960 {
961 const float16x4_t p00 = vld1_f16(mtx_a0);
962 const float16x8_t q00 = vld1q_f16(mtx_b0);
963
964 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
965 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
966 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
967 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
968
969 mtx_a0 += 4;
970 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100971 }
972
973 if(multiply_alpha)
974 {
975 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
976 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
977 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
978 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
979 }
980
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100981 if(id.x() < (out_width - 8))
982 {
983 vst1q_f16(mtx_out, c.val[0]);
984 if(id.y() + 1 < out_height)
985 {
986 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
987 if(id.y() + 2 < out_height)
988 {
989 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
990 if(id.y() + 3 < out_height)
991 {
992 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
993 }
994 }
995 }
996 }
997 else
998 {
999 // Left-over columns
1000 const int columns_left = out_width - id.x();
1001 for(int x = 0; x < columns_left; ++x)
1002 {
1003 *(mtx_out + x) = c.val[0][x];
1004 if(id.y() + 1 < out_height)
1005 {
1006 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1007 if(id.y() + 2 < out_height)
1008 {
1009 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1010 if(id.y() + 3 < out_height)
1011 {
1012 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1013 }
1014 }
1015 }
1016 }
1017 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001018 },
1019 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001020}
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001021#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001022
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001023inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001024{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001025 ARM_COMPUTE_UNUSED(alpha);
1026
Anthony Barbiereaefd002018-07-20 17:49:35 +01001027 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001028 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001029 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001030
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001031 if(!is_interleaved)
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001032 {
1033 ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001034
1035 if(output->total_size() != 0)
1036 {
1037 ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
1038 ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
1039 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001040 }
1041 }
1042 else
1043 {
1044 const int m = reshape_info.m();
1045 const int n = reshape_info.n();
1046 const int k = reshape_info.k();
1047 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1048 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1049
1050 /* Interleave */
1051 TensorShape tensor_shape0{ input0->tensor_shape() };
1052 tensor_shape0.set(0, k);
1053 tensor_shape0.set(1, m);
1054
1055 const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
1056 const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
1057 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
1058
1059 if(n != 0) /* Transpose */
1060 {
1061 TensorShape tensor_shape1{ input1->tensor_shape() };
1062 tensor_shape1.set(0, n);
1063 tensor_shape1.set(1, k);
1064
1065 const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
1066 const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
1067 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
1068 }
1069
1070 if(output->total_size() != 0)
1071 {
1072 if(n != 0)
1073 {
1074 ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
1075 }
1076 ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
1077 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001078 }
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001079 }
1080
1081 return Status{};
1082}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001083} // namespace
1084
1085NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
1086 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1087{
1088}
1089
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001090void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001091{
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001092 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001093
1094 // Output tensor auto inizialitation if not yet initialized
1095 TensorShape tensor_shape{ input0->info()->tensor_shape() };
1096 tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0));
1097 tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1));
1098
1099 auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
1100
1101 // Perform validate step
1102 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001103
1104 _input0 = input0;
1105 _input1 = input1;
1106 _output = output;
1107 _alpha = alpha;
1108
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001109 // Configure kernel window
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001110 Window win{};
1111
1112 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1113 if((output->info()->dimension(1) == 1))
1114 {
1115 const unsigned int num_elems_processed_per_iteration_x = (input0->info()->data_type() == DataType::F32) ? 16 : 32;
1116
1117 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1118 }
1119 else
1120 {
1121 constexpr unsigned int num_elems_processed_per_iteration_x = 8;
1122 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
1123
1124 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1125 }
1126
1127 Coordinates coord;
1128 coord.set_num_dimensions(output->info()->num_dimensions());
1129 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
1130 INEKernel::configure(win);
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001131}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001132
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001133Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
1134 const GEMMReshapeInfo &reshape_info)
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001135{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001136 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001137
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001138 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001139}
1140
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001141void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142{
1143 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1144 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1145
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001146 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001147 const bool is_output_vector = (_output->info()->dimension(1) == 1);
1148 switch(_input0->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001149 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001150 case DataType::F32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001151 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001152 is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) :
1153 matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha);
1154 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001155 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001156#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001157 case DataType::F16:
1158 {
1159 is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) :
1160 matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha);
1161 break;
1162 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001163#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001164 default:
1165 {
1166 ARM_COMPUTE_ERROR("Data type not supported");
1167 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001168 }
1169 }
1170}
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001171} // namespace arm_compute