blob: 0ca24748affded61939a1c2dbddb44671de8596c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
25
Moritz Pflanzer484e7b32017-08-09 11:43:18 +010026#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/AccessWindowTranspose.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010028#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
33#include "arm_compute/core/NEON/NEFixedPoint.h"
34#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Types.h"
36#include "arm_compute/core/Utils.h"
37#include "arm_compute/core/Validate.h"
38#include "arm_compute/core/Window.h"
39
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000040#include "arm_compute/core/utils/misc/ShapeCalculator.h"
41
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042#include <arm_neon.h>
43#include <cstddef>
44#include <cstdint>
45#include <tuple>
46
47using namespace arm_compute;
48
49namespace arm_compute
50{
51class Coordinates;
52} // namespace arm_compute
53
54namespace
55{
56template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +010057void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Pablo Tello221f3812017-06-28 17:27:56 +010058{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000059#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello221f3812017-06-28 17:27:56 +010060 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
61 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
62 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
63
64 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +010065 const int window_start_x = 32 * info.thread_id;
66 const int window_step_x = 32 * info.num_threads;
Pablo Tello221f3812017-06-28 17:27:56 +010067 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
68 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
69
70 Window win_out(window);
71 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
72 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
73
74 Window win_a(window);
75 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
76 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
77
78 Window win_b;
79 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
80 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
81 if(input1->info()->num_dimensions() >= 3)
82 {
83 win_b = window;
84 }
85 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
86 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
87
88 Iterator ina(input0, win_a);
89 Iterator inb(input1, win_b);
90 Iterator out(output, win_out);
91
92 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
93 ARM_COMPUTE_UNUSED(alpha_f16);
94
95 execute_window_loop(win_out, [&](const Coordinates & id)
96 {
97 if(id.x() > width_matrix_b)
98 {
99 return;
100 }
101
102 float16x8_t acc0 = vdupq_n_f16(0.f);
103 float16x8_t acc1 = vdupq_n_f16(0.f);
104 float16x8_t acc2 = vdupq_n_f16(0.f);
105 float16x8_t acc3 = vdupq_n_f16(0.f);
106
107 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
108 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
109
110 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
111 for(; vec_a <= (vec_a_end_addr - 4);)
112 {
113 const float16x4_t a0l = vld1_f16(vec_a);
114
115 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
116 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
117 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
118 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
119 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
120 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
121 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
122 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
123
124 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
125 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
126 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
127 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
128 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
129 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
130 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
131 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
132
133 matrix_b += 2 * in_b_stride;
134
135 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
136 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
137 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
138 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
139 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
140 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
141 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
142 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
143
144 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
145 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
146 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
147 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
148 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
149 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
150 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
151 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
152
153 vec_a += 4;
154 matrix_b += 2 * in_b_stride;
155 }
156
157 for(; vec_a < vec_a_end_addr;)
158 {
159 const float16_t a0 = *vec_a;
160 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
161 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
162 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
163 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
164
165 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
166 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
167 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
168 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
169
170 vec_a += 1;
171 matrix_b += in_b_stride;
172 }
173
174 // Multiply by the weight of matrix product (alpha)
175 if(multiply_alpha)
176 {
177 acc0 = vmulq_f16(acc0, alpha_f16);
178 acc1 = vmulq_f16(acc1, alpha_f16);
179 acc2 = vmulq_f16(acc2, alpha_f16);
180 acc3 = vmulq_f16(acc3, alpha_f16);
181 }
182
183 const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
184
185 vst1q_f16(vec_out + 0, acc0);
186 vst1q_f16(vec_out + 8, acc1);
187 vst1q_f16(vec_out + 16, acc2);
188 vst1q_f16(vec_out + 24, acc3);
189
190 },
191 ina, inb, out);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000192#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Georgios Pinitas30f02152017-09-27 11:20:48 +0100193 ARM_COMPUTE_UNUSED(input0);
194 ARM_COMPUTE_UNUSED(input1);
195 ARM_COMPUTE_UNUSED(output);
196 ARM_COMPUTE_UNUSED(window);
197 ARM_COMPUTE_UNUSED(info);
198 ARM_COMPUTE_UNUSED(alpha);
Pablo Tello221f3812017-06-28 17:27:56 +0100199 ARM_COMPUTE_ERROR("Not implemented");
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000200#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello221f3812017-06-28 17:27:56 +0100201}
202
203template <bool multiply_alpha>
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100204void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100205{
206 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
207 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
208 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
209
210 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100211 const int window_start_x = 16 * info.thread_id;
212 const int window_step_x = 16 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
214 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
215
216 Window win_out(window);
217 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
218 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
219
220 Window win_a(window);
221 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
222 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
223
224 Window win_b;
225 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
226 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
227 if(input1->info()->num_dimensions() >= 3)
228 {
229 win_b = window;
230 }
231 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
232 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
233
234 Iterator ina(input0, win_a);
235 Iterator inb(input1, win_b);
236 Iterator out(output, win_out);
237
238 execute_window_loop(win_out, [&](const Coordinates & id)
239 {
240 if(id.x() > width_matrix_b)
241 {
242 return;
243 }
244
245 float32x4_t acc0 = vdupq_n_f32(0.f);
246 float32x4_t acc1 = vdupq_n_f32(0.f);
247 float32x4_t acc2 = vdupq_n_f32(0.f);
248 float32x4_t acc3 = vdupq_n_f32(0.f);
249
250 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
251 auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
252
253#if __arm__
254 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
255 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
256 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100257#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100258
259 auto vec_a_end_addr = vec_a + num_elems_vec_a;
260 for(; vec_a <= (vec_a_end_addr - 4);)
261 {
262 float32x2_t a0l = vld1_f32(vec_a);
263
264 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
265 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
266 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
267 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
268
269 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
270 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
271 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
272 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
273
274#if __arm__
275 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
276 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
277 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
278 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
279 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100280#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100281
282 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
283 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
284 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
285 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
286
287 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
288 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
289 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
290 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
291
292 vec_a += 2;
293 matrix_b += 2 * in_b_stride;
294
295 a0l = vld1_f32(vec_a);
296
297 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
298 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
299 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
300 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
301
302 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
303 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
304 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
305 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
306
307 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
308 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
309 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
310 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
311
312 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
313 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
314 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
315 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
316
317 vec_a += 2;
318 matrix_b += 2 * in_b_stride;
319 }
320
321 for(; vec_a < vec_a_end_addr;)
322 {
323 const float a0 = *vec_a;
324
325 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
326 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
327 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
328 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
329
330 acc0 = vmlaq_n_f32(acc0, b00, a0);
331 acc1 = vmlaq_n_f32(acc1, b01, a0);
332 acc2 = vmlaq_n_f32(acc2, b02, a0);
333 acc3 = vmlaq_n_f32(acc3, b03, a0);
334
335 vec_a += 1;
336 matrix_b += in_b_stride;
337 }
338
339 // Multiply by the weight of matrix product (alpha)
340 if(multiply_alpha)
341 {
342 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
343 acc0 = vmulq_f32(acc0, alpha_f32);
344 acc1 = vmulq_f32(acc1, alpha_f32);
345 acc2 = vmulq_f32(acc2, alpha_f32);
346 acc3 = vmulq_f32(acc3, alpha_f32);
347 }
348
349 const auto vec_out = reinterpret_cast<float *>(out.ptr());
350
351 vst1q_f32(vec_out + 0, acc0);
352 vst1q_f32(vec_out + 4, acc1);
353 vst1q_f32(vec_out + 8, acc2);
354 vst1q_f32(vec_out + 12, acc3);
355 },
356 ina, inb, out);
357}
358
359template <bool multiply_alpha>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100360void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
361{
362 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
363 const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
364 const size_t out_stride2 = out_stride1 * 2;
365 const size_t out_stride3 = out_stride1 * 3;
366 const int num_elems_matrix_b_x = input1->info()->dimension(0);
367
368 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
369 Window win_a(window);
370 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
371 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
372
373 Window win_b;
374 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
375 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
376 if(input1->info()->num_dimensions() >= 3)
377 {
378 win_b = window;
379 }
380 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
381 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
382 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
383 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
384
385 Iterator ina(input0, win_a);
386 Iterator inb(input1, win_b);
387 Iterator out(output, window);
388
389 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
390 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
391 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
392 execute_window_loop(window, [&](const Coordinates & id)
393 {
394 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
395 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
396 auto mtx_b1 = mtx_b0 + in_b_stride;
397
398 float32x4_t acc00 = vdupq_n_f32(0.f);
399 float32x4_t acc10 = vdupq_n_f32(0.f);
400 float32x4_t acc20 = vdupq_n_f32(0.f);
401 float32x4_t acc30 = vdupq_n_f32(0.f);
402
403 float32x4_t acc01 = vdupq_n_f32(0.f);
404 float32x4_t acc11 = vdupq_n_f32(0.f);
405 float32x4_t acc21 = vdupq_n_f32(0.f);
406 float32x4_t acc31 = vdupq_n_f32(0.f);
407
408#if __arm__
409 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
410 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
411 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100412#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100413
414 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
415 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
416 {
417 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
418 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
419 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
420 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
421
422 float32x4_t b00 = vld1q_f32(mtx_b0);
423 float32x4_t b10 = vld1q_f32(mtx_b1);
424 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
425 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
426
427#if __arm__
428 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
429 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
430 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100431#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100432
433 // 4x4 block 0
434 acc00 = vmlaq_f32(acc00, b00, a0);
435 acc10 = vmlaq_f32(acc10, b00, a1);
436 acc20 = vmlaq_f32(acc20, b00, a2);
437 acc30 = vmlaq_f32(acc30, b00, a3);
438
439 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
440 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
441 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
442 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
443
444 // 4x4 block 1
445 acc01 = vmlaq_f32(acc01, b10, a0);
446 acc11 = vmlaq_f32(acc11, b10, a1);
447 acc21 = vmlaq_f32(acc21, b10, a2);
448 acc31 = vmlaq_f32(acc31, b10, a3);
449
450 // 4x4 block 0
451 acc00 = vmlaq_f32(acc00, b01, a4);
452 acc10 = vmlaq_f32(acc10, b01, a5);
453 acc20 = vmlaq_f32(acc20, b01, a6);
454 acc30 = vmlaq_f32(acc30, b01, a7);
455
456 // 4x4 block 1
457 acc01 = vmlaq_f32(acc01, b11, a4);
458 acc11 = vmlaq_f32(acc11, b11, a5);
459 acc21 = vmlaq_f32(acc21, b11, a6);
460 acc31 = vmlaq_f32(acc31, b11, a7);
461
462 mtx_a0 += 8;
463 mtx_b0 += 8;
464 mtx_b1 += 8;
465
466 a0 = vld1q_dup_f32(mtx_a0 + 0);
467 a1 = vld1q_dup_f32(mtx_a0 + 1);
468 a2 = vld1q_dup_f32(mtx_a0 + 2);
469 a3 = vld1q_dup_f32(mtx_a0 + 3);
470
471 b00 = vld1q_f32(mtx_b0);
472 b10 = vld1q_f32(mtx_b1);
473 b01 = vld1q_f32(mtx_b0 + 4);
474 b11 = vld1q_f32(mtx_b1 + 4);
475
476 // 4x4 block 0
477 acc00 = vmlaq_f32(acc00, b00, a0);
478 acc10 = vmlaq_f32(acc10, b00, a1);
479 acc20 = vmlaq_f32(acc20, b00, a2);
480 acc30 = vmlaq_f32(acc30, b00, a3);
481
482 a4 = vld1q_dup_f32(mtx_a0 + 4);
483 a5 = vld1q_dup_f32(mtx_a0 + 5);
484 a6 = vld1q_dup_f32(mtx_a0 + 6);
485 a7 = vld1q_dup_f32(mtx_a0 + 7);
486
487 // 4x4 block 1
488 acc01 = vmlaq_f32(acc01, b10, a0);
489 acc11 = vmlaq_f32(acc11, b10, a1);
490 acc21 = vmlaq_f32(acc21, b10, a2);
491 acc31 = vmlaq_f32(acc31, b10, a3);
492
493 // 4x4 block 0
494 acc00 = vmlaq_f32(acc00, b01, a4);
495 acc10 = vmlaq_f32(acc10, b01, a5);
496 acc20 = vmlaq_f32(acc20, b01, a6);
497 acc30 = vmlaq_f32(acc30, b01, a7);
498
499 // 4x4 block 1
500 acc01 = vmlaq_f32(acc01, b11, a4);
501 acc11 = vmlaq_f32(acc11, b11, a5);
502 acc21 = vmlaq_f32(acc21, b11, a6);
503 acc31 = vmlaq_f32(acc31, b11, a7);
504
505 mtx_a0 += 8;
506 mtx_b0 += 8;
507 mtx_b1 += 8;
508
509 a0 = vld1q_dup_f32(mtx_a0 + 0);
510 a1 = vld1q_dup_f32(mtx_a0 + 1);
511 a2 = vld1q_dup_f32(mtx_a0 + 2);
512 a3 = vld1q_dup_f32(mtx_a0 + 3);
513 b00 = vld1q_f32(mtx_b0);
514 b10 = vld1q_f32(mtx_b1);
515 b01 = vld1q_f32(mtx_b0 + 4);
516 b11 = vld1q_f32(mtx_b1 + 4);
517
518#if __arm__
519 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
520 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
521 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100522#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100523
524 // 4x4 block 0
525 acc00 = vmlaq_f32(acc00, b00, a0);
526 acc10 = vmlaq_f32(acc10, b00, a1);
527 acc20 = vmlaq_f32(acc20, b00, a2);
528 acc30 = vmlaq_f32(acc30, b00, a3);
529
530 a4 = vld1q_dup_f32(mtx_a0 + 4);
531 a5 = vld1q_dup_f32(mtx_a0 + 5);
532 a6 = vld1q_dup_f32(mtx_a0 + 6);
533 a7 = vld1q_dup_f32(mtx_a0 + 7);
534
535 // 4x4 block 1
536 acc01 = vmlaq_f32(acc01, b10, a0);
537 acc11 = vmlaq_f32(acc11, b10, a1);
538 acc21 = vmlaq_f32(acc21, b10, a2);
539 acc31 = vmlaq_f32(acc31, b10, a3);
540
541 // 4x4 block 0
542 acc00 = vmlaq_f32(acc00, b01, a4);
543 acc10 = vmlaq_f32(acc10, b01, a5);
544 acc20 = vmlaq_f32(acc20, b01, a6);
545 acc30 = vmlaq_f32(acc30, b01, a7);
546
547 // 4x4 block 1
548 acc01 = vmlaq_f32(acc01, b11, a4);
549 acc11 = vmlaq_f32(acc11, b11, a5);
550 acc21 = vmlaq_f32(acc21, b11, a6);
551 acc31 = vmlaq_f32(acc31, b11, a7);
552
553 mtx_a0 += 8;
554 mtx_b0 += 8;
555 mtx_b1 += 8;
556
557 a0 = vld1q_dup_f32(mtx_a0 + 0);
558 a1 = vld1q_dup_f32(mtx_a0 + 1);
559 a2 = vld1q_dup_f32(mtx_a0 + 2);
560 a3 = vld1q_dup_f32(mtx_a0 + 3);
561 b00 = vld1q_f32(mtx_b0);
562 b10 = vld1q_f32(mtx_b1);
563 b01 = vld1q_f32(mtx_b0 + 4);
564 b11 = vld1q_f32(mtx_b1 + 4);
565
566 // 4x4 block 0
567 acc00 = vmlaq_f32(acc00, b00, a0);
568 acc10 = vmlaq_f32(acc10, b00, a1);
569 acc20 = vmlaq_f32(acc20, b00, a2);
570 acc30 = vmlaq_f32(acc30, b00, a3);
571
572 a4 = vld1q_dup_f32(mtx_a0 + 4);
573 a5 = vld1q_dup_f32(mtx_a0 + 5);
574 a6 = vld1q_dup_f32(mtx_a0 + 6);
575 a7 = vld1q_dup_f32(mtx_a0 + 7);
576
577 // 4x4 block 1
578 acc01 = vmlaq_f32(acc01, b10, a0);
579 acc11 = vmlaq_f32(acc11, b10, a1);
580 acc21 = vmlaq_f32(acc21, b10, a2);
581 acc31 = vmlaq_f32(acc31, b10, a3);
582
583 // 4x4 block 0
584 acc00 = vmlaq_f32(acc00, b01, a4);
585 acc10 = vmlaq_f32(acc10, b01, a5);
586 acc20 = vmlaq_f32(acc20, b01, a6);
587 acc30 = vmlaq_f32(acc30, b01, a7);
588
589 // 4x4 block 1
590 acc01 = vmlaq_f32(acc01, b11, a4);
591 acc11 = vmlaq_f32(acc11, b11, a5);
592 acc21 = vmlaq_f32(acc21, b11, a6);
593 acc31 = vmlaq_f32(acc31, b11, a7);
594
595 mtx_a0 += 8;
596 mtx_b0 += 8;
597 mtx_b1 += 8;
598 }
599
600 for(; mtx_b0 < mtx_b0_end_addr;)
601 {
602 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
603 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
604 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
605 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
606 float32x4_t b00 = vld1q_f32(mtx_b0);
607 float32x4_t b10 = vld1q_f32(mtx_b1);
608
609#if __arm__
610 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
611 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
612 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100613#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614 // 4x4 block 0
615 acc00 = vmlaq_f32(acc00, b00, a0);
616 acc10 = vmlaq_f32(acc10, b00, a1);
617 acc20 = vmlaq_f32(acc20, b00, a2);
618 acc30 = vmlaq_f32(acc30, b00, a3);
619
620 // 4x4 block 1
621 acc01 = vmlaq_f32(acc01, b10, a0);
622 acc11 = vmlaq_f32(acc11, b10, a1);
623 acc21 = vmlaq_f32(acc21, b10, a2);
624 acc31 = vmlaq_f32(acc31, b10, a3);
625
626 mtx_a0 += 4;
627 mtx_b0 += 4;
628 mtx_b1 += 4;
629 }
630
631 // Multiply by the weight of matrix product (alpha)
632 if(multiply_alpha)
633 {
634 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
635 acc00 = vmulq_f32(acc00, alpha_f32);
636 acc10 = vmulq_f32(acc10, alpha_f32);
637 acc20 = vmulq_f32(acc20, alpha_f32);
638 acc30 = vmulq_f32(acc30, alpha_f32);
639 acc01 = vmulq_f32(acc01, alpha_f32);
640 acc11 = vmulq_f32(acc11, alpha_f32);
641 acc21 = vmulq_f32(acc21, alpha_f32);
642 acc31 = vmulq_f32(acc31, alpha_f32);
643 }
644
645 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
646 const auto mtx_out1 = mtx_out0 + 4;
647
648 // Store the 4 blocks
649 vst1q_f32(mtx_out0, acc00);
650 vst1q_f32(mtx_out1, acc01);
651 vst1q_f32(mtx_out0 + out_stride1, acc10);
652 vst1q_f32(mtx_out1 + out_stride1, acc11);
653 vst1q_f32(mtx_out0 + out_stride2, acc20);
654 vst1q_f32(mtx_out1 + out_stride2, acc21);
655 vst1q_f32(mtx_out0 + out_stride3, acc30);
656 vst1q_f32(mtx_out1 + out_stride3, acc31);
657 },
658 ina, inb, out);
659}
660
661template <bool multiply_alpha>
662void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
663{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000664#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello221f3812017-06-28 17:27:56 +0100665 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
666 const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
667 const int num_elems_matrix_b_x = input1->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100668
669 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
670 Window win_a(window);
671 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
672 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
673
674 Window win_b;
675 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
676 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
677 if(input1->info()->num_dimensions() >= 3)
678 {
679 win_b = window;
680 }
681 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
682 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
683 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
684
685 Iterator ina(input0, win_a);
686 Iterator inb(input1, win_b);
687 Iterator out(output, window);
688
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100689 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
690
691 execute_window_loop(window, [&](const Coordinates & id)
692 {
693 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
694 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
695 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
696 float16x8x4_t c =
697 {
698 {
699 vdupq_n_f16(0.f),
700 vdupq_n_f16(0.f),
701 vdupq_n_f16(0.f),
702 vdupq_n_f16(0.f)
703 }
704 };
705
706 /*
707 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
708 |a00 a01 a02 a03 | a04 a05 a06 a07|
709 |a10 a11 a12 a13 | a14 a15 a16 a17|
710 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
711 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
712 |a40 a41 a42 a43 | a44 a45 a46 a47|
713 |a50 a51 a52 a53 | a54 a55 a56 a57|
714 |a60 a61 a62 a63 | a64 a65 a66 a67|
715 |a70 a71 a72 a73 | a74 a75 a76 a77|
716
717 After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
718
719 B Matrix has been transposed as shown below
720
721 |b00 b01 b02 b03 b04 b05 b06 b07|
722 |b10 b11 b12 b13 b14 b15 b16 b17|
723 |b20 b21 b22 b23 b24 b25 b26 b27|
724 |b30 b31 b32 b33 b34 b35 b36 b37|
725 ------------------->
726
727 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
728
729 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
730 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
731
732 The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
733 */
Pablo Tello221f3812017-06-28 17:27:56 +0100734 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
735
736 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
737
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100738 {
739 const float16x8_t p00 = vld1q_f16(mtx_a0);
740 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100741
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 const float16x8_t q00 = vld1q_f16(mtx_b0);
743 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
744 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
745 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
746
747 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
748 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
749 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
750 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
751
752 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
753 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
754 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
755 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
756
757 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
758 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
759 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
760 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
761
762 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
763 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
764 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
765 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +0100766
767 mtx_a0 += 16;
768 mtx_b0 += 32;
769 }
770
771 for(; mtx_b0 < mtx_b0_end_addr;)
772
773 {
774 const float16x4_t p00 = vld1_f16(mtx_a0);
775 const float16x8_t q00 = vld1q_f16(mtx_b0);
776
777 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
778 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
779 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
780 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
781
782 mtx_a0 += 4;
783 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100784 }
785
786 if(multiply_alpha)
787 {
788 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
789 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
790 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
791 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
792 }
793
794 vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
795 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
796 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
797 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
798 },
799 ina, inb, out);
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000800#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Georgios Pinitas30f02152017-09-27 11:20:48 +0100801 ARM_COMPUTE_UNUSED(input0);
802 ARM_COMPUTE_UNUSED(input1);
803 ARM_COMPUTE_UNUSED(output);
804 ARM_COMPUTE_UNUSED(window);
805 ARM_COMPUTE_UNUSED(alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100806 ARM_COMPUTE_ERROR("Not implemented");
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000807#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100808}
809
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000810inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000811{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000812 ARM_COMPUTE_UNUSED(alpha);
813
Anthony Barbiereaefd002018-07-20 17:49:35 +0100814 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100815 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000816 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000817
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000818 if(!is_interleaved)
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000819 {
820 ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000821
822 if(output->total_size() != 0)
823 {
824 ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
825 ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
826 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000827 }
828 }
829 else
830 {
831 const int m = reshape_info.m();
832 const int n = reshape_info.n();
833 const int k = reshape_info.k();
834 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
835 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
836
837 /* Interleave */
838 TensorShape tensor_shape0{ input0->tensor_shape() };
839 tensor_shape0.set(0, k);
840 tensor_shape0.set(1, m);
841
842 const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
843 const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
844 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
845
846 if(n != 0) /* Transpose */
847 {
848 TensorShape tensor_shape1{ input1->tensor_shape() };
849 tensor_shape1.set(0, n);
850 tensor_shape1.set(1, k);
851
852 const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
853 const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
854 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
855 }
856
857 if(output->total_size() != 0)
858 {
859 if(n != 0)
860 {
861 ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
862 }
863 ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
864 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000865 }
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000866 }
867
868 return Status{};
869}
870
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000871inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000872{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000873 bool window_changed{};
874 Window win{};
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000875
876 unsigned int num_elems_processed_per_iteration_x = 0;
877 const unsigned int num_elems_processed_per_iteration_y = 4;
878
879 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
880 if((output->dimension(1) == 1))
881 {
882 switch(input0->data_type())
883 {
884 case DataType::F32:
885 {
886 num_elems_processed_per_iteration_x = 16;
887 break;
888 }
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000889#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
890 case DataType::F16:
891 {
892 num_elems_processed_per_iteration_x = 32;
893 break;
894 }
895#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
896 default:
897 {
898 ARM_COMPUTE_ERROR("Data type not supported");
899 break;
900 }
901 }
902
903 // Configure kernel window
904 win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
905
906 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
907
908 window_changed = update_window_and_padding(win,
909 AccessWindowStatic(input0, 0, 0, input0->tensor_shape().x(), 1),
910 AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration_x),
911 output_access);
912
913 Coordinates coord;
914 coord.set_num_dimensions(output->num_dimensions());
915 output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
916 }
917 else
918 {
919 switch(input0->data_type())
920 {
921 case DataType::F32:
922 {
923 num_elems_processed_per_iteration_x = 8;
924 break;
925 }
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000926#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
927 case DataType::F16:
928 {
929 num_elems_processed_per_iteration_x = 8;
930 break;
931 }
932#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
933 default:
934 {
935 ARM_COMPUTE_ERROR("Data type not supported");
936 break;
937 }
938 }
939
940 // Configure kernel window
941 win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
942
943 AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
944
945 window_changed = update_window_and_padding(win,
946 AccessWindowRectangle(input0, 0, 0, 4, 1, 1.f, 0.25f),
947 AccessWindowStatic(input1, 0, 0, input1->tensor_shape().x(), ceil_to_multiple(input1->tensor_shape().y(), 4)),
948 output_access);
949
950 output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
951 }
952
953 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
954 return std::make_pair(err, win);
955}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100956} // namespace
957
958NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
959 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
960{
961}
962
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000963void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100964{
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000965 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000966
967 // Output tensor auto inizialitation if not yet initialized
968 TensorShape tensor_shape{ input0->info()->tensor_shape() };
969 tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0));
970 tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1));
971
972 auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
973
974 // Perform validate step
975 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100976
977 _input0 = input0;
978 _input1 = input1;
979 _output = output;
980 _alpha = alpha;
981
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000982 // Configure kernel window
983 auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
984 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
985 INEKernel::configure(win_config.second);
986}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100987
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000988Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
989 const GEMMReshapeInfo &reshape_info)
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000990{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000991 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000992 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100993
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000994 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100995}
996
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100997void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100998{
999 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1000 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1001
1002 bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
1003
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +01001004 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001005 if((_output->info()->dimension(1) == 1))
1006 {
1007 switch(_input0->info()->data_type())
1008 {
1009 case DataType::F32:
1010 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001011 multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, info, _alpha) :
1012 vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001013 break;
1014 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001015#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello221f3812017-06-28 17:27:56 +01001016 case DataType::F16:
1017 {
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001018 multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, info, _alpha) :
1019 vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, info, _alpha);
Pablo Tello221f3812017-06-28 17:27:56 +01001020 break;
1021 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001022#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023 default:
1024 {
1025 ARM_COMPUTE_ERROR("Data type not supported");
1026 break;
1027 }
1028 }
1029 }
1030 else
1031 {
1032 switch(_input0->info()->data_type())
1033 {
1034 case DataType::F32:
1035 {
1036 multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1037 matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1038 break;
1039 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001040#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001041 case DataType::F16:
1042 {
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001043 multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1044 matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1045 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001046 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001047#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001048 default:
1049 {
1050 ARM_COMPUTE_ERROR("Data type not supported");
1051 break;
1052 }
1053 }
1054 }
1055}