blob: d86ea064de1d9b884652df518c4bb0a176c6d618 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
SiCongLib88272e2021-02-24 15:40:57 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michele Di Giorgio53832b22021-06-21 14:45:44 +010024#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Types.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Validate.h"
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000030#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010031#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010032#include "src/core/helpers/AutoConfiguration.h"
33#include "src/core/helpers/WindowHelpers.h"
34#include "src/core/utils/helpers/float_ops.h"
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000035
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037
38namespace arm_compute
39{
Michele Di Giorgio53832b22021-06-21 14:45:44 +010040namespace cpu
41{
42namespace kernels
43{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044namespace
45{
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010046#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michele Di Giorgio53832b22021-06-21 14:45:44 +010047void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Pablo Tello221f3812017-06-28 17:27:56 +010048{
Michele Di Giorgio53832b22021-06-21 14:45:44 +010049 const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
50 const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
51 const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
Pablo Tello221f3812017-06-28 17:27:56 +010052
53 // The implementation computes 32 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +010054 const int window_start_x = 32 * info.thread_id;
55 const int window_step_x = 32 * info.num_threads;
Pablo Tello221f3812017-06-28 17:27:56 +010056 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
57 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
58
59 Window win_out(window);
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010060 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
Pablo Tello221f3812017-06-28 17:27:56 +010061 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
62
63 Window win_a(window);
64 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
65 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
66
67 Window win_b;
68 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
69 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
Michele Di Giorgio53832b22021-06-21 14:45:44 +010070 if(rhs->info()->num_dimensions() >= 3)
Pablo Tello221f3812017-06-28 17:27:56 +010071 {
72 win_b = window;
73 }
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010074 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
Pablo Tello221f3812017-06-28 17:27:56 +010075 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
76
Michele Di Giorgio53832b22021-06-21 14:45:44 +010077 Iterator ina(lhs, win_a);
78 Iterator inb(rhs, win_b);
79 Iterator out(dst, win_out);
Pablo Tello221f3812017-06-28 17:27:56 +010080
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010081 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
82
Pablo Tello221f3812017-06-28 17:27:56 +010083 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
Pablo Tello221f3812017-06-28 17:27:56 +010084
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010085 execute_window_loop(win_out, [&](const Coordinates &)
Pablo Tello221f3812017-06-28 17:27:56 +010086 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010087 int x = window_start_x;
88 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
Michele Di Giorgio53832b22021-06-21 14:45:44 +010089 // window_end_x is computed above which may cause out-of-bound writes to the dst.
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010090 for(; x < (window_end_x - window_step_x); x += window_step_x)
Pablo Tello221f3812017-06-28 17:27:56 +010091 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +010092 if(x > width_matrix_b)
93 {
94 return;
95 }
96
97 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
98
99 float16x8_t acc0 = vdupq_n_f16(0.f);
100 float16x8_t acc1 = vdupq_n_f16(0.f);
101 float16x8_t acc2 = vdupq_n_f16(0.f);
102 float16x8_t acc3 = vdupq_n_f16(0.f);
103
104 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
105 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
106 for(; vec_a <= (vec_a_end_addr - 4);)
107 {
108 const float16x4_t a0l = vld1_f16(vec_a);
109
110 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
111 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
112 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
113 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
114 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
115 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
116 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
117 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
118
119 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
120 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
121 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
122 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
123 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
124 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
125 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
126 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
127
128 matrix_b += 2 * in_b_stride;
129
130 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
131 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
132 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
133 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
134 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
135 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
136 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
137 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
138
139 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
140 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
141 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
142 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
143 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
144 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
145 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
146 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
147
148 vec_a += 4;
149 matrix_b += 2 * in_b_stride;
150 }
151
152 for(; vec_a < vec_a_end_addr; ++vec_a)
153 {
154 const float16_t a0 = *vec_a;
155 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
156 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
157 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
158 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
159
160 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
161 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
162 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
163 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
164
165 matrix_b += in_b_stride;
166 }
167
168 // Multiply by the weight of matrix product (alpha)
169 if(multiply_alpha)
170 {
171 acc0 = vmulq_f16(acc0, alpha_f16);
172 acc1 = vmulq_f16(acc1, alpha_f16);
173 acc2 = vmulq_f16(acc2, alpha_f16);
174 acc3 = vmulq_f16(acc3, alpha_f16);
175 }
176
177 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
178
179 vst1q_f16(vec_out + 0, acc0);
180 vst1q_f16(vec_out + 8, acc1);
181 vst1q_f16(vec_out + 16, acc2);
182 vst1q_f16(vec_out + 24, acc3);
Pablo Tello221f3812017-06-28 17:27:56 +0100183 }
184
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100185 for(; x < window_end_x; ++x)
Pablo Tello221f3812017-06-28 17:27:56 +0100186 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100187 if(x > width_matrix_b)
188 {
189 return;
190 }
Pablo Tello221f3812017-06-28 17:27:56 +0100191
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100192 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
Pablo Tello221f3812017-06-28 17:27:56 +0100193
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100194 float16x4_t vacc = vdup_n_f16(0.f);
Pablo Tello221f3812017-06-28 17:27:56 +0100195
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100196 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
197 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
198 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
199 {
200 const float16x4_t a0l = vld1_f16(vec_a);
Pablo Tello221f3812017-06-28 17:27:56 +0100201
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100202 const float16x4_t b_col =
203 {
204 *(matrix_b + 0 * in_b_stride),
205 *(matrix_b + 1 * in_b_stride),
206 *(matrix_b + 2 * in_b_stride),
207 *(matrix_b + 3 * in_b_stride),
208 };
Pablo Tello221f3812017-06-28 17:27:56 +0100209
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100210 vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
Pablo Tello221f3812017-06-28 17:27:56 +0100211
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100212 matrix_b += 4 * in_b_stride;
213 }
214
215 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
216
217 for(; vec_a < vec_a_end_addr; ++vec_a)
218 {
219 const float16_t a0 = *vec_a;
220 const float16_t b00 = *matrix_b;
221
222 acc += b00 * a0;
223
224 matrix_b += in_b_stride;
225 }
226
227 // Multiply by the weight of matrix product (alpha)
228 if(multiply_alpha)
229 {
230 acc *= static_cast<float16_t>(alpha);
231 }
232
233 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
234
235 *(vec_out) = acc;
Pablo Tello221f3812017-06-28 17:27:56 +0100236 }
Pablo Tello221f3812017-06-28 17:27:56 +0100237 },
238 ina, inb, out);
Pablo Tello221f3812017-06-28 17:27:56 +0100239}
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100240#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello221f3812017-06-28 17:27:56 +0100241
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100242void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243{
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100244 const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
245 const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()));
246 const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100247
248 // The implementation computes 16 elements per iteration
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100249 const int window_start_x = 16 * info.thread_id;
250 const int window_step_x = 16 * info.num_threads;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
252 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
253
254 Window win_out(window);
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100255 win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100256 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
257
258 Window win_a(window);
259 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
260 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
261
262 Window win_b;
263 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
264 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100265 if(rhs->info()->num_dimensions() >= 3)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266 {
267 win_b = window;
268 }
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100269 win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100270 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
271
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100272 Iterator ina(lhs, win_a);
273 Iterator inb(rhs, win_b);
274 Iterator out(dst, win_out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100275
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100276 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
277
278 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
279
280 execute_window_loop(win_out, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100281 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100282 int x = window_start_x;
283 // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100284 // window_end_x is computed above which may cause out-of-bound writes to the dst.
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100285 for(; x < (window_end_x - window_step_x); x += window_step_x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100287 if(x > width_matrix_b)
288 {
289 return;
290 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100291
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100292 float32x4_t acc0 = vdupq_n_f32(0.f);
293 float32x4_t acc1 = vdupq_n_f32(0.f);
294 float32x4_t acc2 = vdupq_n_f32(0.f);
295 float32x4_t acc3 = vdupq_n_f32(0.f);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100296
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100297 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
298 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100299
300#if __arm__
301 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100302 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
303 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100304#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100306 auto vec_a_end_addr = vec_a + num_elems_vec_a;
307 for(; vec_a <= (vec_a_end_addr - 4);)
308 {
309 float32x2_t a0l = vld1_f32(vec_a);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100310
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100311 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
312 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
313 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
314 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100315
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100316 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
317 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
318 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
319 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100320
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100321#if __arm__
322 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
323 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
324 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
325 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
326 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
327#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100328
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100329 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
330 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
331 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
332 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100334 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
335 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
336 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
337 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100339 vec_a += 2;
340 matrix_b += 2 * in_b_stride;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100342 a0l = vld1_f32(vec_a);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100343
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100344 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
345 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
346 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
347 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
348
349 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
350 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
351 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
352 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
353
354 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
355 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
356 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
357 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
358
359 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
360 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
361 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
362 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
363
364 vec_a += 2;
365 matrix_b += 2 * in_b_stride;
366 }
367
368 for(; vec_a < vec_a_end_addr; ++vec_a)
369 {
370 const float a0 = *vec_a;
371
372 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
373 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
374 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
375 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
376
377 acc0 = vmlaq_n_f32(acc0, b00, a0);
378 acc1 = vmlaq_n_f32(acc1, b01, a0);
379 acc2 = vmlaq_n_f32(acc2, b02, a0);
380 acc3 = vmlaq_n_f32(acc3, b03, a0);
381
382 matrix_b += in_b_stride;
383 }
384
385 // Multiply by the weight of matrix product (alpha)
386 if(multiply_alpha)
387 {
388 acc0 = vmulq_f32(acc0, alpha_f32);
389 acc1 = vmulq_f32(acc1, alpha_f32);
390 acc2 = vmulq_f32(acc2, alpha_f32);
391 acc3 = vmulq_f32(acc3, alpha_f32);
392 }
393
394 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
395
396 vst1q_f32(vec_out + 0, acc0);
397 vst1q_f32(vec_out + 4, acc1);
398 vst1q_f32(vec_out + 8, acc2);
399 vst1q_f32(vec_out + 12, acc3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400 }
401
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100402 // Left-over loop
403 for(; x < window_end_x; ++x)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100405 if(x > width_matrix_b)
406 {
407 return;
408 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100410 float32x4_t vacc = vdupq_n_f32(0.f);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100411
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100412 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
413 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100414
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100415#if __arm__
416 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
417 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
418 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
419#endif /* __arm__ */
420
421 auto vec_a_end_addr = vec_a + num_elems_vec_a;
422 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
423 {
424 const float32x4_t a0l = vld1q_f32(vec_a);
425
426 const float32x4_t b_col =
427 {
428 *(matrix_b + 0 * in_b_stride),
429 *(matrix_b + 1 * in_b_stride),
430 *(matrix_b + 2 * in_b_stride),
431 *(matrix_b + 3 * in_b_stride),
432 };
433
434#if __arm__
435 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
436 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
437 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
438 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
439 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
440#endif /* __arm__ */
441
442 vacc = vmlaq_f32(vacc, b_col, a0l);
443
444 matrix_b += 4 * in_b_stride;
445 }
446
447 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
448
449 for(; vec_a < vec_a_end_addr; ++vec_a)
450 {
451 const float a0 = *vec_a;
452
453 const float b00 = *matrix_b;
454
455 acc += b00 * a0;
456
457 matrix_b += in_b_stride;
458 }
459
460 // Multiply by the weight of matrix product (alpha)
461 if(multiply_alpha)
462 {
463 acc *= alpha;
464 }
465
466 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
467
468 *vec_out = acc;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100470 },
471 ina, inb, out);
472}
473
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100474void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100475{
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100476 ARM_COMPUTE_UNUSED(info);
477 const int out_width = static_cast<int>(dst->info()->dimension(0));
478 const int out_height = static_cast<int>(dst->info()->dimension(1));
479 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
480 const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100481 const size_t out_stride2 = out_stride1 * 2;
482 const size_t out_stride3 = out_stride1 * 3;
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100483 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100484
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100485 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100486 Window win_a(window);
487 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
488 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
489
490 Window win_b;
491 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
492 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100493 if(rhs->info()->num_dimensions() >= 3)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100494 {
495 win_b = window;
496 }
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100497 // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100498 // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
499 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
500 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
501
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100502 Iterator ina(lhs, win_a);
503 Iterator inb(rhs, win_b);
504 Iterator out(dst, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100505
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100506 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
507
508 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
509
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100510 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100511 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
512 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100513 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 {
515 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
516 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
517 auto mtx_b1 = mtx_b0 + in_b_stride;
518
519 float32x4_t acc00 = vdupq_n_f32(0.f);
520 float32x4_t acc10 = vdupq_n_f32(0.f);
521 float32x4_t acc20 = vdupq_n_f32(0.f);
522 float32x4_t acc30 = vdupq_n_f32(0.f);
523
524 float32x4_t acc01 = vdupq_n_f32(0.f);
525 float32x4_t acc11 = vdupq_n_f32(0.f);
526 float32x4_t acc21 = vdupq_n_f32(0.f);
527 float32x4_t acc31 = vdupq_n_f32(0.f);
528
529#if __arm__
530 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
531 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
532 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100533#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100534
535 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
536 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
537 {
538 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
539 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
540 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
541 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
542
543 float32x4_t b00 = vld1q_f32(mtx_b0);
544 float32x4_t b10 = vld1q_f32(mtx_b1);
545 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
546 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
547
548#if __arm__
549 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
550 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
551 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100552#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100553
554 // 4x4 block 0
555 acc00 = vmlaq_f32(acc00, b00, a0);
556 acc10 = vmlaq_f32(acc10, b00, a1);
557 acc20 = vmlaq_f32(acc20, b00, a2);
558 acc30 = vmlaq_f32(acc30, b00, a3);
559
560 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
561 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
562 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
563 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
564
565 // 4x4 block 1
566 acc01 = vmlaq_f32(acc01, b10, a0);
567 acc11 = vmlaq_f32(acc11, b10, a1);
568 acc21 = vmlaq_f32(acc21, b10, a2);
569 acc31 = vmlaq_f32(acc31, b10, a3);
570
571 // 4x4 block 0
572 acc00 = vmlaq_f32(acc00, b01, a4);
573 acc10 = vmlaq_f32(acc10, b01, a5);
574 acc20 = vmlaq_f32(acc20, b01, a6);
575 acc30 = vmlaq_f32(acc30, b01, a7);
576
577 // 4x4 block 1
578 acc01 = vmlaq_f32(acc01, b11, a4);
579 acc11 = vmlaq_f32(acc11, b11, a5);
580 acc21 = vmlaq_f32(acc21, b11, a6);
581 acc31 = vmlaq_f32(acc31, b11, a7);
582
583 mtx_a0 += 8;
584 mtx_b0 += 8;
585 mtx_b1 += 8;
586
587 a0 = vld1q_dup_f32(mtx_a0 + 0);
588 a1 = vld1q_dup_f32(mtx_a0 + 1);
589 a2 = vld1q_dup_f32(mtx_a0 + 2);
590 a3 = vld1q_dup_f32(mtx_a0 + 3);
591
592 b00 = vld1q_f32(mtx_b0);
593 b10 = vld1q_f32(mtx_b1);
594 b01 = vld1q_f32(mtx_b0 + 4);
595 b11 = vld1q_f32(mtx_b1 + 4);
596
597 // 4x4 block 0
598 acc00 = vmlaq_f32(acc00, b00, a0);
599 acc10 = vmlaq_f32(acc10, b00, a1);
600 acc20 = vmlaq_f32(acc20, b00, a2);
601 acc30 = vmlaq_f32(acc30, b00, a3);
602
603 a4 = vld1q_dup_f32(mtx_a0 + 4);
604 a5 = vld1q_dup_f32(mtx_a0 + 5);
605 a6 = vld1q_dup_f32(mtx_a0 + 6);
606 a7 = vld1q_dup_f32(mtx_a0 + 7);
607
608 // 4x4 block 1
609 acc01 = vmlaq_f32(acc01, b10, a0);
610 acc11 = vmlaq_f32(acc11, b10, a1);
611 acc21 = vmlaq_f32(acc21, b10, a2);
612 acc31 = vmlaq_f32(acc31, b10, a3);
613
614 // 4x4 block 0
615 acc00 = vmlaq_f32(acc00, b01, a4);
616 acc10 = vmlaq_f32(acc10, b01, a5);
617 acc20 = vmlaq_f32(acc20, b01, a6);
618 acc30 = vmlaq_f32(acc30, b01, a7);
619
620 // 4x4 block 1
621 acc01 = vmlaq_f32(acc01, b11, a4);
622 acc11 = vmlaq_f32(acc11, b11, a5);
623 acc21 = vmlaq_f32(acc21, b11, a6);
624 acc31 = vmlaq_f32(acc31, b11, a7);
625
626 mtx_a0 += 8;
627 mtx_b0 += 8;
628 mtx_b1 += 8;
629
630 a0 = vld1q_dup_f32(mtx_a0 + 0);
631 a1 = vld1q_dup_f32(mtx_a0 + 1);
632 a2 = vld1q_dup_f32(mtx_a0 + 2);
633 a3 = vld1q_dup_f32(mtx_a0 + 3);
634 b00 = vld1q_f32(mtx_b0);
635 b10 = vld1q_f32(mtx_b1);
636 b01 = vld1q_f32(mtx_b0 + 4);
637 b11 = vld1q_f32(mtx_b1 + 4);
638
639#if __arm__
640 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
641 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
642 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100643#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100644
645 // 4x4 block 0
646 acc00 = vmlaq_f32(acc00, b00, a0);
647 acc10 = vmlaq_f32(acc10, b00, a1);
648 acc20 = vmlaq_f32(acc20, b00, a2);
649 acc30 = vmlaq_f32(acc30, b00, a3);
650
651 a4 = vld1q_dup_f32(mtx_a0 + 4);
652 a5 = vld1q_dup_f32(mtx_a0 + 5);
653 a6 = vld1q_dup_f32(mtx_a0 + 6);
654 a7 = vld1q_dup_f32(mtx_a0 + 7);
655
656 // 4x4 block 1
657 acc01 = vmlaq_f32(acc01, b10, a0);
658 acc11 = vmlaq_f32(acc11, b10, a1);
659 acc21 = vmlaq_f32(acc21, b10, a2);
660 acc31 = vmlaq_f32(acc31, b10, a3);
661
662 // 4x4 block 0
663 acc00 = vmlaq_f32(acc00, b01, a4);
664 acc10 = vmlaq_f32(acc10, b01, a5);
665 acc20 = vmlaq_f32(acc20, b01, a6);
666 acc30 = vmlaq_f32(acc30, b01, a7);
667
668 // 4x4 block 1
669 acc01 = vmlaq_f32(acc01, b11, a4);
670 acc11 = vmlaq_f32(acc11, b11, a5);
671 acc21 = vmlaq_f32(acc21, b11, a6);
672 acc31 = vmlaq_f32(acc31, b11, a7);
673
674 mtx_a0 += 8;
675 mtx_b0 += 8;
676 mtx_b1 += 8;
677
678 a0 = vld1q_dup_f32(mtx_a0 + 0);
679 a1 = vld1q_dup_f32(mtx_a0 + 1);
680 a2 = vld1q_dup_f32(mtx_a0 + 2);
681 a3 = vld1q_dup_f32(mtx_a0 + 3);
682 b00 = vld1q_f32(mtx_b0);
683 b10 = vld1q_f32(mtx_b1);
684 b01 = vld1q_f32(mtx_b0 + 4);
685 b11 = vld1q_f32(mtx_b1 + 4);
686
687 // 4x4 block 0
688 acc00 = vmlaq_f32(acc00, b00, a0);
689 acc10 = vmlaq_f32(acc10, b00, a1);
690 acc20 = vmlaq_f32(acc20, b00, a2);
691 acc30 = vmlaq_f32(acc30, b00, a3);
692
693 a4 = vld1q_dup_f32(mtx_a0 + 4);
694 a5 = vld1q_dup_f32(mtx_a0 + 5);
695 a6 = vld1q_dup_f32(mtx_a0 + 6);
696 a7 = vld1q_dup_f32(mtx_a0 + 7);
697
698 // 4x4 block 1
699 acc01 = vmlaq_f32(acc01, b10, a0);
700 acc11 = vmlaq_f32(acc11, b10, a1);
701 acc21 = vmlaq_f32(acc21, b10, a2);
702 acc31 = vmlaq_f32(acc31, b10, a3);
703
704 // 4x4 block 0
705 acc00 = vmlaq_f32(acc00, b01, a4);
706 acc10 = vmlaq_f32(acc10, b01, a5);
707 acc20 = vmlaq_f32(acc20, b01, a6);
708 acc30 = vmlaq_f32(acc30, b01, a7);
709
710 // 4x4 block 1
711 acc01 = vmlaq_f32(acc01, b11, a4);
712 acc11 = vmlaq_f32(acc11, b11, a5);
713 acc21 = vmlaq_f32(acc21, b11, a6);
714 acc31 = vmlaq_f32(acc31, b11, a7);
715
716 mtx_a0 += 8;
717 mtx_b0 += 8;
718 mtx_b1 += 8;
719 }
720
721 for(; mtx_b0 < mtx_b0_end_addr;)
722 {
723 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
724 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
725 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
726 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
727 float32x4_t b00 = vld1q_f32(mtx_b0);
728 float32x4_t b10 = vld1q_f32(mtx_b1);
729
730#if __arm__
731 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
732 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
733 asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
Anthony Barbierac69aa12017-07-03 17:39:37 +0100734#endif /* __arm__ */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100735 // 4x4 block 0
736 acc00 = vmlaq_f32(acc00, b00, a0);
737 acc10 = vmlaq_f32(acc10, b00, a1);
738 acc20 = vmlaq_f32(acc20, b00, a2);
739 acc30 = vmlaq_f32(acc30, b00, a3);
740
741 // 4x4 block 1
742 acc01 = vmlaq_f32(acc01, b10, a0);
743 acc11 = vmlaq_f32(acc11, b10, a1);
744 acc21 = vmlaq_f32(acc21, b10, a2);
745 acc31 = vmlaq_f32(acc31, b10, a3);
746
747 mtx_a0 += 4;
748 mtx_b0 += 4;
749 mtx_b1 += 4;
750 }
751
752 // Multiply by the weight of matrix product (alpha)
753 if(multiply_alpha)
754 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100755 acc00 = vmulq_f32(acc00, alpha_f32);
756 acc10 = vmulq_f32(acc10, alpha_f32);
757 acc20 = vmulq_f32(acc20, alpha_f32);
758 acc30 = vmulq_f32(acc30, alpha_f32);
759 acc01 = vmulq_f32(acc01, alpha_f32);
760 acc11 = vmulq_f32(acc11, alpha_f32);
761 acc21 = vmulq_f32(acc21, alpha_f32);
762 acc31 = vmulq_f32(acc31, alpha_f32);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100763 }
764
765 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
766 const auto mtx_out1 = mtx_out0 + 4;
767
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100768 if(id.x() < (out_width - 8))
769 {
770 vst1q_f32(mtx_out0, acc00);
771 vst1q_f32(mtx_out1, acc01);
772 if(id.y() + 1 < out_height)
773 {
774 vst1q_f32(mtx_out0 + out_stride1, acc10);
775 vst1q_f32(mtx_out1 + out_stride1, acc11);
776 if(id.y() + 2 < out_height)
777 {
778 vst1q_f32(mtx_out0 + out_stride2, acc20);
779 vst1q_f32(mtx_out1 + out_stride2, acc21);
780 if(id.y() + 3 < out_height)
781 {
782 vst1q_f32(mtx_out0 + out_stride3, acc30);
783 vst1q_f32(mtx_out1 + out_stride3, acc31);
784 }
785 }
786 }
787 }
788 else if(id.x() < (out_width - 4))
789 {
790 vst1q_f32(mtx_out0, acc00);
791 if(id.y() + 1 < out_height)
792 {
793 vst1q_f32(mtx_out0 + out_stride1, acc10);
794 if(id.y() + 2 < out_height)
795 {
796 vst1q_f32(mtx_out0 + out_stride2, acc20);
797 if(id.y() + 3 < out_height)
798 {
799 vst1q_f32(mtx_out0 + out_stride3, acc30);
800 }
801 }
802 }
803 // Left-over columns
804 const int columns_left = out_width - id.x() - 4;
805 for(auto x = 0; x < columns_left; ++x)
806 {
807 *(mtx_out1 + x) = acc01[x];
808 if(id.y() + 1 < out_height)
809 {
810 *(mtx_out1 + x + out_stride1) = acc11[x];
811 if(id.y() + 2 < out_height)
812 {
813 *(mtx_out1 + x + out_stride2) = acc21[x];
814 if(id.y() + 3 < out_height)
815 {
816 *(mtx_out1 + x + out_stride3) = acc31[x];
817 }
818 }
819 }
820 }
821 }
822 else
823 {
824 // Left-over columns
825 const int columns_left = out_width - id.x();
826 for(int x = 0; x < columns_left; ++x)
827 {
828 *(mtx_out0 + x) = acc00[x];
829 if(id.y() + 1 < out_height)
830 {
831 *(mtx_out0 + x + out_stride1) = acc10[x];
832 if(id.y() + 2 < out_height)
833 {
834 *(mtx_out0 + x + out_stride2) = acc20[x];
835 if(id.y() + 3 < out_height)
836 {
837 *(mtx_out0 + x + out_stride3) = acc30[x];
838 }
839 }
840 }
841 }
842 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100843 },
844 ina, inb, out);
845}
846
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100847#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100848void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100849{
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100850 ARM_COMPUTE_UNUSED(info);
851 const int out_width = static_cast<int>(dst->info()->dimension(0));
852 const int out_height = static_cast<int>(dst->info()->dimension(1));
853 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
854 const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
855 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100856
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100857 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100858 Window win_a(window);
859 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
860 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
861
862 Window win_b;
863 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
864 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100865 if(rhs->info()->num_dimensions() >= 3)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100866 {
867 win_b = window;
868 }
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100869 // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100870 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
871 win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
872
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100873 Iterator ina(lhs, win_a);
874 Iterator inb(rhs, win_b);
875 Iterator out(dst, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100877 const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
878
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100879 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
880
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100881 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100882 {
883 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
884 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
885 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
886 float16x8x4_t c =
887 {
888 {
889 vdupq_n_f16(0.f),
890 vdupq_n_f16(0.f),
891 vdupq_n_f16(0.f),
892 vdupq_n_f16(0.f)
893 }
894 };
895
896 /*
897 This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
898 |a00 a01 a02 a03 | a04 a05 a06 a07|
899 |a10 a11 a12 a13 | a14 a15 a16 a17|
900 |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
901 |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
902 |a40 a41 a42 a43 | a44 a45 a46 a47|
903 |a50 a51 a52 a53 | a54 a55 a56 a57|
904 |a60 a61 a62 a63 | a64 a65 a66 a67|
905 |a70 a71 a72 a73 | a74 a75 a76 a77|
906
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100907 After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ]
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100908
909 B Matrix has been transposed as shown below
910
911 |b00 b01 b02 b03 b04 b05 b06 b07|
912 |b10 b11 b12 b13 b14 b15 b16 b17|
913 |b20 b21 b22 b23 b24 b25 b26 b27|
914 |b30 b31 b32 b33 b34 b35 b36 b37|
915 ------------------->
916
917 |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
918
919 c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
920 c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
921
Michele Di Giorgio53832b22021-06-21 14:45:44 +0100922 The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100923 */
Pablo Tello221f3812017-06-28 17:27:56 +0100924 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
925
926 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
927
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100928 {
929 const float16x8_t p00 = vld1q_f16(mtx_a0);
930 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
Pablo Tello221f3812017-06-28 17:27:56 +0100931
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100932 const float16x8_t q00 = vld1q_f16(mtx_b0);
933 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
934 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
935 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
936
937 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
938 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
939 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
940 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
941
942 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
943 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
944 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
945 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
946
947 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
948 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
949 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
950 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
951
952 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
953 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
954 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
955 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
Pablo Tello221f3812017-06-28 17:27:56 +0100956
957 mtx_a0 += 16;
958 mtx_b0 += 32;
959 }
960
961 for(; mtx_b0 < mtx_b0_end_addr;)
962
963 {
964 const float16x4_t p00 = vld1_f16(mtx_a0);
965 const float16x8_t q00 = vld1q_f16(mtx_b0);
966
967 c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
968 c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
969 c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
970 c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
971
972 mtx_a0 += 4;
973 mtx_b0 += 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100974 }
975
976 if(multiply_alpha)
977 {
978 c.val[0] = vmulq_f16(c.val[0], alpha_f16);
979 c.val[1] = vmulq_f16(c.val[1], alpha_f16);
980 c.val[2] = vmulq_f16(c.val[2], alpha_f16);
981 c.val[3] = vmulq_f16(c.val[3], alpha_f16);
982 }
983
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +0100984 if(id.x() < (out_width - 8))
985 {
986 vst1q_f16(mtx_out, c.val[0]);
987 if(id.y() + 1 < out_height)
988 {
989 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
990 if(id.y() + 2 < out_height)
991 {
992 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
993 if(id.y() + 3 < out_height)
994 {
995 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
996 }
997 }
998 }
999 }
1000 else
1001 {
1002 // Left-over columns
1003 const int columns_left = out_width - id.x();
1004 for(int x = 0; x < columns_left; ++x)
1005 {
1006 *(mtx_out + x) = c.val[0][x];
1007 if(id.y() + 1 < out_height)
1008 {
1009 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1010 if(id.y() + 2 < out_height)
1011 {
1012 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1013 if(id.y() + 3 < out_height)
1014 {
1015 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1016 }
1017 }
1018 }
1019 }
1020 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001021 },
1022 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023}
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001024#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001025
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001026inline Status validate_arguments(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001027{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001028 ARM_COMPUTE_UNUSED(alpha);
1029
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001030 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
1031 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32);
1032 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001033
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001034 if(!is_interleaved)
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001035 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001036 ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(0) != rhs->dimension(1));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001037
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001038 if(dst->total_size() != 0)
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001039 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001040 ARM_COMPUTE_RETURN_ERROR_ON(rhs->dimension(0) != dst->dimension(0));
1041 ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(1) != dst->dimension(1));
1042 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001043 }
1044 }
1045 else
1046 {
1047 const int m = reshape_info.m();
1048 const int n = reshape_info.n();
1049 const int k = reshape_info.k();
1050 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1051 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1052
1053 /* Interleave */
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001054 TensorShape tensor_shape0{ lhs->tensor_shape() };
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001055 tensor_shape0.set(0, k);
1056 tensor_shape0.set(1, m);
1057
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001058 const TensorInfo tensor_info0 = lhs->clone()->set_tensor_shape(tensor_shape0);
1059 const TensorInfo tensor_info_reshaped0 = lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
1060 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lhs, &tensor_info_reshaped0);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001061
1062 if(n != 0) /* Transpose */
1063 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001064 TensorShape tensor_shape1{ rhs->tensor_shape() };
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001065 tensor_shape1.set(0, n);
1066 tensor_shape1.set(1, k);
1067
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001068 const TensorInfo tensor_info1 = rhs->clone()->set_tensor_shape(tensor_shape1);
1069 const TensorInfo tensor_info_reshaped1 = rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
1070 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(rhs, &tensor_info_reshaped1);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001071 }
1072
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001073 if(dst->total_size() != 0)
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001074 {
1075 if(n != 0)
1076 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001077 ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(0) != static_cast<size_t>(n));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001078 }
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001079 ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(1) != static_cast<size_t>(m));
1080 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001081 }
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001082 }
1083
1084 return Status{};
1085}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001086} // namespace
1087
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001088void CpuGemmMatrixMultiplyKernel::configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001089{
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001090 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001091
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001092 // dst tensor auto inizialitation if not yet initialized
1093 TensorShape tensor_shape{ lhs->tensor_shape() };
1094 tensor_shape.set(0, is_interleaved ? reshape_info.n() : rhs->dimension(0));
1095 tensor_shape.set(1, is_interleaved ? reshape_info.m() : lhs->dimension(1));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001096
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001097 auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(tensor_shape));
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00001098
1099 // Perform validate step
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001100 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001101
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001102 _alpha = alpha;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001103
Giorgio Arena7c23ad02017-11-30 15:08:38 +00001104 // Configure kernel window
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001105 Window win{};
1106
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001107 // Check if the dst tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1108 const bool is_dst_vector = (dst->dimension(1) == 1);
1109 if(is_dst_vector)
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001110 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001111 const unsigned int num_elems_processed_per_iteration_x = (lhs->data_type() == DataType::F32) ? 16 : 32;
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001112
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001113 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001114 }
1115 else
1116 {
1117 constexpr unsigned int num_elems_processed_per_iteration_x = 8;
1118 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
1119
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001120 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001121 }
1122
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001123 switch(lhs->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001124 {
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001125 case DataType::F32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001126 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001127 _func = (is_dst_vector) ? vector_matrix_multiply_f32 : matrix_matrix_multiply_f32;
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001128 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001129 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001130#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001131 case DataType::F16:
1132 {
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001133 _func = (is_dst_vector) ? vector_matrix_multiply_f16 : matrix_matrix_multiply_f16;
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001134 break;
1135 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001136#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001137 default:
1138 {
1139 ARM_COMPUTE_ERROR("Data type not supported");
1140 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001141 }
1142 }
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001143 ICPPKernel::configure(win);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001144}
Michele Di Giorgio53832b22021-06-21 14:45:44 +01001145
1146Status CpuGemmMatrixMultiplyKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved,
1147 const GEMMReshapeInfo &reshape_info)
1148{
1149 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info));
1150
1151 return Status{};
1152}
1153
1154void CpuGemmMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
1155{
1156 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1157 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
1158 ARM_COMPUTE_ERROR_ON(tensors.empty());
1159 ARM_COMPUTE_ERROR_ON(_func == nullptr);
1160
1161 const ITensor *lhs = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1162 const ITensor *rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1163 ITensor *dst = tensors.get_tensor(TensorType::ACL_DST);
1164
1165 (*_func)(lhs, rhs, dst, window, info, _alpha);
1166}
1167
1168const char *CpuGemmMatrixMultiplyKernel::name() const
1169{
1170 return "CpuGemmMatrixMultiplyKernel";
1171}
1172} // namespace kernels
1173} // namespace cpu
Michele Di Giorgiocf9e29e2020-10-08 11:54:42 +01001174} // namespace arm_compute