blob: 4dbfc3b022c0ff2946e98d1fbcdb697f12d82332 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010034#include "src/core/AccessWindowStatic.h"
35#include "src/core/helpers/AutoConfiguration.h"
36#include "src/core/helpers/WindowHelpers.h"
37
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
40using namespace arm_compute;
41
42namespace arm_compute
43{
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000044namespace
45{
morgolock4adaddb2020-09-29 14:24:32 +010046void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000047{
Gian Marcoc7f9b892017-11-30 14:31:13 +000048 execute_window_loop(window, [&](const Coordinates & id)
49 {
50 if(id.x() > width_b)
51 {
52 return;
53 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000054
Gian Marcoc7f9b892017-11-30 14:31:13 +000055 // Note: Since the input are all positives, we can use uint32_t
56 // Accumulators for the block 0
57 uint32x4x4_t c0 =
58 {
59 {
60 vdupq_n_u32(0),
61 vdupq_n_u32(0),
62 vdupq_n_u32(0),
63 vdupq_n_u32(0)
64 }
65 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000066
Gian Marcoc7f9b892017-11-30 14:31:13 +000067 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
68 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
69 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000070
Gian Marcoc7f9b892017-11-30 14:31:13 +000071 // This for loop performs 8 accumulations
72 for(; vec_a <= (vec_a_end_addr - 8);)
73 {
74 const uint8x8_t a00_u8 = vld1_u8(vec_a);
75 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
76 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
77 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
78 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
79 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
80 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
81 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
82 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000083
Gian Marcoc7f9b892017-11-30 14:31:13 +000084 // Convert a00_u8 to uint16_t and get the lower part
85 const uint16x4x2_t a00_u16 =
86 {
87 {
88 vget_low_u16(vmovl_u8(a00_u8)),
89 vget_high_u16(vmovl_u8(a00_u8))
90 }
91 };
92
93 const uint16x4x4_t b00_u16 =
94 {
95 {
96 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
97 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
99 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
100 }
101 };
102
103 const uint16x4x4_t b10_u16 =
104 {
105 {
106 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
107 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
109 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
110 }
111 };
112
113 const uint16x4x4_t b20_u16 =
114 {
115 {
116 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
117 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
119 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
120 }
121 };
122
123 const uint16x4x4_t b30_u16 =
124 {
125 {
126 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
127 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
129 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
130 }
131 };
132
133 const uint16x4x4_t b40_u16 =
134 {
135 {
136 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
137 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
139 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
140 }
141 };
142
143 const uint16x4x4_t b50_u16 =
144 {
145 {
146 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
147 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
149 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
150 }
151 };
152
153 const uint16x4x4_t b60_u16 =
154 {
155 {
156 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
157 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
159 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
160 }
161 };
162
163 const uint16x4x4_t b70_u16 =
164 {
165 {
166 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
167 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
169 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
170 }
171 };
172
173 // Accumulate 0:
174 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
175 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
176 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
177 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
178
179 // Accumulate 1:
180 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
181 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
182 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
183 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
184
185 // Accumulate 2:
186 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
187 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
188 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
189 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
190
191 // Accumulate 3:
192 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
193 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
194 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
195 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
196
197 // Accumulate 4:
198 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
199 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
200 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
201 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
202
203 // Accumulate 5:
204 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
205 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
206 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
207 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
208
209 // Accumulate 6:
210 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
211 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
212 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
213 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
214
215 // Accumulate 7:
216 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
217 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
218 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
219 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
220
221 vec_a += 8;
222 matrix_b += 8 * stride_b;
223 }
224
225 // This for loop performs the left-over accumulations
226 for(; vec_a < vec_a_end_addr;)
227 {
228 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
229 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
230
231 const uint16x4x4_t b00_u16 =
232 {
233 {
234 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
235 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
237 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
238 }
239 };
240
241 // Convert a00_u8 to uint16_t and get the lower part
242 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
243
244 // Accumulate 0:
245 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
246 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
247 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
248 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
249
250 vec_a += 1;
251 matrix_b += stride_b;
252 }
253
254 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100255 if(id.x() < (width_out - 16))
256 {
257 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
258 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
259 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
260 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
261 }
262 else
263 {
264 auto left_over = width_out - id.x();
265 for(auto k = 0; k < 4 && left_over; ++k)
266 {
267 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
268 {
269 *(vec_out + k * 4 + j) = c0.val[k][j];
270 }
271 }
272 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000273 },
274 ina, inb, out);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000275}
276
morgolock4adaddb2020-09-29 14:24:32 +0100277void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000278{
Gian Marcoc7f9b892017-11-30 14:31:13 +0000279 execute_window_loop(window, [&](const Coordinates & id)
280 {
281 if(id.x() > width_b)
282 {
283 return;
284 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000285
Gian Marcoc7f9b892017-11-30 14:31:13 +0000286 // Accumulators for the block 0
287 int32x4x4_t c0 =
288 {
289 {
290 vdupq_n_s32(0),
291 vdupq_n_s32(0),
292 vdupq_n_s32(0),
293 vdupq_n_s32(0)
294 }
295 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000296
Gian Marcoc7f9b892017-11-30 14:31:13 +0000297 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
298 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
299 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000300
Gian Marcoc7f9b892017-11-30 14:31:13 +0000301 // This for loop performs 8 accumulations
302 for(; vec_a <= (vec_a_end_addr - 8);)
303 {
304 const int8x8_t a00_s8 = vld1_s8(vec_a);
305 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
306 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
307 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
308 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
309 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
310 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
311 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
312 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000313
Gian Marcoc7f9b892017-11-30 14:31:13 +0000314 // Convert a00_s8 to int16_t and get the lower part
315 const int16x4x2_t a00_s16 =
316 {
317 {
318 vget_low_s16(vmovl_s8(a00_s8)),
319 vget_high_s16(vmovl_s8(a00_s8))
320 }
321 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000322
Gian Marcoc7f9b892017-11-30 14:31:13 +0000323 const int16x4x4_t b00_s16 =
324 {
325 {
326 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
327 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
328 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
329 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
330 }
331 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000332
Gian Marcoc7f9b892017-11-30 14:31:13 +0000333 const int16x4x4_t b10_s16 =
334 {
335 {
336 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
337 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
338 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
339 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
340 }
341 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342
Gian Marcoc7f9b892017-11-30 14:31:13 +0000343 const int16x4x4_t b20_s16 =
344 {
345 {
346 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
347 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
348 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
349 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
350 }
351 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100352
Gian Marcoc7f9b892017-11-30 14:31:13 +0000353 const int16x4x4_t b30_s16 =
354 {
355 {
356 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
357 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
358 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
359 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
360 }
361 };
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100362
Gian Marcoc7f9b892017-11-30 14:31:13 +0000363 const int16x4x4_t b40_s16 =
364 {
365 {
366 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
367 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
368 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
369 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
370 }
371 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100372
Gian Marcoc7f9b892017-11-30 14:31:13 +0000373 const int16x4x4_t b50_s16 =
374 {
375 {
376 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
377 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
378 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
379 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
380 }
381 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100382
Gian Marcoc7f9b892017-11-30 14:31:13 +0000383 const int16x4x4_t b60_s16 =
384 {
385 {
386 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
387 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
388 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
389 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
390 }
391 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100392
Gian Marcoc7f9b892017-11-30 14:31:13 +0000393 const int16x4x4_t b70_s16 =
394 {
395 {
396 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
397 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
398 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
399 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
400 }
401 };
402
403 // Accumulate 0:
404 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
405 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
406 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
407 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
408
409 // Accumulate 1:
410 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
411 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
412 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
413 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
414
415 // Accumulate 2:
416 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
417 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
418 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
419 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
420
421 // Accumulate 3:
422 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
423 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
424 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
425 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
426
427 // Accumulate 4:
428 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
429 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
430 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
431 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
432
433 // Accumulate 5:
434 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
435 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
436 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
437 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
438
439 // Accumulate 6:
440 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
441 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
442 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
443 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
444
445 // Accumulate 7:
446 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
447 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
448 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
449 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
450
451 vec_a += 8;
452 matrix_b += 8 * stride_b;
453 }
454
455 // This for loop performs the left-over accumulations
456 for(; vec_a < vec_a_end_addr;)
457 {
458 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
459 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
460
461 const int16x4x4_t b00_s16 =
462 {
463 {
464 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
465 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
466 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
467 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
468 }
469 };
470
471 // Convert a00_s8 to uint16_t and get the lower part
472 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
473
474 // Accumulate 0:
475 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
476 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
477 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
478 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
479
480 vec_a += 1;
481 matrix_b += stride_b;
482 }
483
484 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100485 if(id.x() < (width_out - 16))
486 {
487 vst1q_s32(vec_out + 0, c0.val[0]);
488 vst1q_s32(vec_out + 4, c0.val[1]);
489 vst1q_s32(vec_out + 8, c0.val[2]);
490 vst1q_s32(vec_out + 12, c0.val[3]);
491 }
492 else
493 {
494 auto left_over = width_out - id.x();
495 for(auto k = 0; k < 4 && left_over; ++k)
496 {
497 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
498 {
499 *(vec_out + k * 4 + j) = c0.val[k][j];
500 }
501 }
502 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000503 },
504 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100505}
506
morgolock4adaddb2020-09-29 14:24:32 +0100507void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100508{
morgolock4adaddb2020-09-29 14:24:32 +0100509 const auto width_out = static_cast<int>(out_info.dimension(0));
510 const auto height_out = static_cast<int>(out_info.dimension(1));
511 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
512 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100513 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000514 const uint8_t *mtx_a0 = ina.ptr();
515 const uint8_t *mtx_b0 = inb.ptr();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100516
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100517 // Note: Since the input are all positives, we can use uint32_t
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100518 // Accumulators for the block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000519 uint32x4x4_t c0 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100520 {
521 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000522 vdupq_n_u32(0),
523 vdupq_n_u32(0),
524 vdupq_n_u32(0),
525 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100526 }
527 };
528
529 // Accumulators for the block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000530 uint32x4x4_t c1 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100531 {
532 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000533 vdupq_n_u32(0),
534 vdupq_n_u32(0),
535 vdupq_n_u32(0),
536 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100537 }
538 };
539
540 // Accumulators for the block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000541 uint32x4x4_t c2 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 {
543 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000544 vdupq_n_u32(0),
545 vdupq_n_u32(0),
546 vdupq_n_u32(0),
547 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100548 }
549 };
550
551 // Accumulators for the block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000552 uint32x4x4_t c3 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100553 {
554 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000555 vdupq_n_u32(0),
556 vdupq_n_u32(0),
557 vdupq_n_u32(0),
558 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100559 }
560 };
561
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100562 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100563 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000564 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
565 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100566
Gian Marcoc7f9b892017-11-30 14:31:13 +0000567 // Convert a00_u8 to uint16_t and get the lower part
Gian Marcoe75a02b2017-11-08 12:24:09 +0000568 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100569
Gian Marcoe75a02b2017-11-08 12:24:09 +0000570 // Convert b00_s8 to uint16_t
571 const uint16x4x4_t b00_u16 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100572 {
573 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000574 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
575 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
576 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
577 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100578 }
579 };
580
581 // 4x4 block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000582 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
583 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
584 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
585 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100586
587 // 4x4 block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000588 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
589 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
590 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
591 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100592
593 // 4x4 block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000594 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
595 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
596 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
597 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100598
599 // 4x4 block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000600 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
601 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
602 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
603 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100604 }
605
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100606 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100607
608 if(id.y() < height_out && id.x() < (width_out - 16))
609 {
610 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
611 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
612 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
613 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
614 if(id.y() + 1 < height_out)
615 {
616 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
617 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
618 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
619 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
620 if(id.y() + 2 < height_out)
621 {
622 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
623 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
624 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
625 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
626 if(id.y() + 3 < height_out)
627 {
628 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
629 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
630 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
631 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
632 }
633 }
634 }
635 }
636 else
637 {
638 const auto left_over_value = width_out - id.x();
639 auto left_over = left_over_value;
640 for(auto k = 0; k < 4 && left_over; ++k)
641 {
642 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
643 {
644 *(mtx_out + k * 4 + j) = c0.val[k][j];
645 }
646 }
647 if(id.y() + 1 < height_out)
648 {
649 left_over = left_over_value;
650 for(auto k = 0; k < 4 && left_over; ++k)
651 {
652 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
653 {
654 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
655 }
656 }
657 if(id.y() + 2 < height_out)
658 {
659 left_over = left_over_value;
660 for(auto k = 0; k < 4 && left_over; ++k)
661 {
662 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
663 {
664 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
665 }
666 }
667 if(id.y() + 3 < height_out)
668 {
669 left_over = left_over_value;
670 for(auto k = 0; k < 4 && left_over; ++k)
671 {
672 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
673 {
674 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
675 }
676 }
677 }
678 }
679 }
680 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100681 },
682 ina, inb, out);
683}
Pablo Tello181e6512017-11-15 13:28:27 +0000684
morgolock4adaddb2020-09-29 14:24:32 +0100685void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Pablo Tello181e6512017-11-15 13:28:27 +0000686{
morgolock4adaddb2020-09-29 14:24:32 +0100687 const auto width_out = static_cast<int>(out_info.dimension(0));
688 const auto height_out = static_cast<int>(out_info.dimension(1));
689 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
Pablo Tello181e6512017-11-15 13:28:27 +0000690 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
691 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
692 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
morgolock4adaddb2020-09-29 14:24:32 +0100693 execute_window_loop(window, [&](const Coordinates & id)
Pablo Tello181e6512017-11-15 13:28:27 +0000694 {
695 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
696 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
697
698 // Note: Since the input are all positives, we can use uint32_t
699 // Accumulators for the block 0
700 int32x4x4_t c0 =
701 {
702 {
703 vdupq_n_s32(0),
704 vdupq_n_s32(0),
705 vdupq_n_s32(0),
706 vdupq_n_s32(0)
707 }
708 };
709
710 // Accumulators for the block 1
711 int32x4x4_t c1 =
712 {
713 {
714 vdupq_n_s32(0),
715 vdupq_n_s32(0),
716 vdupq_n_s32(0),
717 vdupq_n_s32(0)
718 }
719 };
720
721 // Accumulators for the block 2
722 int32x4x4_t c2 =
723 {
724 {
725 vdupq_n_s32(0),
726 vdupq_n_s32(0),
727 vdupq_n_s32(0),
728 vdupq_n_s32(0)
729 }
730 };
731
732 // Accumulators for the block 3
733 int32x4x4_t c3 =
734 {
735 {
736 vdupq_n_s32(0),
737 vdupq_n_s32(0),
738 vdupq_n_s32(0),
739 vdupq_n_s32(0)
740 }
741 };
742
743 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
744 {
745 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
746 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
747
748 // Convert a00_s8 to uint16_t and get the lower part
749 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
750
751 // Convert b00_s8 to int16_t
752 const int16x4x4_t b00_s16 =
753 {
754 {
755 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
756 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
757 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
758 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
759 }
760 };
761
762 // 4x4 block 0
763 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
764 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
765 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
766 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
767
768 // 4x4 block 1
769 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
770 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
771 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
772 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
773
774 // 4x4 block 2
775 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
776 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
777 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
778 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
779
780 // 4x4 block 3
781 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
782 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
783 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
784 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
785 }
Pablo Tello181e6512017-11-15 13:28:27 +0000786 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100787 if(id.y() < height_out && id.x() < (width_out - 16))
788 {
789 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
790 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
791 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
792 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
793 if(id.y() + 1 < height_out)
794 {
795 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
796 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
797 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
798 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
799 if(id.y() + 2 < height_out)
800 {
801 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
802 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
803 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
804 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
805 if(id.y() + 3 < height_out)
806 {
807 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
808 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
809 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
810 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
811 }
812 }
813 }
814 }
815 else if(id.y() < height_out)
816 {
817 const auto left_over_value = width_out - id.x();
818 auto left_over = left_over_value;
819 for(auto k = 0; k < 4 && left_over; ++k)
820 {
821 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
822 {
823 *(mtx_out + k * 4 + j) = c0.val[k][j];
824 }
825 }
826 if(id.y() + 1 < height_out)
827 {
828 left_over = left_over_value;
829 for(auto k = 0; k < 4 && left_over; ++k)
830 {
831 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
832 {
833 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
834 }
835 }
836 if(id.y() + 2 < height_out)
837 {
838 left_over = left_over_value;
839 for(auto k = 0; k < 4 && left_over; ++k)
840 {
841 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
842 {
843 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
844 }
845 }
846 if(id.y() + 3 < height_out)
847 {
848 left_over = left_over_value;
849 for(auto k = 0; k < 4 && left_over; ++k)
850 {
851 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
852 {
853 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
854 }
855 }
856 }
857 }
858 }
859 }
860
Pablo Tello181e6512017-11-15 13:28:27 +0000861 },
862 ina, inb, out);
863}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000864} // namespace
865
Gian Marcoc7f9b892017-11-30 14:31:13 +0000866namespace
867{
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000868Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000869{
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +0100870 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000871 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000872 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
873
874 TensorShape in0_shape = input0->tensor_shape();
875 TensorShape in1_shape = input1->tensor_shape();
876 TensorShape out_shape = output->tensor_shape();
877
878 // Check vector-by-matrix case
879 if(out_shape[1] == 1)
880 {
881 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
882 }
883 else
884 {
885 in0_shape.collapse(2);
886 in1_shape.collapse(2);
887 out_shape.collapse(2);
888
889 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
890 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000891 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
Gian Marcoc7f9b892017-11-30 14:31:13 +0000892 }
893
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000894 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000895}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000896} // namespace
897
898NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
899 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
900{
901}
902
903void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
904{
905 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
906 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
907
908 TensorShape in1_shape = input1->info()->tensor_shape();
909 in1_shape.collapse(2);
910
911 _input0 = input0;
912 _input1 = input1;
913 _output = output;
914 _slide_matrix_b = in1_shape[2] != 1;
915
morgolock4adaddb2020-09-29 14:24:32 +0100916 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
917 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
918
919 Window win;
920
921 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
922 if((output->info()->dimension(1) == 1))
923 {
924 // Configure kernel window
925 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
926
927 Coordinates coord;
928 coord.set_num_dimensions(output->info()->num_dimensions());
929 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
930 }
931 else
932 {
933 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
934 output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
935 }
936
937 INEKernel::configure(win);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000938}
939
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000940Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000941{
942 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000943
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000944 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000945}
Pablo Tello181e6512017-11-15 13:28:27 +0000946
947void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
948{
949 ARM_COMPUTE_UNUSED(info);
950 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
951 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
952
Gian Marcoc7f9b892017-11-30 14:31:13 +0000953 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
954 if((_output->info()->dimension(1) == 1))
Pablo Tello181e6512017-11-15 13:28:27 +0000955 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000956 const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
957 const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
morgolock4adaddb2020-09-29 14:24:32 +0100958 const auto width_out = static_cast<int>(_output->info()->dimension(0));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000959 const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
960
961 // The implementation computes 16 elements per iteration
962 const int window_start_x = 16 * info.thread_id;
963 const int window_step_x = 16 * info.num_threads;
964 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
965 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
966
967 Window win_out(window);
968 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
969 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
970
971 Window win_a(window);
972 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
973 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
974
975 Window win_b;
976 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
977 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
978 if(_input1->info()->num_dimensions() >= 3)
979 {
980 win_b = window;
981 }
982 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
983 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
984
985 Iterator ina(_input0, win_a);
986 Iterator inb(_input1, win_b);
987 Iterator out(_output, win_out);
988
989 switch(_input0->info()->data_type())
990 {
991 case DataType::S8:
Georgios Pinitas63d4dbd2019-11-08 11:51:56 +0000992 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000993 {
morgolock4adaddb2020-09-29 14:24:32 +0100994 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000995 break;
996 }
997 case DataType::U8:
998 case DataType::QASYMM8:
999 {
morgolock4adaddb2020-09-29 14:24:32 +01001000 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001001 break;
1002 }
1003 default:
1004 {
1005 ARM_COMPUTE_ERROR("Not supported");
1006 break;
1007 }
1008 }
Pablo Tello181e6512017-11-15 13:28:27 +00001009 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001010 else
Pablo Tello181e6512017-11-15 13:28:27 +00001011 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001012 const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
morgolock4adaddb2020-09-29 14:24:32 +01001013 const int width_b = _input1->info()->dimension(0);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001014
1015 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1016 Window win_a(window);
1017 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1018 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1019
1020 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1021 Window win_b;
1022 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1023 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1024 if(_slide_matrix_b)
Pablo Tello181e6512017-11-15 13:28:27 +00001025 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001026 win_b = window;
Pablo Tello181e6512017-11-15 13:28:27 +00001027 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001028 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1029 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1030
1031 // The step x and step y for the output matrix has been already set using in configure()
1032 Iterator ina(_input0, win_a);
1033 Iterator inb(_input1, win_b);
1034 Iterator out(_output, window);
1035
Gian Marcoc7f9b892017-11-30 14:31:13 +00001036 switch(_input0->info()->data_type())
Pablo Tello181e6512017-11-15 13:28:27 +00001037 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001038 case DataType::S8:
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +01001039 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +00001040 {
morgolock4adaddb2020-09-29 14:24:32 +01001041 matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001042 break;
1043 }
1044 case DataType::U8:
1045 case DataType::QASYMM8:
1046 {
morgolock4adaddb2020-09-29 14:24:32 +01001047 matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001048 break;
1049 }
1050 default:
1051 {
1052 ARM_COMPUTE_ERROR("Not supported");
1053 break;
1054 }
Pablo Tello181e6512017-11-15 13:28:27 +00001055 }
1056 }
1057}
morgolock4adaddb2020-09-29 14:24:32 +01001058} // namespace arm_compute