blob: f8bef64066644cc6fd1ec8811dbd2ac26b13b259 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
SiCongLib88272e2021-02-24 15:40:57 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/helpers/AutoConfiguration.h"
35#include "src/core/helpers/WindowHelpers.h"
36
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039namespace arm_compute
40{
Manuel Bottinicfac51c2021-06-18 15:47:28 +010041namespace cpu
42{
43namespace kernels
44{
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000045namespace
46{
morgolock4adaddb2020-09-29 14:24:32 +010047void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000048{
Gian Marcoc7f9b892017-11-30 14:31:13 +000049 execute_window_loop(window, [&](const Coordinates & id)
50 {
51 if(id.x() > width_b)
52 {
53 return;
54 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000055
Gian Marcoc7f9b892017-11-30 14:31:13 +000056 // Note: Since the input are all positives, we can use uint32_t
57 // Accumulators for the block 0
58 uint32x4x4_t c0 =
59 {
60 {
61 vdupq_n_u32(0),
62 vdupq_n_u32(0),
63 vdupq_n_u32(0),
64 vdupq_n_u32(0)
65 }
66 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000067
Gian Marcoc7f9b892017-11-30 14:31:13 +000068 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
69 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
70 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000071
Gian Marcoc7f9b892017-11-30 14:31:13 +000072 // This for loop performs 8 accumulations
73 for(; vec_a <= (vec_a_end_addr - 8);)
74 {
75 const uint8x8_t a00_u8 = vld1_u8(vec_a);
76 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000084
Gian Marcoc7f9b892017-11-30 14:31:13 +000085 // Convert a00_u8 to uint16_t and get the lower part
86 const uint16x4x2_t a00_u16 =
87 {
88 {
89 vget_low_u16(vmovl_u8(a00_u8)),
90 vget_high_u16(vmovl_u8(a00_u8))
91 }
92 };
93
94 const uint16x4x4_t b00_u16 =
95 {
96 {
97 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101 }
102 };
103
104 const uint16x4x4_t b10_u16 =
105 {
106 {
107 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111 }
112 };
113
114 const uint16x4x4_t b20_u16 =
115 {
116 {
117 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121 }
122 };
123
124 const uint16x4x4_t b30_u16 =
125 {
126 {
127 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131 }
132 };
133
134 const uint16x4x4_t b40_u16 =
135 {
136 {
137 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141 }
142 };
143
144 const uint16x4x4_t b50_u16 =
145 {
146 {
147 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151 }
152 };
153
154 const uint16x4x4_t b60_u16 =
155 {
156 {
157 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161 }
162 };
163
164 const uint16x4x4_t b70_u16 =
165 {
166 {
167 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171 }
172 };
173
174 // Accumulate 0:
175 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179
180 // Accumulate 1:
181 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185
186 // Accumulate 2:
187 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191
192 // Accumulate 3:
193 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197
198 // Accumulate 4:
199 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203
204 // Accumulate 5:
205 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209
210 // Accumulate 6:
211 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215
216 // Accumulate 7:
217 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221
222 vec_a += 8;
223 matrix_b += 8 * stride_b;
224 }
225
226 // This for loop performs the left-over accumulations
227 for(; vec_a < vec_a_end_addr;)
228 {
229 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231
232 const uint16x4x4_t b00_u16 =
233 {
234 {
235 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239 }
240 };
241
242 // Convert a00_u8 to uint16_t and get the lower part
243 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244
245 // Accumulate 0:
246 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250
251 vec_a += 1;
252 matrix_b += stride_b;
253 }
254
255 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100256 if(id.x() < (width_out - 16))
257 {
258 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
259 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
260 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
261 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
262 }
263 else
264 {
265 auto left_over = width_out - id.x();
266 for(auto k = 0; k < 4 && left_over; ++k)
267 {
268 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
269 {
270 *(vec_out + k * 4 + j) = c0.val[k][j];
271 }
272 }
273 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000274 },
275 ina, inb, out);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000276}
277
morgolock4adaddb2020-09-29 14:24:32 +0100278void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000279{
Gian Marcoc7f9b892017-11-30 14:31:13 +0000280 execute_window_loop(window, [&](const Coordinates & id)
281 {
282 if(id.x() > width_b)
283 {
284 return;
285 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000286
Gian Marcoc7f9b892017-11-30 14:31:13 +0000287 // Accumulators for the block 0
288 int32x4x4_t c0 =
289 {
290 {
291 vdupq_n_s32(0),
292 vdupq_n_s32(0),
293 vdupq_n_s32(0),
294 vdupq_n_s32(0)
295 }
296 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000297
Gian Marcoc7f9b892017-11-30 14:31:13 +0000298 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
299 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
300 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000301
Gian Marcoc7f9b892017-11-30 14:31:13 +0000302 // This for loop performs 8 accumulations
303 for(; vec_a <= (vec_a_end_addr - 8);)
304 {
305 const int8x8_t a00_s8 = vld1_s8(vec_a);
306 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
307 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
308 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
309 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
310 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
311 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
312 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
313 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000314
Gian Marcoc7f9b892017-11-30 14:31:13 +0000315 // Convert a00_s8 to int16_t and get the lower part
316 const int16x4x2_t a00_s16 =
317 {
318 {
319 vget_low_s16(vmovl_s8(a00_s8)),
320 vget_high_s16(vmovl_s8(a00_s8))
321 }
322 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000323
Gian Marcoc7f9b892017-11-30 14:31:13 +0000324 const int16x4x4_t b00_s16 =
325 {
326 {
327 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
328 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
329 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
330 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
331 }
332 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000333
Gian Marcoc7f9b892017-11-30 14:31:13 +0000334 const int16x4x4_t b10_s16 =
335 {
336 {
337 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
338 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
339 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
340 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
341 }
342 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100343
Gian Marcoc7f9b892017-11-30 14:31:13 +0000344 const int16x4x4_t b20_s16 =
345 {
346 {
347 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
348 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
349 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
350 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
351 }
352 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100353
Gian Marcoc7f9b892017-11-30 14:31:13 +0000354 const int16x4x4_t b30_s16 =
355 {
356 {
357 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
358 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
359 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
360 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
361 }
362 };
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100363
Gian Marcoc7f9b892017-11-30 14:31:13 +0000364 const int16x4x4_t b40_s16 =
365 {
366 {
367 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
368 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
369 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
370 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
371 }
372 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100373
Gian Marcoc7f9b892017-11-30 14:31:13 +0000374 const int16x4x4_t b50_s16 =
375 {
376 {
377 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
378 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
379 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
380 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
381 }
382 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100383
Gian Marcoc7f9b892017-11-30 14:31:13 +0000384 const int16x4x4_t b60_s16 =
385 {
386 {
387 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
388 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
389 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
390 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
391 }
392 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100393
Gian Marcoc7f9b892017-11-30 14:31:13 +0000394 const int16x4x4_t b70_s16 =
395 {
396 {
397 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
398 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
399 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
400 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
401 }
402 };
403
404 // Accumulate 0:
405 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
406 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
407 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
408 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
409
410 // Accumulate 1:
411 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
412 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
413 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
414 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
415
416 // Accumulate 2:
417 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
418 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
419 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
420 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
421
422 // Accumulate 3:
423 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
424 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
425 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
426 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
427
428 // Accumulate 4:
429 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
430 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
431 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
432 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
433
434 // Accumulate 5:
435 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
436 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
437 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
438 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
439
440 // Accumulate 6:
441 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
442 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
443 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
444 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
445
446 // Accumulate 7:
447 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
448 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
449 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
450 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
451
452 vec_a += 8;
453 matrix_b += 8 * stride_b;
454 }
455
456 // This for loop performs the left-over accumulations
457 for(; vec_a < vec_a_end_addr;)
458 {
459 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
460 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
461
462 const int16x4x4_t b00_s16 =
463 {
464 {
465 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
466 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
467 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
468 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
469 }
470 };
471
472 // Convert a00_s8 to uint16_t and get the lower part
473 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
474
475 // Accumulate 0:
476 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
477 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
478 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
479 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
480
481 vec_a += 1;
482 matrix_b += stride_b;
483 }
484
485 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100486 if(id.x() < (width_out - 16))
487 {
488 vst1q_s32(vec_out + 0, c0.val[0]);
489 vst1q_s32(vec_out + 4, c0.val[1]);
490 vst1q_s32(vec_out + 8, c0.val[2]);
491 vst1q_s32(vec_out + 12, c0.val[3]);
492 }
493 else
494 {
495 auto left_over = width_out - id.x();
496 for(auto k = 0; k < 4 && left_over; ++k)
497 {
498 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
499 {
500 *(vec_out + k * 4 + j) = c0.val[k][j];
501 }
502 }
503 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000504 },
505 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100506}
507
morgolock4adaddb2020-09-29 14:24:32 +0100508void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100509{
morgolock4adaddb2020-09-29 14:24:32 +0100510 const auto width_out = static_cast<int>(out_info.dimension(0));
511 const auto height_out = static_cast<int>(out_info.dimension(1));
512 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
513 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000515 const uint8_t *mtx_a0 = ina.ptr();
516 const uint8_t *mtx_b0 = inb.ptr();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100517
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100518 // Note: Since the input are all positives, we can use uint32_t
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519 // Accumulators for the block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000520 uint32x4x4_t c0 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100521 {
522 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000523 vdupq_n_u32(0),
524 vdupq_n_u32(0),
525 vdupq_n_u32(0),
526 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100527 }
528 };
529
530 // Accumulators for the block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000531 uint32x4x4_t c1 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532 {
533 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000534 vdupq_n_u32(0),
535 vdupq_n_u32(0),
536 vdupq_n_u32(0),
537 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100538 }
539 };
540
541 // Accumulators for the block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000542 uint32x4x4_t c2 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100543 {
544 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000545 vdupq_n_u32(0),
546 vdupq_n_u32(0),
547 vdupq_n_u32(0),
548 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100549 }
550 };
551
552 // Accumulators for the block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000553 uint32x4x4_t c3 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100554 {
555 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000556 vdupq_n_u32(0),
557 vdupq_n_u32(0),
558 vdupq_n_u32(0),
559 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100560 }
561 };
562
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100563 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100564 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000565 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
566 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567
Gian Marcoc7f9b892017-11-30 14:31:13 +0000568 // Convert a00_u8 to uint16_t and get the lower part
Gian Marcoe75a02b2017-11-08 12:24:09 +0000569 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570
Gian Marcoe75a02b2017-11-08 12:24:09 +0000571 // Convert b00_s8 to uint16_t
572 const uint16x4x4_t b00_u16 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100573 {
574 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000575 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
576 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
577 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
578 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100579 }
580 };
581
582 // 4x4 block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000583 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
584 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
585 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
586 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100587
588 // 4x4 block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000589 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
590 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
591 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
592 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100593
594 // 4x4 block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000595 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
596 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
597 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
598 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100599
600 // 4x4 block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000601 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
602 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
603 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
604 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100605 }
606
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100607 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100608
609 if(id.y() < height_out && id.x() < (width_out - 16))
610 {
611 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
612 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
613 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
614 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
615 if(id.y() + 1 < height_out)
616 {
617 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
618 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
619 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
620 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
621 if(id.y() + 2 < height_out)
622 {
623 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
624 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
625 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
626 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
627 if(id.y() + 3 < height_out)
628 {
629 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
630 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
631 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
632 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
633 }
634 }
635 }
636 }
637 else
638 {
639 const auto left_over_value = width_out - id.x();
640 auto left_over = left_over_value;
641 for(auto k = 0; k < 4 && left_over; ++k)
642 {
643 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
644 {
645 *(mtx_out + k * 4 + j) = c0.val[k][j];
646 }
647 }
648 if(id.y() + 1 < height_out)
649 {
650 left_over = left_over_value;
651 for(auto k = 0; k < 4 && left_over; ++k)
652 {
653 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
654 {
655 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
656 }
657 }
658 if(id.y() + 2 < height_out)
659 {
660 left_over = left_over_value;
661 for(auto k = 0; k < 4 && left_over; ++k)
662 {
663 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
664 {
665 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
666 }
667 }
668 if(id.y() + 3 < height_out)
669 {
670 left_over = left_over_value;
671 for(auto k = 0; k < 4 && left_over; ++k)
672 {
673 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
674 {
675 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
676 }
677 }
678 }
679 }
680 }
681 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100682 },
683 ina, inb, out);
684}
Pablo Tello181e6512017-11-15 13:28:27 +0000685
morgolock4adaddb2020-09-29 14:24:32 +0100686void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Pablo Tello181e6512017-11-15 13:28:27 +0000687{
morgolock4adaddb2020-09-29 14:24:32 +0100688 const auto width_out = static_cast<int>(out_info.dimension(0));
689 const auto height_out = static_cast<int>(out_info.dimension(1));
690 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
Michele Di Giorgio93b75e02021-06-21 12:00:43 +0100691 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
Pablo Tello181e6512017-11-15 13:28:27 +0000692 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
693 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
morgolock4adaddb2020-09-29 14:24:32 +0100694 execute_window_loop(window, [&](const Coordinates & id)
Pablo Tello181e6512017-11-15 13:28:27 +0000695 {
696 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
697 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
698
699 // Note: Since the input are all positives, we can use uint32_t
700 // Accumulators for the block 0
701 int32x4x4_t c0 =
702 {
703 {
704 vdupq_n_s32(0),
705 vdupq_n_s32(0),
706 vdupq_n_s32(0),
707 vdupq_n_s32(0)
708 }
709 };
710
711 // Accumulators for the block 1
712 int32x4x4_t c1 =
713 {
714 {
715 vdupq_n_s32(0),
716 vdupq_n_s32(0),
717 vdupq_n_s32(0),
718 vdupq_n_s32(0)
719 }
720 };
721
722 // Accumulators for the block 2
723 int32x4x4_t c2 =
724 {
725 {
726 vdupq_n_s32(0),
727 vdupq_n_s32(0),
728 vdupq_n_s32(0),
729 vdupq_n_s32(0)
730 }
731 };
732
733 // Accumulators for the block 3
734 int32x4x4_t c3 =
735 {
736 {
737 vdupq_n_s32(0),
738 vdupq_n_s32(0),
739 vdupq_n_s32(0),
740 vdupq_n_s32(0)
741 }
742 };
743
744 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
745 {
746 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
747 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
748
749 // Convert a00_s8 to uint16_t and get the lower part
750 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
751
752 // Convert b00_s8 to int16_t
753 const int16x4x4_t b00_s16 =
754 {
755 {
756 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
757 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
758 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
759 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
760 }
761 };
762
763 // 4x4 block 0
764 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
765 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
766 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
767 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
768
769 // 4x4 block 1
770 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
771 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
772 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
773 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
774
775 // 4x4 block 2
776 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
777 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
778 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
779 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
780
781 // 4x4 block 3
782 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
783 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
784 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
785 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
786 }
Pablo Tello181e6512017-11-15 13:28:27 +0000787 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100788 if(id.y() < height_out && id.x() < (width_out - 16))
789 {
790 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
791 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
792 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
793 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
794 if(id.y() + 1 < height_out)
795 {
796 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
797 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
798 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
799 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
800 if(id.y() + 2 < height_out)
801 {
802 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
803 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
804 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
805 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
806 if(id.y() + 3 < height_out)
807 {
808 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
809 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
810 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
811 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
812 }
813 }
814 }
815 }
816 else if(id.y() < height_out)
817 {
818 const auto left_over_value = width_out - id.x();
819 auto left_over = left_over_value;
820 for(auto k = 0; k < 4 && left_over; ++k)
821 {
822 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
823 {
824 *(mtx_out + k * 4 + j) = c0.val[k][j];
825 }
826 }
827 if(id.y() + 1 < height_out)
828 {
829 left_over = left_over_value;
830 for(auto k = 0; k < 4 && left_over; ++k)
831 {
832 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
833 {
834 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
835 }
836 }
837 if(id.y() + 2 < height_out)
838 {
839 left_over = left_over_value;
840 for(auto k = 0; k < 4 && left_over; ++k)
841 {
842 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
843 {
844 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
845 }
846 }
847 if(id.y() + 3 < height_out)
848 {
849 left_over = left_over_value;
850 for(auto k = 0; k < 4 && left_over; ++k)
851 {
852 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
853 {
854 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
855 }
856 }
857 }
858 }
859 }
860 }
861
Pablo Tello181e6512017-11-15 13:28:27 +0000862 },
863 ina, inb, out);
864}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000865
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100866Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000867{
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100868 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
869 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
870 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000871
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100872 TensorShape in0_shape = src0->tensor_shape();
873 TensorShape in1_shape = src1->tensor_shape();
874 TensorShape out_shape = dst->tensor_shape();
Gian Marcoc7f9b892017-11-30 14:31:13 +0000875
876 // Check vector-by-matrix case
877 if(out_shape[1] == 1)
878 {
879 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
880 }
881 else
882 {
883 in0_shape.collapse(2);
884 in1_shape.collapse(2);
885 out_shape.collapse(2);
886
887 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
888 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000889 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
Gian Marcoc7f9b892017-11-30 14:31:13 +0000890 }
891
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000892 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000893}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000894} // namespace
895
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100896void CpuGemmLowpMatrixMultiplyKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000897{
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100898 ARM_COMPUTE_UNUSED(src0);
899 ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
900 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, dst));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000901
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100902 TensorShape in1_shape = src1->tensor_shape();
Gian Marcoc7f9b892017-11-30 14:31:13 +0000903 in1_shape.collapse(2);
904
Gian Marcoc7f9b892017-11-30 14:31:13 +0000905 _slide_matrix_b = in1_shape[2] != 1;
906
morgolock4adaddb2020-09-29 14:24:32 +0100907 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
908 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
909
910 Window win;
morgolock4adaddb2020-09-29 14:24:32 +0100911 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100912 if((dst->dimension(1) == 1))
morgolock4adaddb2020-09-29 14:24:32 +0100913 {
914 // Configure kernel window
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100915 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
morgolock4adaddb2020-09-29 14:24:32 +0100916 }
917 else
918 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100919 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
morgolock4adaddb2020-09-29 14:24:32 +0100920 }
921
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100922 ICpuKernel::configure(win);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000923}
924
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100925Status CpuGemmLowpMatrixMultiplyKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000926{
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100927 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, dst));
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000928 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000929}
Pablo Tello181e6512017-11-15 13:28:27 +0000930
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100931void CpuGemmLowpMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Pablo Tello181e6512017-11-15 13:28:27 +0000932{
933 ARM_COMPUTE_UNUSED(info);
934 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100935 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
936
937 auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
938 auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
939 auto dst = tensors.get_tensor(TensorType::ACL_DST);
Pablo Tello181e6512017-11-15 13:28:27 +0000940
Gian Marcoc7f9b892017-11-30 14:31:13 +0000941 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100942 if((dst->info()->dimension(1) == 1))
Pablo Tello181e6512017-11-15 13:28:27 +0000943 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100944 const auto width_matrix_a = static_cast<int>(src0->info()->dimension(0));
945 const auto width_matrix_b = static_cast<int>(src1->info()->dimension(0));
946 const auto width_out = static_cast<int>(dst->info()->dimension(0));
947 const auto in_b_stride = static_cast<int>(src1->info()->strides_in_bytes()[1] / data_size_from_type(src1->info()->data_type()));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000948
949 // The implementation computes 16 elements per iteration
950 const int window_start_x = 16 * info.thread_id;
951 const int window_step_x = 16 * info.num_threads;
952 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
953 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
954
955 Window win_out(window);
956 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
957 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
958
959 Window win_a(window);
960 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
961 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
962
963 Window win_b;
964 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
965 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100966 if(src1->info()->num_dimensions() >= 3)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000967 {
968 win_b = window;
969 }
970 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
971 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
972
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100973 Iterator ina(src0, win_a);
974 Iterator inb(src1, win_b);
975 Iterator out(dst, win_out);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000976
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100977 switch(src0->info()->data_type())
Gian Marcoc7f9b892017-11-30 14:31:13 +0000978 {
979 case DataType::S8:
Georgios Pinitas63d4dbd2019-11-08 11:51:56 +0000980 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000981 {
morgolock4adaddb2020-09-29 14:24:32 +0100982 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000983 break;
984 }
985 case DataType::U8:
986 case DataType::QASYMM8:
987 {
morgolock4adaddb2020-09-29 14:24:32 +0100988 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000989 break;
990 }
991 default:
992 {
993 ARM_COMPUTE_ERROR("Not supported");
994 break;
995 }
996 }
Pablo Tello181e6512017-11-15 13:28:27 +0000997 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000998 else
Pablo Tello181e6512017-11-15 13:28:27 +0000999 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001000 const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
1001 const int width_b = src1->info()->dimension(0);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001002
1003 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1004 Window win_a(window);
1005 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1006 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1007
1008 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1009 Window win_b;
1010 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1011 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1012 if(_slide_matrix_b)
Pablo Tello181e6512017-11-15 13:28:27 +00001013 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001014 win_b = window;
Pablo Tello181e6512017-11-15 13:28:27 +00001015 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001016 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1017 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1018
1019 // The step x and step y for the output matrix has been already set using in configure()
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001020 Iterator ina(src0, win_a);
1021 Iterator inb(src1, win_b);
1022 Iterator out(dst, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001023
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001024 switch(src0->info()->data_type())
Pablo Tello181e6512017-11-15 13:28:27 +00001025 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001026 case DataType::S8:
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +01001027 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +00001028 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001029 matrix_multiply_s8(ina, inb, out, width_b, *dst->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001030 break;
1031 }
1032 case DataType::U8:
1033 case DataType::QASYMM8:
1034 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001035 matrix_multiply_u8(ina, inb, out, width_b, *dst->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001036 break;
1037 }
1038 default:
1039 {
1040 ARM_COMPUTE_ERROR("Not supported");
1041 break;
1042 }
Pablo Tello181e6512017-11-15 13:28:27 +00001043 }
1044 }
1045}
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001046
1047const char *CpuGemmLowpMatrixMultiplyKernel::name() const
1048{
1049 return "CpuGemmLowpMatrixMultiplyKernel";
1050}
1051} // namespace kernels
1052} // namespace cpu
1053} // namespace arm_compute