blob: f3ba2901cb176b0b4978a950bc89133ca346c1ec [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include <arm_neon.h>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035
36using namespace arm_compute;
37
38namespace arm_compute
39{
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000040namespace
41{
morgolock4adaddb2020-09-29 14:24:32 +010042void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000043{
Gian Marcoc7f9b892017-11-30 14:31:13 +000044 execute_window_loop(window, [&](const Coordinates & id)
45 {
46 if(id.x() > width_b)
47 {
48 return;
49 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000050
Gian Marcoc7f9b892017-11-30 14:31:13 +000051 // Note: Since the input are all positives, we can use uint32_t
52 // Accumulators for the block 0
53 uint32x4x4_t c0 =
54 {
55 {
56 vdupq_n_u32(0),
57 vdupq_n_u32(0),
58 vdupq_n_u32(0),
59 vdupq_n_u32(0)
60 }
61 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000062
Gian Marcoc7f9b892017-11-30 14:31:13 +000063 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
64 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
65 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000066
Gian Marcoc7f9b892017-11-30 14:31:13 +000067 // This for loop performs 8 accumulations
68 for(; vec_a <= (vec_a_end_addr - 8);)
69 {
70 const uint8x8_t a00_u8 = vld1_u8(vec_a);
71 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
72 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
73 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
74 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
75 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
76 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
77 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
78 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000079
Gian Marcoc7f9b892017-11-30 14:31:13 +000080 // Convert a00_u8 to uint16_t and get the lower part
81 const uint16x4x2_t a00_u16 =
82 {
83 {
84 vget_low_u16(vmovl_u8(a00_u8)),
85 vget_high_u16(vmovl_u8(a00_u8))
86 }
87 };
88
89 const uint16x4x4_t b00_u16 =
90 {
91 {
92 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
93 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
94 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
95 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
96 }
97 };
98
99 const uint16x4x4_t b10_u16 =
100 {
101 {
102 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
103 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
104 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
105 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
106 }
107 };
108
109 const uint16x4x4_t b20_u16 =
110 {
111 {
112 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
113 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
114 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
115 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
116 }
117 };
118
119 const uint16x4x4_t b30_u16 =
120 {
121 {
122 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
123 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
124 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
125 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
126 }
127 };
128
129 const uint16x4x4_t b40_u16 =
130 {
131 {
132 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
133 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
134 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
135 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
136 }
137 };
138
139 const uint16x4x4_t b50_u16 =
140 {
141 {
142 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
143 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
144 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
145 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
146 }
147 };
148
149 const uint16x4x4_t b60_u16 =
150 {
151 {
152 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
153 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
154 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
155 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
156 }
157 };
158
159 const uint16x4x4_t b70_u16 =
160 {
161 {
162 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
163 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
164 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
165 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
166 }
167 };
168
169 // Accumulate 0:
170 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
171 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
172 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
173 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
174
175 // Accumulate 1:
176 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
177 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
178 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
179 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
180
181 // Accumulate 2:
182 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
183 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
184 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
185 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
186
187 // Accumulate 3:
188 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
189 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
190 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
191 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
192
193 // Accumulate 4:
194 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
195 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
196 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
197 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
198
199 // Accumulate 5:
200 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
201 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
202 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
203 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
204
205 // Accumulate 6:
206 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
207 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
208 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
209 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
210
211 // Accumulate 7:
212 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
213 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
214 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
215 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
216
217 vec_a += 8;
218 matrix_b += 8 * stride_b;
219 }
220
221 // This for loop performs the left-over accumulations
222 for(; vec_a < vec_a_end_addr;)
223 {
224 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
225 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
226
227 const uint16x4x4_t b00_u16 =
228 {
229 {
230 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
231 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
232 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
233 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
234 }
235 };
236
237 // Convert a00_u8 to uint16_t and get the lower part
238 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
239
240 // Accumulate 0:
241 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
242 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
243 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
244 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
245
246 vec_a += 1;
247 matrix_b += stride_b;
248 }
249
250 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100251 if(id.x() < (width_out - 16))
252 {
253 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
254 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
255 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
256 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
257 }
258 else
259 {
260 auto left_over = width_out - id.x();
261 for(auto k = 0; k < 4 && left_over; ++k)
262 {
263 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
264 {
265 *(vec_out + k * 4 + j) = c0.val[k][j];
266 }
267 }
268 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000269 },
270 ina, inb, out);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000271}
272
morgolock4adaddb2020-09-29 14:24:32 +0100273void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000274{
Gian Marcoc7f9b892017-11-30 14:31:13 +0000275 execute_window_loop(window, [&](const Coordinates & id)
276 {
277 if(id.x() > width_b)
278 {
279 return;
280 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000281
Gian Marcoc7f9b892017-11-30 14:31:13 +0000282 // Accumulators for the block 0
283 int32x4x4_t c0 =
284 {
285 {
286 vdupq_n_s32(0),
287 vdupq_n_s32(0),
288 vdupq_n_s32(0),
289 vdupq_n_s32(0)
290 }
291 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000292
Gian Marcoc7f9b892017-11-30 14:31:13 +0000293 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
294 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
295 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000296
Gian Marcoc7f9b892017-11-30 14:31:13 +0000297 // This for loop performs 8 accumulations
298 for(; vec_a <= (vec_a_end_addr - 8);)
299 {
300 const int8x8_t a00_s8 = vld1_s8(vec_a);
301 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
302 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
303 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
304 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
305 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
306 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
307 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
308 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000309
Gian Marcoc7f9b892017-11-30 14:31:13 +0000310 // Convert a00_s8 to int16_t and get the lower part
311 const int16x4x2_t a00_s16 =
312 {
313 {
314 vget_low_s16(vmovl_s8(a00_s8)),
315 vget_high_s16(vmovl_s8(a00_s8))
316 }
317 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000318
Gian Marcoc7f9b892017-11-30 14:31:13 +0000319 const int16x4x4_t b00_s16 =
320 {
321 {
322 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
323 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
324 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
325 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
326 }
327 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000328
Gian Marcoc7f9b892017-11-30 14:31:13 +0000329 const int16x4x4_t b10_s16 =
330 {
331 {
332 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
333 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
334 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
335 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
336 }
337 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338
Gian Marcoc7f9b892017-11-30 14:31:13 +0000339 const int16x4x4_t b20_s16 =
340 {
341 {
342 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
343 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
344 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
345 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
346 }
347 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100348
Gian Marcoc7f9b892017-11-30 14:31:13 +0000349 const int16x4x4_t b30_s16 =
350 {
351 {
352 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
353 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
354 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
355 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
356 }
357 };
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100358
Gian Marcoc7f9b892017-11-30 14:31:13 +0000359 const int16x4x4_t b40_s16 =
360 {
361 {
362 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
363 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
364 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
365 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
366 }
367 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100368
Gian Marcoc7f9b892017-11-30 14:31:13 +0000369 const int16x4x4_t b50_s16 =
370 {
371 {
372 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
373 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
374 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
375 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
376 }
377 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100378
Gian Marcoc7f9b892017-11-30 14:31:13 +0000379 const int16x4x4_t b60_s16 =
380 {
381 {
382 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
383 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
384 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
385 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
386 }
387 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100388
Gian Marcoc7f9b892017-11-30 14:31:13 +0000389 const int16x4x4_t b70_s16 =
390 {
391 {
392 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
393 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
394 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
395 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
396 }
397 };
398
399 // Accumulate 0:
400 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
401 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
402 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
403 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
404
405 // Accumulate 1:
406 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
407 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
408 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
409 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
410
411 // Accumulate 2:
412 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
413 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
414 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
415 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
416
417 // Accumulate 3:
418 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
419 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
420 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
421 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
422
423 // Accumulate 4:
424 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
425 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
426 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
427 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
428
429 // Accumulate 5:
430 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
431 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
432 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
433 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
434
435 // Accumulate 6:
436 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
437 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
438 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
439 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
440
441 // Accumulate 7:
442 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
443 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
444 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
445 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
446
447 vec_a += 8;
448 matrix_b += 8 * stride_b;
449 }
450
451 // This for loop performs the left-over accumulations
452 for(; vec_a < vec_a_end_addr;)
453 {
454 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
455 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
456
457 const int16x4x4_t b00_s16 =
458 {
459 {
460 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
461 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
462 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
463 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
464 }
465 };
466
467 // Convert a00_s8 to uint16_t and get the lower part
468 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
469
470 // Accumulate 0:
471 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
472 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
473 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
474 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
475
476 vec_a += 1;
477 matrix_b += stride_b;
478 }
479
480 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100481 if(id.x() < (width_out - 16))
482 {
483 vst1q_s32(vec_out + 0, c0.val[0]);
484 vst1q_s32(vec_out + 4, c0.val[1]);
485 vst1q_s32(vec_out + 8, c0.val[2]);
486 vst1q_s32(vec_out + 12, c0.val[3]);
487 }
488 else
489 {
490 auto left_over = width_out - id.x();
491 for(auto k = 0; k < 4 && left_over; ++k)
492 {
493 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
494 {
495 *(vec_out + k * 4 + j) = c0.val[k][j];
496 }
497 }
498 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000499 },
500 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100501}
502
morgolock4adaddb2020-09-29 14:24:32 +0100503void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100504{
morgolock4adaddb2020-09-29 14:24:32 +0100505 const auto width_out = static_cast<int>(out_info.dimension(0));
506 const auto height_out = static_cast<int>(out_info.dimension(1));
507 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
508 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100509 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000510 const uint8_t *mtx_a0 = ina.ptr();
511 const uint8_t *mtx_b0 = inb.ptr();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100513 // Note: Since the input are all positives, we can use uint32_t
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100514 // Accumulators for the block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000515 uint32x4x4_t c0 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100516 {
517 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000518 vdupq_n_u32(0),
519 vdupq_n_u32(0),
520 vdupq_n_u32(0),
521 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100522 }
523 };
524
525 // Accumulators for the block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000526 uint32x4x4_t c1 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100527 {
528 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000529 vdupq_n_u32(0),
530 vdupq_n_u32(0),
531 vdupq_n_u32(0),
532 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100533 }
534 };
535
536 // Accumulators for the block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000537 uint32x4x4_t c2 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100538 {
539 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000540 vdupq_n_u32(0),
541 vdupq_n_u32(0),
542 vdupq_n_u32(0),
543 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100544 }
545 };
546
547 // Accumulators for the block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000548 uint32x4x4_t c3 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100549 {
550 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000551 vdupq_n_u32(0),
552 vdupq_n_u32(0),
553 vdupq_n_u32(0),
554 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100555 }
556 };
557
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100558 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100559 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000560 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
561 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100562
Gian Marcoc7f9b892017-11-30 14:31:13 +0000563 // Convert a00_u8 to uint16_t and get the lower part
Gian Marcoe75a02b2017-11-08 12:24:09 +0000564 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100565
Gian Marcoe75a02b2017-11-08 12:24:09 +0000566 // Convert b00_s8 to uint16_t
567 const uint16x4x4_t b00_u16 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568 {
569 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000570 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
571 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
572 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
573 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100574 }
575 };
576
577 // 4x4 block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000578 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
579 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
580 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
581 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100582
583 // 4x4 block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000584 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
585 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
586 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
587 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100588
589 // 4x4 block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000590 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
591 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
592 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
593 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100594
595 // 4x4 block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000596 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
597 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
598 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
599 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100600 }
601
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100602 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100603
604 if(id.y() < height_out && id.x() < (width_out - 16))
605 {
606 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
607 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
608 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
609 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
610 if(id.y() + 1 < height_out)
611 {
612 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
613 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
614 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
615 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
616 if(id.y() + 2 < height_out)
617 {
618 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
619 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
620 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
621 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
622 if(id.y() + 3 < height_out)
623 {
624 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
625 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
626 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
627 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
628 }
629 }
630 }
631 }
632 else
633 {
634 const auto left_over_value = width_out - id.x();
635 auto left_over = left_over_value;
636 for(auto k = 0; k < 4 && left_over; ++k)
637 {
638 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
639 {
640 *(mtx_out + k * 4 + j) = c0.val[k][j];
641 }
642 }
643 if(id.y() + 1 < height_out)
644 {
645 left_over = left_over_value;
646 for(auto k = 0; k < 4 && left_over; ++k)
647 {
648 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
649 {
650 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
651 }
652 }
653 if(id.y() + 2 < height_out)
654 {
655 left_over = left_over_value;
656 for(auto k = 0; k < 4 && left_over; ++k)
657 {
658 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
659 {
660 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
661 }
662 }
663 if(id.y() + 3 < height_out)
664 {
665 left_over = left_over_value;
666 for(auto k = 0; k < 4 && left_over; ++k)
667 {
668 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
669 {
670 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
671 }
672 }
673 }
674 }
675 }
676 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100677 },
678 ina, inb, out);
679}
Pablo Tello181e6512017-11-15 13:28:27 +0000680
morgolock4adaddb2020-09-29 14:24:32 +0100681void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
Pablo Tello181e6512017-11-15 13:28:27 +0000682{
morgolock4adaddb2020-09-29 14:24:32 +0100683 const auto width_out = static_cast<int>(out_info.dimension(0));
684 const auto height_out = static_cast<int>(out_info.dimension(1));
685 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
Pablo Tello181e6512017-11-15 13:28:27 +0000686 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
687 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
688 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
morgolock4adaddb2020-09-29 14:24:32 +0100689 execute_window_loop(window, [&](const Coordinates & id)
Pablo Tello181e6512017-11-15 13:28:27 +0000690 {
691 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
692 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
693
694 // Note: Since the input are all positives, we can use uint32_t
695 // Accumulators for the block 0
696 int32x4x4_t c0 =
697 {
698 {
699 vdupq_n_s32(0),
700 vdupq_n_s32(0),
701 vdupq_n_s32(0),
702 vdupq_n_s32(0)
703 }
704 };
705
706 // Accumulators for the block 1
707 int32x4x4_t c1 =
708 {
709 {
710 vdupq_n_s32(0),
711 vdupq_n_s32(0),
712 vdupq_n_s32(0),
713 vdupq_n_s32(0)
714 }
715 };
716
717 // Accumulators for the block 2
718 int32x4x4_t c2 =
719 {
720 {
721 vdupq_n_s32(0),
722 vdupq_n_s32(0),
723 vdupq_n_s32(0),
724 vdupq_n_s32(0)
725 }
726 };
727
728 // Accumulators for the block 3
729 int32x4x4_t c3 =
730 {
731 {
732 vdupq_n_s32(0),
733 vdupq_n_s32(0),
734 vdupq_n_s32(0),
735 vdupq_n_s32(0)
736 }
737 };
738
739 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
740 {
741 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
742 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
743
744 // Convert a00_s8 to uint16_t and get the lower part
745 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
746
747 // Convert b00_s8 to int16_t
748 const int16x4x4_t b00_s16 =
749 {
750 {
751 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
752 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
753 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
754 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
755 }
756 };
757
758 // 4x4 block 0
759 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
760 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
761 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
762 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
763
764 // 4x4 block 1
765 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
766 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
767 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
768 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
769
770 // 4x4 block 2
771 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
772 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
773 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
774 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
775
776 // 4x4 block 3
777 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
778 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
779 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
780 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
781 }
Pablo Tello181e6512017-11-15 13:28:27 +0000782 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
morgolock4adaddb2020-09-29 14:24:32 +0100783 if(id.y() < height_out && id.x() < (width_out - 16))
784 {
785 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
786 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
787 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
788 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
789 if(id.y() + 1 < height_out)
790 {
791 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
792 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
793 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
794 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
795 if(id.y() + 2 < height_out)
796 {
797 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
798 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
799 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
800 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
801 if(id.y() + 3 < height_out)
802 {
803 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
804 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
805 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
806 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
807 }
808 }
809 }
810 }
811 else if(id.y() < height_out)
812 {
813 const auto left_over_value = width_out - id.x();
814 auto left_over = left_over_value;
815 for(auto k = 0; k < 4 && left_over; ++k)
816 {
817 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
818 {
819 *(mtx_out + k * 4 + j) = c0.val[k][j];
820 }
821 }
822 if(id.y() + 1 < height_out)
823 {
824 left_over = left_over_value;
825 for(auto k = 0; k < 4 && left_over; ++k)
826 {
827 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
828 {
829 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
830 }
831 }
832 if(id.y() + 2 < height_out)
833 {
834 left_over = left_over_value;
835 for(auto k = 0; k < 4 && left_over; ++k)
836 {
837 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
838 {
839 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
840 }
841 }
842 if(id.y() + 3 < height_out)
843 {
844 left_over = left_over_value;
845 for(auto k = 0; k < 4 && left_over; ++k)
846 {
847 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
848 {
849 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
850 }
851 }
852 }
853 }
854 }
855 }
856
Pablo Tello181e6512017-11-15 13:28:27 +0000857 },
858 ina, inb, out);
859}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000860} // namespace
861
Gian Marcoc7f9b892017-11-30 14:31:13 +0000862namespace
863{
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000864Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000865{
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +0100866 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000867 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000868 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
869
870 TensorShape in0_shape = input0->tensor_shape();
871 TensorShape in1_shape = input1->tensor_shape();
872 TensorShape out_shape = output->tensor_shape();
873
874 // Check vector-by-matrix case
875 if(out_shape[1] == 1)
876 {
877 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
878 }
879 else
880 {
881 in0_shape.collapse(2);
882 in1_shape.collapse(2);
883 out_shape.collapse(2);
884
885 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
886 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000887 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
Gian Marcoc7f9b892017-11-30 14:31:13 +0000888 }
889
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000890 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000891}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000892} // namespace
893
894NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
895 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
896{
897}
898
899void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
900{
901 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
902 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
903
904 TensorShape in1_shape = input1->info()->tensor_shape();
905 in1_shape.collapse(2);
906
907 _input0 = input0;
908 _input1 = input1;
909 _output = output;
910 _slide_matrix_b = in1_shape[2] != 1;
911
morgolock4adaddb2020-09-29 14:24:32 +0100912 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
913 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
914
915 Window win;
916
917 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
918 if((output->info()->dimension(1) == 1))
919 {
920 // Configure kernel window
921 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
922
923 Coordinates coord;
924 coord.set_num_dimensions(output->info()->num_dimensions());
925 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
926 }
927 else
928 {
929 win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
930 output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
931 }
932
933 INEKernel::configure(win);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000934}
935
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000936Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000937{
938 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000939
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000940 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000941}
Pablo Tello181e6512017-11-15 13:28:27 +0000942
943void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
944{
945 ARM_COMPUTE_UNUSED(info);
946 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
947 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
948
Gian Marcoc7f9b892017-11-30 14:31:13 +0000949 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
950 if((_output->info()->dimension(1) == 1))
Pablo Tello181e6512017-11-15 13:28:27 +0000951 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000952 const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
953 const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
morgolock4adaddb2020-09-29 14:24:32 +0100954 const auto width_out = static_cast<int>(_output->info()->dimension(0));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000955 const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
956
957 // The implementation computes 16 elements per iteration
958 const int window_start_x = 16 * info.thread_id;
959 const int window_step_x = 16 * info.num_threads;
960 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
961 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
962
963 Window win_out(window);
964 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
965 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
966
967 Window win_a(window);
968 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
969 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
970
971 Window win_b;
972 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
973 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
974 if(_input1->info()->num_dimensions() >= 3)
975 {
976 win_b = window;
977 }
978 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
979 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
980
981 Iterator ina(_input0, win_a);
982 Iterator inb(_input1, win_b);
983 Iterator out(_output, win_out);
984
985 switch(_input0->info()->data_type())
986 {
987 case DataType::S8:
Georgios Pinitas63d4dbd2019-11-08 11:51:56 +0000988 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000989 {
morgolock4adaddb2020-09-29 14:24:32 +0100990 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000991 break;
992 }
993 case DataType::U8:
994 case DataType::QASYMM8:
995 {
morgolock4adaddb2020-09-29 14:24:32 +0100996 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000997 break;
998 }
999 default:
1000 {
1001 ARM_COMPUTE_ERROR("Not supported");
1002 break;
1003 }
1004 }
Pablo Tello181e6512017-11-15 13:28:27 +00001005 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001006 else
Pablo Tello181e6512017-11-15 13:28:27 +00001007 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001008 const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
morgolock4adaddb2020-09-29 14:24:32 +01001009 const int width_b = _input1->info()->dimension(0);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001010
1011 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1012 Window win_a(window);
1013 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1014 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1015
1016 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1017 Window win_b;
1018 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1019 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1020 if(_slide_matrix_b)
Pablo Tello181e6512017-11-15 13:28:27 +00001021 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001022 win_b = window;
Pablo Tello181e6512017-11-15 13:28:27 +00001023 }
Gian Marcoc7f9b892017-11-30 14:31:13 +00001024 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1025 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1026
1027 // The step x and step y for the output matrix has been already set using in configure()
1028 Iterator ina(_input0, win_a);
1029 Iterator inb(_input1, win_b);
1030 Iterator out(_output, window);
1031
Gian Marcoc7f9b892017-11-30 14:31:13 +00001032 switch(_input0->info()->data_type())
Pablo Tello181e6512017-11-15 13:28:27 +00001033 {
Gian Marcoc7f9b892017-11-30 14:31:13 +00001034 case DataType::S8:
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +01001035 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +00001036 {
morgolock4adaddb2020-09-29 14:24:32 +01001037 matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001038 break;
1039 }
1040 case DataType::U8:
1041 case DataType::QASYMM8:
1042 {
morgolock4adaddb2020-09-29 14:24:32 +01001043 matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
Gian Marcoc7f9b892017-11-30 14:31:13 +00001044 break;
1045 }
1046 default:
1047 {
1048 ARM_COMPUTE_ERROR("Not supported");
1049 break;
1050 }
Pablo Tello181e6512017-11-15 13:28:27 +00001051 }
1052 }
1053}
morgolock4adaddb2020-09-29 14:24:32 +01001054} // namespace arm_compute
1055
1056