blob: c5d7f10e55d07a68dd148d1ded544a9f403ec6dc [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
25
Gian Marco Iodiceab182122017-10-09 15:05:40 +010026#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Types.h"
32#include "arm_compute/core/Utils.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
35
36#include <arm_neon.h>
37#include <cstddef>
38#include <cstdint>
39#include <tuple>
40
41using namespace arm_compute;
42
43namespace arm_compute
44{
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000045namespace
46{
Gian Marcoc7f9b892017-11-30 14:31:13 +000047void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000048{
Gian Marcoc7f9b892017-11-30 14:31:13 +000049 execute_window_loop(window, [&](const Coordinates & id)
50 {
51 if(id.x() > width_b)
52 {
53 return;
54 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000055
Gian Marcoc7f9b892017-11-30 14:31:13 +000056 // Note: Since the input are all positives, we can use uint32_t
57 // Accumulators for the block 0
58 uint32x4x4_t c0 =
59 {
60 {
61 vdupq_n_u32(0),
62 vdupq_n_u32(0),
63 vdupq_n_u32(0),
64 vdupq_n_u32(0)
65 }
66 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000067
Gian Marcoc7f9b892017-11-30 14:31:13 +000068 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
69 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
70 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000071
Gian Marcoc7f9b892017-11-30 14:31:13 +000072 // This for loop performs 8 accumulations
73 for(; vec_a <= (vec_a_end_addr - 8);)
74 {
75 const uint8x8_t a00_u8 = vld1_u8(vec_a);
76 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +000084
Gian Marcoc7f9b892017-11-30 14:31:13 +000085 // Convert a00_u8 to uint16_t and get the lower part
86 const uint16x4x2_t a00_u16 =
87 {
88 {
89 vget_low_u16(vmovl_u8(a00_u8)),
90 vget_high_u16(vmovl_u8(a00_u8))
91 }
92 };
93
94 const uint16x4x4_t b00_u16 =
95 {
96 {
97 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101 }
102 };
103
104 const uint16x4x4_t b10_u16 =
105 {
106 {
107 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111 }
112 };
113
114 const uint16x4x4_t b20_u16 =
115 {
116 {
117 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121 }
122 };
123
124 const uint16x4x4_t b30_u16 =
125 {
126 {
127 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131 }
132 };
133
134 const uint16x4x4_t b40_u16 =
135 {
136 {
137 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141 }
142 };
143
144 const uint16x4x4_t b50_u16 =
145 {
146 {
147 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151 }
152 };
153
154 const uint16x4x4_t b60_u16 =
155 {
156 {
157 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161 }
162 };
163
164 const uint16x4x4_t b70_u16 =
165 {
166 {
167 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171 }
172 };
173
174 // Accumulate 0:
175 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179
180 // Accumulate 1:
181 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185
186 // Accumulate 2:
187 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191
192 // Accumulate 3:
193 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197
198 // Accumulate 4:
199 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203
204 // Accumulate 5:
205 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209
210 // Accumulate 6:
211 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215
216 // Accumulate 7:
217 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221
222 vec_a += 8;
223 matrix_b += 8 * stride_b;
224 }
225
226 // This for loop performs the left-over accumulations
227 for(; vec_a < vec_a_end_addr;)
228 {
229 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231
232 const uint16x4x4_t b00_u16 =
233 {
234 {
235 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239 }
240 };
241
242 // Convert a00_u8 to uint16_t and get the lower part
243 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244
245 // Accumulate 0:
246 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250
251 vec_a += 1;
252 matrix_b += stride_b;
253 }
254
255 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
256 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
257 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
258 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
259 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
260 },
261 ina, inb, out);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000262}
263
Gian Marcoc7f9b892017-11-30 14:31:13 +0000264void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000265{
Gian Marcoc7f9b892017-11-30 14:31:13 +0000266 execute_window_loop(window, [&](const Coordinates & id)
267 {
268 if(id.x() > width_b)
269 {
270 return;
271 }
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000272
Gian Marcoc7f9b892017-11-30 14:31:13 +0000273 // Accumulators for the block 0
274 int32x4x4_t c0 =
275 {
276 {
277 vdupq_n_s32(0),
278 vdupq_n_s32(0),
279 vdupq_n_s32(0),
280 vdupq_n_s32(0)
281 }
282 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000283
Gian Marcoc7f9b892017-11-30 14:31:13 +0000284 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
285 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
286 auto vec_a_end_addr = vec_a + width_a;
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000287
Gian Marcoc7f9b892017-11-30 14:31:13 +0000288 // This for loop performs 8 accumulations
289 for(; vec_a <= (vec_a_end_addr - 8);)
290 {
291 const int8x8_t a00_s8 = vld1_s8(vec_a);
292 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
293 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
294 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
295 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
296 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
297 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
298 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
299 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000300
Gian Marcoc7f9b892017-11-30 14:31:13 +0000301 // Convert a00_s8 to int16_t and get the lower part
302 const int16x4x2_t a00_s16 =
303 {
304 {
305 vget_low_s16(vmovl_s8(a00_s8)),
306 vget_high_s16(vmovl_s8(a00_s8))
307 }
308 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000309
Gian Marcoc7f9b892017-11-30 14:31:13 +0000310 const int16x4x4_t b00_s16 =
311 {
312 {
313 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
314 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
315 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
316 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
317 }
318 };
Georgios Pinitasa3b1b462017-11-16 19:24:39 +0000319
Gian Marcoc7f9b892017-11-30 14:31:13 +0000320 const int16x4x4_t b10_s16 =
321 {
322 {
323 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
324 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
325 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
326 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
327 }
328 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100329
Gian Marcoc7f9b892017-11-30 14:31:13 +0000330 const int16x4x4_t b20_s16 =
331 {
332 {
333 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
334 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
335 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
336 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
337 }
338 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100339
Gian Marcoc7f9b892017-11-30 14:31:13 +0000340 const int16x4x4_t b30_s16 =
341 {
342 {
343 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
344 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
345 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
346 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
347 }
348 };
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100349
Gian Marcoc7f9b892017-11-30 14:31:13 +0000350 const int16x4x4_t b40_s16 =
351 {
352 {
353 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
354 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
355 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
356 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
357 }
358 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100359
Gian Marcoc7f9b892017-11-30 14:31:13 +0000360 const int16x4x4_t b50_s16 =
361 {
362 {
363 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
364 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
365 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
366 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
367 }
368 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100369
Gian Marcoc7f9b892017-11-30 14:31:13 +0000370 const int16x4x4_t b60_s16 =
371 {
372 {
373 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
374 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
375 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
376 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
377 }
378 };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100379
Gian Marcoc7f9b892017-11-30 14:31:13 +0000380 const int16x4x4_t b70_s16 =
381 {
382 {
383 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
384 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
385 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
386 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
387 }
388 };
389
390 // Accumulate 0:
391 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
392 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
393 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
394 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
395
396 // Accumulate 1:
397 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
398 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
399 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
400 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
401
402 // Accumulate 2:
403 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
404 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
405 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
406 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
407
408 // Accumulate 3:
409 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
410 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
411 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
412 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
413
414 // Accumulate 4:
415 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
416 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
417 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
418 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
419
420 // Accumulate 5:
421 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
422 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
423 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
424 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
425
426 // Accumulate 6:
427 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
428 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
429 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
430 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
431
432 // Accumulate 7:
433 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
434 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
435 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
436 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
437
438 vec_a += 8;
439 matrix_b += 8 * stride_b;
440 }
441
442 // This for loop performs the left-over accumulations
443 for(; vec_a < vec_a_end_addr;)
444 {
445 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
446 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
447
448 const int16x4x4_t b00_s16 =
449 {
450 {
451 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
452 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
453 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
454 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
455 }
456 };
457
458 // Convert a00_s8 to uint16_t and get the lower part
459 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
460
461 // Accumulate 0:
462 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
463 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
464 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
465 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
466
467 vec_a += 1;
468 matrix_b += stride_b;
469 }
470
471 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
472 vst1q_s32(vec_out + 0, c0.val[0]);
473 vst1q_s32(vec_out + 4, c0.val[1]);
474 vst1q_s32(vec_out + 8, c0.val[2]);
475 vst1q_s32(vec_out + 12, c0.val[3]);
476 },
477 ina, inb, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100478}
479
Pablo Tello181e6512017-11-15 13:28:27 +0000480void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100481{
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100482 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000484 const uint8_t *mtx_a0 = ina.ptr();
485 const uint8_t *mtx_b0 = inb.ptr();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100486
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100487 // Note: Since the input are all positives, we can use uint32_t
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100488 // Accumulators for the block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000489 uint32x4x4_t c0 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100490 {
491 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000492 vdupq_n_u32(0),
493 vdupq_n_u32(0),
494 vdupq_n_u32(0),
495 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100496 }
497 };
498
499 // Accumulators for the block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000500 uint32x4x4_t c1 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100501 {
502 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000503 vdupq_n_u32(0),
504 vdupq_n_u32(0),
505 vdupq_n_u32(0),
506 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100507 }
508 };
509
510 // Accumulators for the block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000511 uint32x4x4_t c2 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 {
513 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000514 vdupq_n_u32(0),
515 vdupq_n_u32(0),
516 vdupq_n_u32(0),
517 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100518 }
519 };
520
521 // Accumulators for the block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000522 uint32x4x4_t c3 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100523 {
524 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000525 vdupq_n_u32(0),
526 vdupq_n_u32(0),
527 vdupq_n_u32(0),
528 vdupq_n_u32(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100529 }
530 };
531
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100532 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100533 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000534 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
535 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100536
Gian Marcoc7f9b892017-11-30 14:31:13 +0000537 // Convert a00_u8 to uint16_t and get the lower part
Gian Marcoe75a02b2017-11-08 12:24:09 +0000538 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100539
Gian Marcoe75a02b2017-11-08 12:24:09 +0000540 // Convert b00_s8 to uint16_t
541 const uint16x4x4_t b00_u16 =
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 {
543 {
Gian Marcoe75a02b2017-11-08 12:24:09 +0000544 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
545 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
546 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
547 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100548 }
549 };
550
551 // 4x4 block 0
Gian Marcoe75a02b2017-11-08 12:24:09 +0000552 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
553 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
554 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
555 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100556
557 // 4x4 block 1
Gian Marcoe75a02b2017-11-08 12:24:09 +0000558 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
559 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
560 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
561 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100562
563 // 4x4 block 2
Gian Marcoe75a02b2017-11-08 12:24:09 +0000564 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
565 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
566 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
567 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568
569 // 4x4 block 3
Gian Marcoe75a02b2017-11-08 12:24:09 +0000570 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
571 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
572 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
573 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100574 }
575
Gian Marco Iodiceab182122017-10-09 15:05:40 +0100576 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
Gian Marcoe75a02b2017-11-08 12:24:09 +0000577 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
578 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
579 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
580 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
581 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
582 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
583 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
584 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
585 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
586 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
587 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
588 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
589 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
590 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
591 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
592 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100593 },
594 ina, inb, out);
595}
Pablo Tello181e6512017-11-15 13:28:27 +0000596
597void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
598{
599 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
600 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
601 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100602 execute_window_loop(window, [&](const Coordinates &)
Pablo Tello181e6512017-11-15 13:28:27 +0000603 {
604 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
605 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
606
607 // Note: Since the input are all positives, we can use uint32_t
608 // Accumulators for the block 0
609 int32x4x4_t c0 =
610 {
611 {
612 vdupq_n_s32(0),
613 vdupq_n_s32(0),
614 vdupq_n_s32(0),
615 vdupq_n_s32(0)
616 }
617 };
618
619 // Accumulators for the block 1
620 int32x4x4_t c1 =
621 {
622 {
623 vdupq_n_s32(0),
624 vdupq_n_s32(0),
625 vdupq_n_s32(0),
626 vdupq_n_s32(0)
627 }
628 };
629
630 // Accumulators for the block 2
631 int32x4x4_t c2 =
632 {
633 {
634 vdupq_n_s32(0),
635 vdupq_n_s32(0),
636 vdupq_n_s32(0),
637 vdupq_n_s32(0)
638 }
639 };
640
641 // Accumulators for the block 3
642 int32x4x4_t c3 =
643 {
644 {
645 vdupq_n_s32(0),
646 vdupq_n_s32(0),
647 vdupq_n_s32(0),
648 vdupq_n_s32(0)
649 }
650 };
651
652 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
653 {
654 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
655 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
656
657 // Convert a00_s8 to uint16_t and get the lower part
658 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
659
660 // Convert b00_s8 to int16_t
661 const int16x4x4_t b00_s16 =
662 {
663 {
664 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
665 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
666 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
667 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
668 }
669 };
670
671 // 4x4 block 0
672 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
673 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
674 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
675 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
676
677 // 4x4 block 1
678 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
679 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
680 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
681 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
682
683 // 4x4 block 2
684 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
685 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
686 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
687 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
688
689 // 4x4 block 3
690 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
691 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
692 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
693 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
694 }
695
696 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
697 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
698 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
699 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
700 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
701 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
702 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
703 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
704 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
705 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
706 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
707 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
708 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
709 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
710 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
711 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
712 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
713 },
714 ina, inb, out);
715}
Gian Marcoc7f9b892017-11-30 14:31:13 +0000716} // namespace
717
718class Coordinates;
719} // namespace arm_compute
720
721namespace
722{
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000723Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000724{
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +0100725 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000726 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000727 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
728
729 TensorShape in0_shape = input0->tensor_shape();
730 TensorShape in1_shape = input1->tensor_shape();
731 TensorShape out_shape = output->tensor_shape();
732
733 // Check vector-by-matrix case
734 if(out_shape[1] == 1)
735 {
736 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
737 }
738 else
739 {
740 in0_shape.collapse(2);
741 in1_shape.collapse(2);
742 out_shape.collapse(2);
743
744 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
745 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000746 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
Gian Marcoc7f9b892017-11-30 14:31:13 +0000747 }
748
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000749 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000750}
751
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000752std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000753{
754 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
755 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
756
757 Window win;
758 bool window_changed = false;
759
760 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
761 if((output->dimension(1) == 1))
762 {
763 // Configure kernel window
764 win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
765
766 // We cannot read out-of-bound elements from matrix A as we use the left-over for loop
767 AccessWindowStatic in0_access(input0, 0, 0, input0->tensor_shape().x(), 1);
768 AccessWindowHorizontal in1_access(input1, 0, num_elems_processed_per_iteration_x);
769 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
770
771 window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
772
773 Coordinates coord;
774 coord.set_num_dimensions(output->num_dimensions());
775 output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
776 }
777 else
778 {
779 win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
780
Anthony Barbier93b9bdb2017-12-12 11:27:55 +0000781 unsigned int num_k_iterations = ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x) / 16;
782 // For each iteration of "k" we increment the input pointer by 4, and we load 8 elements a the time:
Michele Di Giorgio9d3e7f92019-08-13 14:23:21 +0100783 AccessWindowStatic in0_access(input0, 0, 0, (num_k_iterations - 1) * 4 + 8, input0->dimension(1));
784 AccessWindowStatic in1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
785 AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
Gian Marcoc7f9b892017-11-30 14:31:13 +0000786
787 window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
788
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000789 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Gian Marcoc7f9b892017-11-30 14:31:13 +0000790 }
791
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000792 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000793 return std::make_pair(err, win);
794}
795} // namespace
796
797NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
798 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
799{
800}
801
802void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
803{
804 ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
805 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
806
807 TensorShape in1_shape = input1->info()->tensor_shape();
808 in1_shape.collapse(2);
809
810 _input0 = input0;
811 _input1 = input1;
812 _output = output;
813 _slide_matrix_b = in1_shape[2] != 1;
814
815 // Configure kernel window
816 auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
817 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
818 INEKernel::configure(win_config.second);
819}
820
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000821Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Gian Marcoc7f9b892017-11-30 14:31:13 +0000822{
823 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
824 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
825
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000826 return Status{};
Gian Marcoc7f9b892017-11-30 14:31:13 +0000827}
Pablo Tello181e6512017-11-15 13:28:27 +0000828
829void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
830{
831 ARM_COMPUTE_UNUSED(info);
832 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
833 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
834
Gian Marcoc7f9b892017-11-30 14:31:13 +0000835 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
836 if((_output->info()->dimension(1) == 1))
Pablo Tello181e6512017-11-15 13:28:27 +0000837 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000838 const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
839 const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
840 const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
841
842 // The implementation computes 16 elements per iteration
843 const int window_start_x = 16 * info.thread_id;
844 const int window_step_x = 16 * info.num_threads;
845 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
846 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
847
848 Window win_out(window);
849 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
850 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
851
852 Window win_a(window);
853 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
854 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
855
856 Window win_b;
857 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
858 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
859 if(_input1->info()->num_dimensions() >= 3)
860 {
861 win_b = window;
862 }
863 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
864 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
865
866 Iterator ina(_input0, win_a);
867 Iterator inb(_input1, win_b);
868 Iterator out(_output, win_out);
869
870 switch(_input0->info()->data_type())
871 {
872 case DataType::S8:
Georgios Pinitas63d4dbd2019-11-08 11:51:56 +0000873 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000874 {
875 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
876 break;
877 }
878 case DataType::U8:
879 case DataType::QASYMM8:
880 {
881 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
882 break;
883 }
884 default:
885 {
886 ARM_COMPUTE_ERROR("Not supported");
887 break;
888 }
889 }
Pablo Tello181e6512017-11-15 13:28:27 +0000890 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000891 else
Pablo Tello181e6512017-11-15 13:28:27 +0000892 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000893 const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
894 const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size();
895
896 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
897 Window win_a(window);
898 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
899 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
900
901 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
902 Window win_b;
903 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
904 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
905 if(_slide_matrix_b)
Pablo Tello181e6512017-11-15 13:28:27 +0000906 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000907 win_b = window;
Pablo Tello181e6512017-11-15 13:28:27 +0000908 }
Gian Marcoc7f9b892017-11-30 14:31:13 +0000909 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
910 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
911
912 // The step x and step y for the output matrix has been already set using in configure()
913 Iterator ina(_input0, win_a);
914 Iterator inb(_input1, win_b);
915 Iterator out(_output, window);
916
917 const int width_b = _input1->info()->dimension(0);
918 switch(_input0->info()->data_type())
Pablo Tello181e6512017-11-15 13:28:27 +0000919 {
Gian Marcoc7f9b892017-11-30 14:31:13 +0000920 case DataType::S8:
Georgios Pinitasdbdea0d2019-10-16 19:21:40 +0100921 case DataType::QASYMM8_SIGNED:
Gian Marcoc7f9b892017-11-30 14:31:13 +0000922 {
923 matrix_multiply_s8(ina, inb, out, width_b, out_stride, window);
924 break;
925 }
926 case DataType::U8:
927 case DataType::QASYMM8:
928 {
929 matrix_multiply_u8(ina, inb, out, width_b, out_stride, window);
930 break;
931 }
932 default:
933 {
934 ARM_COMPUTE_ERROR("Not supported");
935 break;
936 }
Pablo Tello181e6512017-11-15 13:28:27 +0000937 }
938 }
939}