blob: b95bdd4ca5b03c0b383152d96c184fe061bd7215 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
SiCongLib88272e2021-02-24 15:40:57 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/helpers/AutoConfiguration.h"
35#include "src/core/helpers/WindowHelpers.h"
36
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39using namespace arm_compute;
40
41namespace arm_compute
42{
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000043namespace
44{
morgolock4adaddb2020-09-29 14:24:32 +010045void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000046{
Gian Marcoc7f9b892017-11-30 14:31:13 +000047 execute_window_loop(window, [&](const Coordinates & id)
48 {
49 if(id.x() > width_b)
50 {
51 return;
52 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000053
Gian Marcoc7f9b892017-11-30 14:31:13 +000054 // Note: Since the input are all positives, we can use uint32_t
55 // Accumulators for the block 0
56 uint32x4x4_t c0 =
57 {
58 {
59 vdupq_n_u32(0),
60 vdupq_n_u32(0),
61 vdupq_n_u32(0),
62 vdupq_n_u32(0)
63 }
64 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000065
Gian Marcoc7f9b892017-11-30 14:31:13 +000066 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
67 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
68 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000069
Gian Marcoc7f9b892017-11-30 14:31:13 +000070 // This for loop performs 8 accumulations
71 for(; vec_a <= (vec_a_end_addr - 8);)
72 {
73 const uint8x8_t a00_u8 = vld1_u8(vec_a);
74 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
75 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
76 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
77 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
78 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
79 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
80 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
81 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000082
Gian Marcoc7f9b892017-11-30 14:31:13 +000083 // Convert a00_u8 to uint16_t and get the lower part
84 const uint16x4x2_t a00_u16 =
85 {
86 {
87 vget_low_u16(vmovl_u8(a00_u8)),
88 vget_high_u16(vmovl_u8(a00_u8))
89 }
90 };
91
92 const uint16x4x4_t b00_u16 =
93 {
94 {
95 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
96 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
97 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
99 }
100 };
101
102 const uint16x4x4_t b10_u16 =
103 {
104 {
105 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
106 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
107 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
109 }
110 };
111
112 const uint16x4x4_t b20_u16 =
113 {
114 {
115 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
116 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
117 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
119 }
120 };
121
122 const uint16x4x4_t b30_u16 =
123 {
124 {
125 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
126 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
127 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
129 }
130 };
131
132 const uint16x4x4_t b40_u16 =
133 {
134 {
135 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
136 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
137 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
139 }
140 };
141
142 const uint16x4x4_t b50_u16 =
143 {
144 {
145 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
146 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
147 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
149 }
150 };
151
152 const uint16x4x4_t b60_u16 =
153 {
154 {
155 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
156 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
157 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
159 }
160 };
161
162 const uint16x4x4_t b70_u16 =
163 {
164 {
165 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
166 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
167 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
169 }
170 };
171
172 // Accumulate 0:
173 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
174 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
175 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
176 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
177
178 // Accumulate 1:
179 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
180 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
181 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
182 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
183
184 // Accumulate 2:
185 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
186 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
187 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
188 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
189
190 // Accumulate 3:
191 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
192 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
193 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
194 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
195
196 // Accumulate 4:
197 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
198 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
199 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
200 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
201
202 // Accumulate 5:
203 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
204 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
205 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
206 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
207
208 // Accumulate 6:
209 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
210 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
211 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
212 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
213
214 // Accumulate 7:
215 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
216 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
217 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
218 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
219
220 vec_a += 8;
221 matrix_b += 8 * stride_b;
222 }
223
224 // This for loop performs the left-over accumulations
225 for(; vec_a < vec_a_end_addr;)
226 {
227 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
228 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
229
230 const uint16x4x4_t b00_u16 =
231 {
232 {
233 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
234 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
235 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
237 }
238 };
239
240 // Convert a00_u8 to uint16_t and get the lower part
241 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
242
243 // Accumulate 0:
244 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
245 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
246 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
247 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
248
249 vec_a += 1;
250 matrix_b += stride_b;
251 }
252
253 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100254 if(id.x() < (width_out - 16))
255 {
256 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
257 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
258 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
259 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
260 }
261 else
262 {
263 auto left_over = width_out - id.x();
264 for(auto k = 0; k < 4 && left_over; ++k)
265 {
266 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
267 {
268 *(vec_out + k * 4 + j) = c0.val[k][j];
269 }
270 }
271 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000272 },
273 ina, inb, out);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000274}
275
morgolock4adaddb2020-09-29 14:24:32 +0100276void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000277{
Gian Marcoc7f9b892017-11-30 14:31:13 +0000278 execute_window_loop(window, [&](const Coordinates & id)
279 {
280 if(id.x() > width_b)
281 {
282 return;
283 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000284
Gian Marcoc7f9b892017-11-30 14:31:13 +0000285 // Accumulators for the block 0
286 int32x4x4_t c0 =
287 {
288 {
289 vdupq_n_s32(0),
290 vdupq_n_s32(0),
291 vdupq_n_s32(0),
292 vdupq_n_s32(0)
293 }
294 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000295
Gian Marcoc7f9b892017-11-30 14:31:13 +0000296 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
297 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
298 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000299
Gian Marcoc7f9b892017-11-30 14:31:13 +0000300 // This for loop performs 8 accumulations
301 for(; vec_a <= (vec_a_end_addr - 8);)
302 {
303 const int8x8_t a00_s8 = vld1_s8(vec_a);
304 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
305 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
306 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
307 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
308 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
309 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
310 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
311 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000312
Gian Marcoc7f9b892017-11-30 14:31:13 +0000313 // Convert a00_s8 to int16_t and get the lower part
314 const int16x4x2_t a00_s16 =
315 {
316 {
317 vget_low_s16(vmovl_s8(a00_s8)),
318 vget_high_s16(vmovl_s8(a00_s8))
319 }
320 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000321
Gian Marcoc7f9b892017-11-30 14:31:13 +0000322 const int16x4x4_t b00_s16 =
323 {
324 {
325 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
326 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
327 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
328 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
329 }
330 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000331
Gian Marcoc7f9b892017-11-30 14:31:13 +0000332 const int16x4x4_t b10_s16 =
333 {
334 {
335 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
336 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
337 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
338 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
339 }
340 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341
Gian Marcoc7f9b892017-11-30 14:31:13 +0000342 const int16x4x4_t b20_s16 =
343 {
344 {
345 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
346 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
347 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
348 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
349 }
350 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351
Gian Marcoc7f9b892017-11-30 14:31:13 +0000352 const int16x4x4_t b30_s16 =
353 {
354 {
355 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
356 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
357 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
358 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
359 }
360 };
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100361
Gian Marcoc7f9b892017-11-30 14:31:13 +0000362 const int16x4x4_t b40_s16 =
363 {
364 {
365 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
366 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
367 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
368 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
369 }
370 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100371
Gian Marcoc7f9b892017-11-30 14:31:13 +0000372 const int16x4x4_t b50_s16 =
373 {
374 {
375 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
376 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
377 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
378 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
379 }
380 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100381
Gian Marcoc7f9b892017-11-30 14:31:13 +0000382 const int16x4x4_t b60_s16 =
383 {
384 {
385 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
386 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
387 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
388 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
389 }
390 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100391
Gian Marcoc7f9b892017-11-30 14:31:13 +0000392 const int16x4x4_t b70_s16 =
393 {
394 {
395 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
396 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
397 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
398 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
399 }
400 };
401
402 // Accumulate 0:
403 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
404 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
405 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
406 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
407
408 // Accumulate 1:
409 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
410 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
411 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
412 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
413
414 // Accumulate 2:
415 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
416 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
417 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
418 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
419
420 // Accumulate 3:
421 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
422 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
423 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
424 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
425
426 // Accumulate 4:
427 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
428 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
429 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
430 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
431
432 // Accumulate 5:
433 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
434 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
435 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
436 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
437
438 // Accumulate 6:
439 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
440 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
441 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
442 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
443
444 // Accumulate 7:
445 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
446 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
447 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
448 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
449
450 vec_a += 8;
451 matrix_b += 8 * stride_b;
452 }
453
454 // This for loop performs the left-over accumulations
455 for(; vec_a < vec_a_end_addr;)
456 {
457 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
458 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
459
460 const int16x4x4_t b00_s16 =
461 {
462 {
463 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
464 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
465 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
466 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
467 }
468 };
469
470 // Convert a00_s8 to uint16_t and get the lower part
471 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
472
473 // Accumulate 0:
474 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
475 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
476 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
477 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
478
479 vec_a += 1;
480 matrix_b += stride_b;
481 }
482
483 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100484 if(id.x() < (width_out - 16))
485 {
486 vst1q_s32(vec_out + 0, c0.val[0]);
487 vst1q_s32(vec_out + 4, c0.val[1]);
488 vst1q_s32(vec_out + 8, c0.val[2]);
489 vst1q_s32(vec_out + 12, c0.val[3]);
490 }
491 else
492 {
493 auto left_over = width_out - id.x();
494 for(auto k = 0; k < 4 && left_over; ++k)
495 {
496 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
497 {
498 *(vec_out + k * 4 + j) = c0.val[k][j];
499 }
500 }
501 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000502 },
503 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100504}
505
morgolock4adaddb2020-09-29 14:24:32 +0100506void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100507{
morgolock4adaddb2020-09-29 14:24:32 +0100508 const auto width_out = static_cast<int>(out_info.dimension(0));
509 const auto height_out = static_cast<int>(out_info.dimension(1));
510 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
511 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000513 const uint8_t *mtx_a0 = ina.ptr();
514 const uint8_t *mtx_b0 = inb.ptr();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100515
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100516 // Note: Since the input are all positives, we can use uint32_t
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100517 // Accumulators for the block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000518 uint32x4x4_t c0 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519 {
520 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000521 vdupq_n_u32(0),
522 vdupq_n_u32(0),
523 vdupq_n_u32(0),
524 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100525 }
526 };
527
528 // Accumulators for the block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000529 uint32x4x4_t c1 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100530 {
531 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000532 vdupq_n_u32(0),
533 vdupq_n_u32(0),
534 vdupq_n_u32(0),
535 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100536 }
537 };
538
539 // Accumulators for the block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000540 uint32x4x4_t c2 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100541 {
542 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000543 vdupq_n_u32(0),
544 vdupq_n_u32(0),
545 vdupq_n_u32(0),
546 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100547 }
548 };
549
550 // Accumulators for the block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000551 uint32x4x4_t c3 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100552 {
553 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000554 vdupq_n_u32(0),
555 vdupq_n_u32(0),
556 vdupq_n_u32(0),
557 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558 }
559 };
560
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100561 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100562 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000563 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
564 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100565
Gian Marcoc7f9b892017-11-30 14:31:13 +0000566 // Convert a00_u8 to uint16_t and get the lower part
Gian Marcoe75a02b2017-11-08 12:24:09 +0000567 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568
Gian Marcoe75a02b2017-11-08 12:24:09 +0000569 // Convert b00_s8 to uint16_t
570 const uint16x4x4_t b00_u16 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100571 {
572 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000573 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
574 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
575 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
576 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100577 }
578 };
579
580 // 4x4 block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000581 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
582 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
583 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
584 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100585
586 // 4x4 block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000587 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
588 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
589 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
590 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100591
592 // 4x4 block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000593 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
594 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
595 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
596 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100597
598 // 4x4 block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000599 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
600 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
601 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
602 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100603 }
604
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100605 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100606
607 if(id.y() < height_out && id.x() < (width_out - 16))
608 {
609 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
610 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
611 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
612 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
613 if(id.y() + 1 < height_out)
614 {
615 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
616 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
617 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
618 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
619 if(id.y() + 2 < height_out)
620 {
621 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
622 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
623 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
624 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
625 if(id.y() + 3 < height_out)
626 {
627 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
628 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
629 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
630 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
631 }
632 }
633 }
634 }
635 else
636 {
637 const auto left_over_value = width_out - id.x();
638 auto left_over = left_over_value;
639 for(auto k = 0; k < 4 && left_over; ++k)
640 {
641 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
642 {
643 *(mtx_out + k * 4 + j) = c0.val[k][j];
644 }
645 }
646 if(id.y() + 1 < height_out)
647 {
648 left_over = left_over_value;
649 for(auto k = 0; k < 4 && left_over; ++k)
650 {
651 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
652 {
653 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
654 }
655 }
656 if(id.y() + 2 < height_out)
657 {
658 left_over = left_over_value;
659 for(auto k = 0; k < 4 && left_over; ++k)
660 {
661 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
662 {
663 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
664 }
665 }
666 if(id.y() + 3 < height_out)
667 {
668 left_over = left_over_value;
669 for(auto k = 0; k < 4 && left_over; ++k)
670 {
671 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
672 {
673 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
674 }
675 }
676 }
677 }
678 }
679 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100680 },
681 ina, inb, out);
682}
Pablo Tello181e6512017-11-15 13:28:27 +0000683
morgolock4adaddb2020-09-29 14:24:32 +0100684void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Pablo Tello181e6512017-11-15 13:28:27 +0000685{
morgolock4adaddb2020-09-29 14:24:32 +0100686 const auto width_out = static_cast<int>(out_info.dimension(0));
687 const auto height_out = static_cast<int>(out_info.dimension(1));
688 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
Pablo Tello181e6512017-11-15 13:28:27 +0000689 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
690 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
691 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
morgolock4adaddb2020-09-29 14:24:32 +0100692 execute_window_loop(window, [&](const Coordinates & id)
Pablo Tello181e6512017-11-15 13:28:27 +0000693 {
694 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
695 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
696
697 // Note: Since the input are all positives, we can use uint32_t
698 // Accumulators for the block 0
699 int32x4x4_t c0 =
700 {
701 {
702 vdupq_n_s32(0),
703 vdupq_n_s32(0),
704 vdupq_n_s32(0),
705 vdupq_n_s32(0)
706 }
707 };
708
709 // Accumulators for the block 1
710 int32x4x4_t c1 =
711 {
712 {
713 vdupq_n_s32(0),
714 vdupq_n_s32(0),
715 vdupq_n_s32(0),
716 vdupq_n_s32(0)
717 }
718 };
719
720 // Accumulators for the block 2
721 int32x4x4_t c2 =
722 {
723 {
724 vdupq_n_s32(0),
725 vdupq_n_s32(0),
726 vdupq_n_s32(0),
727 vdupq_n_s32(0)
728 }
729 };
730
731 // Accumulators for the block 3
732 int32x4x4_t c3 =
733 {
734 {
735 vdupq_n_s32(0),
736 vdupq_n_s32(0),
737 vdupq_n_s32(0),
738 vdupq_n_s32(0)
739 }
740 };
741
742 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
743 {
744 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
745 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
746
747 // Convert a00_s8 to uint16_t and get the lower part
748 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
749
750 // Convert b00_s8 to int16_t
751 const int16x4x4_t b00_s16 =
752 {
753 {
754 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
755 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
756 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
757 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
758 }
759 };
760
761 // 4x4 block 0
762 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
763 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
764 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
765 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
766
767 // 4x4 block 1
768 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
769 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
770 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
771 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
772
773 // 4x4 block 2
774 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
775 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
776 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
777 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
778
779 // 4x4 block 3
780 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
781 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
782 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
783 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
784 }
Pablo Tello181e6512017-11-15 13:28:27 +0000785 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100786 if(id.y() < height_out && id.x() < (width_out - 16))
787 {
788 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
789 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
790 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
791 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
792 if(id.y() + 1 < height_out)
793 {
794 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
795 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
796 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
797 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
798 if(id.y() + 2 < height_out)
799 {
800 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
801 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
802 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
803 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
804 if(id.y() + 3 < height_out)
805 {
806 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
807 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
808 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
809 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
810 }
811 }
812 }
813 }
814 else if(id.y() < height_out)
815 {
816 const auto left_over_value = width_out - id.x();
817 auto left_over = left_over_value;
818 for(auto k = 0; k < 4 && left_over; ++k)
819 {
820 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
821 {
822 *(mtx_out + k * 4 + j) = c0.val[k][j];
823 }
824 }
825 if(id.y() + 1 < height_out)
826 {
827 left_over = left_over_value;
828 for(auto k = 0; k < 4 && left_over; ++k)
829 {
830 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
831 {
832 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
833 }
834 }
835 if(id.y() + 2 < height_out)
836 {
837 left_over = left_over_value;
838 for(auto k = 0; k < 4 && left_over; ++k)
839 {
840 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
841 {
842 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
843 }
844 }
845 if(id.y() + 3 < height_out)
846 {
847 left_over = left_over_value;
848 for(auto k = 0; k < 4 && left_over; ++k)
849 {
850 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
851 {
852 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
853 }
854 }
855 }
856 }
857 }
858 }
859
Pablo Tello181e6512017-11-15 13:28:27 +0000860 },
861 ina, inb, out);
862}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000863} // namespace
864
Gian Marcoc7f9b892017-11-30 14:31:13 +0000865namespace
866{
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000867Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000868{
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +0100869 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000870 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000871 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
872
873 TensorShape in0_shape = input0->tensor_shape();
874 TensorShape in1_shape = input1->tensor_shape();
875 TensorShape out_shape = output->tensor_shape();
876
877 // Check vector-by-matrix case
878 if(out_shape[1] == 1)
879 {
880 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
881 }
882 else
883 {
884 in0_shape.collapse(2);
885 in1_shape.collapse(2);
886 out_shape.collapse(2);
887
888 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
889 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000890 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
Gian Marcoc7f9b892017-11-30 14:31:13 +0000891 }
892
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000893 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000894}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000895} // namespace
896
897NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
898 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
899{
900}
901
902void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
903{
904 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
905 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
906
907 TensorShape in1_shape = input1->info()->tensor_shape();
908 in1_shape.collapse(2);
909
910 _input0 = input0;
911 _input1 = input1;
912 _output = output;
913 _slide_matrix_b = in1_shape[2] != 1;
914
morgolock4adaddb2020-09-29 14:24:32 +0100915 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
916 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
917
918 Window win;
919
920 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
921 if((output->info()->dimension(1) == 1))
922 {
923 // Configure kernel window
924 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
morgolock4adaddb2020-09-29 14:24:32 +0100925 }
926 else
927 {
928 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
morgolock4adaddb2020-09-29 14:24:32 +0100929 }
930
931 INEKernel::configure(win);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000932}
933
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000934Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000935{
936 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000937
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000938 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000939}
Pablo Tello181e6512017-11-15 13:28:27 +0000940
941void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
942{
943 ARM_COMPUTE_UNUSED(info);
944 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
945 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
946
Gian Marcoc7f9b892017-11-30 14:31:13 +0000947 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
948 if((_output->info()->dimension(1) == 1))
Pablo Tello181e6512017-11-15 13:28:27 +0000949 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000950 const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
951 const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
morgolock4adaddb2020-09-29 14:24:32 +0100952 const auto width_out = static_cast<int>(_output->info()->dimension(0));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000953 const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
954
955 // The implementation computes 16 elements per iteration
956 const int window_start_x = 16 * info.thread_id;
957 const int window_step_x = 16 * info.num_threads;
958 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
959 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
960
961 Window win_out(window);
962 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
963 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
964
965 Window win_a(window);
966 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
967 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
968
969 Window win_b;
970 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
971 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
972 if(_input1->info()->num_dimensions() >= 3)
973 {
974 win_b = window;
975 }
976 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
977 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
978
979 Iterator ina(_input0, win_a);
980 Iterator inb(_input1, win_b);
981 Iterator out(_output, win_out);
982
983 switch(_input0->info()->data_type())
984 {
985 case DataType::S8:
Georgios Pinitas63d4dbd2019-11-08 11:51:56 +0000986 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000987 {
morgolock4adaddb2020-09-29 14:24:32 +0100988 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000989 break;
990 }
991 case DataType::U8:
992 case DataType::QASYMM8:
993 {
morgolock4adaddb2020-09-29 14:24:32 +0100994 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000995 break;
996 }
997 default:
998 {
999 ARM_COMPUTE_ERROR("Not supported");
1000 break;
1001 }
1002 }
Pablo Tello181e6512017-11-15 13:28:27 +00001003 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001004 else
Pablo Tello181e6512017-11-15 13:28:27 +00001005 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001006 const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
morgolock4adaddb2020-09-29 14:24:32 +01001007 const int width_b = _input1->info()->dimension(0);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001008
1009 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1010 Window win_a(window);
1011 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1012 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1013
1014 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1015 Window win_b;
1016 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1017 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1018 if(_slide_matrix_b)
Pablo Tello181e6512017-11-15 13:28:27 +00001019 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001020 win_b = window;
Pablo Tello181e6512017-11-15 13:28:27 +00001021 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001022 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1023 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1024
1025 // The step x and step y for the output matrix has been already set using in configure()
1026 Iterator ina(_input0, win_a);
1027 Iterator inb(_input1, win_b);
1028 Iterator out(_output, window);
1029
Gian Marcoc7f9b892017-11-30 14:31:13 +00001030 switch(_input0->info()->data_type())
Pablo Tello181e6512017-11-15 13:28:27 +00001031 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001032 case DataType::S8:
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +01001033 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +00001034 {
morgolock4adaddb2020-09-29 14:24:32 +01001035 matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001036 break;
1037 }
1038 case DataType::U8:
1039 case DataType::QASYMM8:
1040 {
morgolock4adaddb2020-09-29 14:24:32 +01001041 matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001042 break;
1043 }
1044 default:
1045 {
1046 ARM_COMPUTE_ERROR("Not supported");
1047 break;
1048 }
Pablo Tello181e6512017-11-15 13:28:27 +00001049 }
1050 }
1051}
morgolock4adaddb2020-09-29 14:24:32 +01001052} // namespace arm_compute