blob: 28f01bd252229a704f8fe9bebcf09d15f588935b [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasf33484f2019-07-29 12:40:59 +010024#ifdef __aarch64__
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010025
26#include "arm_gemm.hpp"
27
28#include <arm_neon.h>
29
30namespace arm_gemm {
31
32namespace {
33
34/* Requantize a block of data, using the requantize parameters in 'qp'.
35 *
36 * row_bias and col_bias are assumed to be precomputed values which include
37 * any externally supplied bias, plus the row/column contibution sums, plus
38 * the overall constant offset (A_offset * B_offset * depth).
39 *
40 * Note that this function works equally well for uint8_t output: just set
41 * minval/maxval appropriately and cast the output pointer. It is caller's
42 * responsibility to ensure that minval/maxval are representable in the
43 * target type - the downcast to (u)int8_t is done by simply extracting the
44 * LSB.
45 *
46 * The 'do_shift_correction' template parameter turns on the correction
47 * applied to negative values being shifted right to make sure they round
48 * properly - if negative values are never output (e.g. fused ReLU) this is
49 * unnecessary.
50 */
51template<bool do_shift_correction>
52void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
53 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
54 const int32_t *row_bias, const int32_t *col_bias) {
55 const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul);
56 const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift);
57 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
58 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
59 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
60
61 /* To make sure we have plenty of accumulators, compute two rows at a
62 * time. If the number of rows is odd, compute the bottom row twice to
63 * avoid needing a duplicate codepath. */
64 for (unsigned int row=0; row<height; row+=2) {
65 /* Prefer to do 4 vectors (16 values) at once as this collapses
66 * neatly to a single vector of output, failing that a vector at a
67 * time and then the odd ones out at the end. */
68 unsigned int blocks=(width / 16);
69 unsigned int regs=(width % 16) / 4;
70 unsigned int odds=(width % 4);
71
72 const int32_t *colptr = col_bias;
73
74 const int32_t *in_ptr = input + (row * in_stride);
75 int8_t *out_ptr = output + (row * out_stride);
76 int32_t row_sum = row_bias[row];
77
78 const int32_t *in_ptr1;
79 int8_t *out_ptr1;
80 int32_t row_sum1;
81
82 if (row == height-1) {
83 in_ptr1 = in_ptr;
84 out_ptr1 = out_ptr;
85 row_sum1 = row_sum;
86 } else {
87 in_ptr1 = in_ptr + in_stride;
88 out_ptr1 = out_ptr + out_stride;
89 row_sum1 = row_bias[row+1];
90 }
91
92 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
93 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
94
95 while (blocks--) {
96 // Load column pointers
97 int32x4_t v_col0 = vld1q_s32(colptr);
98 int32x4_t v_col1 = vld1q_s32(colptr + 4);
99 int32x4_t v_col2 = vld1q_s32(colptr + 8);
100 int32x4_t v_col3 = vld1q_s32(colptr + 12);
101 colptr += 16;
102
103 // Load input data (row 0);
104 int32x4_t v_in00 = vld1q_s32(in_ptr);
105 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
106 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
107 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
108 in_ptr += 16;
109
110 // Load input data (row 1);
111 int32x4_t v_in10 = vld1q_s32(in_ptr1);
112 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
113 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
114 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
115 in_ptr1 += 16;
116
117 // Add on row bias and column bias
118 v_in00 = vaddq_s32(v_in00, v_row_sum);
119 v_in01 = vaddq_s32(v_in01, v_row_sum);
120 v_in02 = vaddq_s32(v_in02, v_row_sum);
121 v_in03 = vaddq_s32(v_in03, v_row_sum);
122
123 v_in10 = vaddq_s32(v_in10, v_row_sum1);
124 v_in11 = vaddq_s32(v_in11, v_row_sum1);
125 v_in12 = vaddq_s32(v_in12, v_row_sum1);
126 v_in13 = vaddq_s32(v_in13, v_row_sum1);
127
128 v_in00 = vaddq_s32(v_in00, v_col0);
129 v_in01 = vaddq_s32(v_in01, v_col1);
130 v_in02 = vaddq_s32(v_in02, v_col2);
131 v_in03 = vaddq_s32(v_in03, v_col3);
132
133 v_in10 = vaddq_s32(v_in10, v_col0);
134 v_in11 = vaddq_s32(v_in11, v_col1);
135 v_in12 = vaddq_s32(v_in12, v_col2);
136 v_in13 = vaddq_s32(v_in13, v_col3);
137
138 // Quantize - start with multiply
139 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
140 v_in01 = vqrdmulhq_s32(v_in01, v_mul);
141 v_in02 = vqrdmulhq_s32(v_in02, v_mul);
142 v_in03 = vqrdmulhq_s32(v_in03, v_mul);
143
144 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
145 v_in11 = vqrdmulhq_s32(v_in11, v_mul);
146 v_in12 = vqrdmulhq_s32(v_in12, v_mul);
147 v_in13 = vqrdmulhq_s32(v_in13, v_mul);
148
149 // Compute and add on corrective offset
150 if (do_shift_correction) {
151 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
152 int32x4_t v_temp01 = vandq_s32(v_in01, v_shift);
153 int32x4_t v_temp02 = vandq_s32(v_in02, v_shift);
154 int32x4_t v_temp03 = vandq_s32(v_in03, v_shift);
155
156 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
157 int32x4_t v_temp11 = vandq_s32(v_in11, v_shift);
158 int32x4_t v_temp12 = vandq_s32(v_in12, v_shift);
159 int32x4_t v_temp13 = vandq_s32(v_in13, v_shift);
160
161 v_temp00 = vshrq_n_s32(v_temp00, 31);
162 v_temp01 = vshrq_n_s32(v_temp01, 31);
163 v_temp02 = vshrq_n_s32(v_temp02, 31);
164 v_temp03 = vshrq_n_s32(v_temp03, 31);
165
166 v_temp10 = vshrq_n_s32(v_temp10, 31);
167 v_temp11 = vshrq_n_s32(v_temp11, 31);
168 v_temp12 = vshrq_n_s32(v_temp12, 31);
169 v_temp13 = vshrq_n_s32(v_temp13, 31);
170
171 v_in00 = vqaddq_s32(v_in00, v_temp00);
172 v_in01 = vqaddq_s32(v_in01, v_temp01);
173 v_in02 = vqaddq_s32(v_in02, v_temp02);
174 v_in03 = vqaddq_s32(v_in03, v_temp03);
175
176 v_in10 = vqaddq_s32(v_in10, v_temp10);
177 v_in11 = vqaddq_s32(v_in11, v_temp11);
178 v_in12 = vqaddq_s32(v_in12, v_temp12);
179 v_in13 = vqaddq_s32(v_in13, v_temp13);
180 }
181
182 v_in00 = vrshlq_s32(v_in00, v_shift);
183 v_in01 = vrshlq_s32(v_in01, v_shift);
184 v_in02 = vrshlq_s32(v_in02, v_shift);
185 v_in03 = vrshlq_s32(v_in03, v_shift);
186
187 v_in10 = vrshlq_s32(v_in10, v_shift);
188 v_in11 = vrshlq_s32(v_in11, v_shift);
189 v_in12 = vrshlq_s32(v_in12, v_shift);
190 v_in13 = vrshlq_s32(v_in13, v_shift);
191
192 v_in00 = vaddq_s32(v_in00, v_c_offset);
193 v_in01 = vaddq_s32(v_in01, v_c_offset);
194 v_in02 = vaddq_s32(v_in02, v_c_offset);
195 v_in03 = vaddq_s32(v_in03, v_c_offset);
196
197 v_in10 = vaddq_s32(v_in10, v_c_offset);
198 v_in11 = vaddq_s32(v_in11, v_c_offset);
199 v_in12 = vaddq_s32(v_in12, v_c_offset);
200 v_in13 = vaddq_s32(v_in13, v_c_offset);
201
202 v_in00 = vmaxq_s32(v_in00, v_minval);
203 v_in01 = vmaxq_s32(v_in01, v_minval);
204 v_in02 = vmaxq_s32(v_in02, v_minval);
205 v_in03 = vmaxq_s32(v_in03, v_minval);
206
207 v_in10 = vmaxq_s32(v_in10, v_minval);
208 v_in11 = vmaxq_s32(v_in11, v_minval);
209 v_in12 = vmaxq_s32(v_in12, v_minval);
210 v_in13 = vmaxq_s32(v_in13, v_minval);
211
212 v_in00 = vminq_s32(v_in00, v_maxval);
213 v_in01 = vminq_s32(v_in01, v_maxval);
214 v_in02 = vminq_s32(v_in02, v_maxval);
215 v_in03 = vminq_s32(v_in03, v_maxval);
216
217 v_in10 = vminq_s32(v_in10, v_maxval);
218 v_in11 = vminq_s32(v_in11, v_maxval);
219 v_in12 = vminq_s32(v_in12, v_maxval);
220 v_in13 = vminq_s32(v_in13, v_maxval);
221
222 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
223 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
224
225 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
226 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
227
228 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
229 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
230
231 vst1q_s8(out_ptr, v_uz0);
232 out_ptr += 16;
233 vst1q_s8(out_ptr1, v_uz1);
234 out_ptr1 += 16;
235 }
236
237 while (regs--) {
238 // Load column pointers
239 int32x4_t v_col0 = vld1q_s32(colptr);
240 colptr += 4;
241
242 // Load input data (row 0);
243 int32x4_t v_in00 = vld1q_s32(in_ptr);
244 in_ptr += 4;
245
246 // Load input data (row 1);
247 int32x4_t v_in10 = vld1q_s32(in_ptr1);
248 in_ptr1 += 4;
249
250 // Add on row sum and bias constant
251 v_in00 = vaddq_s32(v_in00, v_row_sum);
252
253 v_in10 = vaddq_s32(v_in10, v_row_sum1);
254
255 // Subtract col sum * a_offset
256 v_in00 = vaddq_s32(v_in00, v_col0);
257
258 v_in10 = vaddq_s32(v_in10, v_col0);
259
260 // Quantize - start with multiply
261 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
262
263 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
264
265 // Compute and add on corrective offset
266 if (do_shift_correction) {
267 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
268
269 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
270
271 v_temp00 = vshrq_n_s32(v_temp00, 31);
272
273 v_temp10 = vshrq_n_s32(v_temp10, 31);
274
275 v_in00 = vqaddq_s32(v_in00, v_temp00);
276
277 v_in10 = vqaddq_s32(v_in10, v_temp10);
278 }
279
280 v_in00 = vrshlq_s32(v_in00, v_shift);
281
282 v_in10 = vrshlq_s32(v_in10, v_shift);
283
284 v_in00 = vaddq_s32(v_in00, v_c_offset);
285
286 v_in10 = vaddq_s32(v_in10, v_c_offset);
287
288 v_in00 = vmaxq_s32(v_in00, v_minval);
289
290 v_in10 = vmaxq_s32(v_in10, v_minval);
291
292 v_in00 = vminq_s32(v_in00, v_maxval);
293
294 v_in10 = vminq_s32(v_in10, v_maxval);
295
296 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
297
298 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
299
300 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
301 out_ptr += 4;
302 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
303 out_ptr1 += 4;
304 }
305
306 if (odds) {
307 int32x4_t v_col0 = vdupq_n_s32(0);
308 int32x4_t v_in00 = vdupq_n_s32(0);
309 int32x4_t v_in10 = vdupq_n_s32(0);
310
311 do {
312 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
313 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
314 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
315 if (odds == 1) { break; }
316
317 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
318 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
319 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
320 if (odds == 2) { break; }
321
322 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
323 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
324 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
325 } while (0);
326
327 // Add on row sum and bias constant
328 v_in00 = vaddq_s32(v_in00, v_row_sum);
329
330 v_in10 = vaddq_s32(v_in10, v_row_sum1);
331
332 // Subtract col sum * a_offset
333 v_in00 = vaddq_s32(v_in00, v_col0);
334
335 v_in10 = vaddq_s32(v_in10, v_col0);
336
337 // Quantize - start with multiply
338 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
339
340 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
341
342 // Compute and add on corrective offset
343 if (do_shift_correction) {
344 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
345
346 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
347
348 v_temp00 = vshrq_n_s32(v_temp00, 31);
349
350 v_temp10 = vshrq_n_s32(v_temp10, 31);
351
352 v_in00 = vqaddq_s32(v_in00, v_temp00);
353
354 v_in10 = vqaddq_s32(v_in10, v_temp10);
355 }
356
357 v_in00 = vrshlq_s32(v_in00, v_shift);
358
359 v_in10 = vrshlq_s32(v_in10, v_shift);
360
361 v_in00 = vaddq_s32(v_in00, v_c_offset);
362
363 v_in10 = vaddq_s32(v_in10, v_c_offset);
364
365 v_in00 = vmaxq_s32(v_in00, v_minval);
366
367 v_in10 = vmaxq_s32(v_in10, v_minval);
368
369 v_in00 = vminq_s32(v_in00, v_maxval);
370
371 v_in10 = vminq_s32(v_in10, v_maxval);
372
373 do {
374 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
375 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
376
377 if (odds==1) { break; }
378
379 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
380 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
381
382 if (odds==2) { break; }
383
384 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
385 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
386 } while(0);
387 }
388 }
389}
390
391} // anonymous namespace
392
393template<typename Tin, typename Tout>
394void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
395 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
396 const int32_t *row_bias, const int32_t *col_bias) {
397 if (qp.minval >= qp.c_offset) {
398 requantize_block_32_int<false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
399 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
400 } else {
401 requantize_block_32_int<true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
402 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
403 }
404}
405
406template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
407 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
408 const int32_t *row_bias, const int32_t *col_bias);
409
410template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
411 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
412 const int32_t *row_bias, const int32_t *col_bias);
413
414/*
415 * Routine (and helpers) to compute row sums needed for offset correction.
416 *
417 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
418 * of length 27), therefore it's important not to sacrifice performance on
419 * odd length rows.
420 *
421 * To minimize performance loss in these cases, this routine will overread
422 * by up to 7 bytes.
423 *
424 * This is handled via "mask" and "mask mode" parameters to the inner
425 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
426 * (inclusive) needed at the end; in these cases we always read 8 bytes.
427 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
428 * the end, and in this case we always read 16 bytes. In both cases the
429 * 'mask' vector is set up so that the read value can be masked off to clear
430 * the overread lanes. This is handled by 'accumulate_masked_8' and
431 * 'accumulate_masked_16' above.
432 *
433 * This routine is templated on the type to be accumulated, because the
434 * innermost instruction used needs to be of the correct signedness.
435 * However, beyond this point we always use signed values in both cases.
436 * The instructions that need to be different are therefore wrapped in
437 * helper functions below.
Michalis Spyrou400abc82019-08-20 17:25:25 +0100438 *
439 * The general strategy used is to load vectors of 16 bytes and accumulate
440 * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
441 * These are then reduced (using uadalp/sadalp again) into 4x32-bit
442 * accumulators. The 4 accumulators for up to 4 rows being processed are
443 * then added together into a single output vector using pairwise adds.
444 *
445 * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
446 * occur before the 16-bit accumulators can overflow - which is every 32
447 * iterations (512 total bytes processed). This is explained more below.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100448 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100449namespace {
450 struct row_sum_helpers {
451 const ARequantizeLayer32 &qp;
452
453 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
454 template<typename T>
455 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
456
457 /* Load a full 16 byte vector, but mask before accumulation (see above). */
458 template<typename T>
459 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
460
461 /* Load 8 bytes and mask before accumulation. */
462 template<typename T>
463 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
464
465 /* This function does the actual work for up to 4 rows at a time.
466 * It's pulled out so we can template on the row count to generate
467 * the 4 different cases. 4 rows are computed at a time as this
468 * reduces to a single vector write. */
469 template<unsigned int rows, typename T>
470 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
471 int16x8_t sums[rows];
472 int32x4_t finalsums[rows];
473
474 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100475 sums[i] = vdupq_n_s16(0);
476 finalsums[i] = vdupq_n_s32(0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100477 }
478
479 for (unsigned int i=0; i<blocks; i++) {
480 for (unsigned int r=0; r<rows; r++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100481 /* If we add too many blocks together, we run the risk
482 * of overflowing the intermediate 16-bit accumulators,
483 * especially in the unsigned case where we later treat
484 * the accumulator as signed.
485 *
486 * In that case, the maximum (signed) value is 16383,
487 * which is safe for 64 (unsigned) accumulations (255*64
488 * = 16,320).
489 *
490 * Each invocation of pairwise add adds 2 values to the
491 * accumulator - so in the unsigned case we can do 32
492 * adds before we need to reset the 16-bit accumulator
493 * by adding into the 32-bit 'finalsums'.
494 *
495 * We could do 64 adds in the signed case, but that
496 * optimization is not worth the complexity.
497 */
498 if (i > 0 && ((i & 31) == 0)) {
499 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
500 sums[r] = vdupq_n_s16(0);
501 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100502 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
503 }
504 }
505
506 /* Handle the final masked read if needed. */
507 if (mask_mode > 0) {
508 for (unsigned int r=0; r<rows; r++) {
509 if (mask_mode == 1) {
510 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
511 } else {
512 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
513 }
514 }
515 }
516
517 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100518 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100519 }
520
521 int32x4_t t0, t1;
522 int32x2_t t2;
523
524 /* Result writeback - need to write back one value per row
525 * processed. Multiply all the final totals by -b_offset so
526 * that the terms can simply be added in the requantize code.
527 * */
528 switch (rows) {
529 case 1:
530 /* If we only have one output, just use ADDV. Multiply
531 * the offset into all four components separately so it
532 * can stay in the SIMD register file. */
533 t0 = vmulq_s32(finalsums[0], offset_mul);
534 *row_bias = vaddvq_s32(t0);
535 break;
536
537 case 2:
538 /* For two outputs, two rounds of pairwise adds will
539 * generate the result in a 2-vector we can store in one
540 * go. */
541 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
542 t0 = vpaddq_s32(t0, t0);
543 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
544 vst1_s32(row_bias, t2);
545 break;
546
547 case 3:
548 /* Three rows - need to store the low two words plus the odd value from lane 2 */
549 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
550 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
551
552 t0 = vpaddq_s32(t0, t1);
553 t0 = vmulq_s32(t0, offset_mul);
554
555 vst1_s32(row_bias, vget_low_s32(t0));
556 row_bias[2] = vgetq_lane_s32(t0, 2);
557 break;
558
559 case 4:
560 /* Four rows (most common case) - reduce to a single
561 * vector with pairwise adds. */
562 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
563 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
564
565 t0 = vpaddq_s32(t0, t1);
566 t0 = vmulq_s32(t0, offset_mul);
567
568 vst1q_s32(row_bias, t0);
569 break;
570 default:
571 break;
572 }
573 }
574
575 row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { }
576 };
577
578 template<>
579 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
580 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
581 }
582
583 template<>
584 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
585 return vpadalq_s8(sum, vld1q_s8(ptr));
586 }
587
588 template<>
589 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
590 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
591 return vpadalq_s8(sum, v);
592 }
593
594 template<>
595 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
596 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
597 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
598 }
599
600 template<>
601 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
602 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
603 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
604 return vpadalq_s8(sum, v);
605 }
606
607 template<>
608 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
609 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
610 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
611 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
612 }
613}
614
615template<typename T>
616void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
617 const T *input, unsigned int in_stride, int32_t *row_bias) {
618 row_sum_helpers thehelpers(qp);
619
620 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
621
622 /* Work out how many full vectors of 16 bytes we will read, and how many
623 * odd bytes at the end */
624 unsigned int blocks = (width / 16);
625 const unsigned int odds = width % 16;
626
627 /* Generate a mask to use on the last iteration, if necessary. */
628 uint64x2_t mask;
629 unsigned int mask_mode = 0;
630
631 if (odds > 0 && odds <= 8) {
632 /* 1-8 odds: mask in the low lane, 0 in the top */
633 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
634
635 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
636
637 mask_mode = 1;
638 } else if (odds > 8) {
639 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
640 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
641
642 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
643
644 mask_mode = 2;
645 }
646
647 for (unsigned int row=0; row<height; row+=4) {
648 switch(height-row) {
649 default:
650 case 4:
651 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
652 break;
653 case 3:
654 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
655 break;
656 case 2:
657 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
658 break;
659 case 1:
660 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
661 break;
662 }
663 }
664}
665
666/* Instantiate the two versions for uint8_t and int8_t. */
667template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
668template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
669
670template<unsigned int active_rows, typename T>
671inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
672
673template<unsigned int active_rows>
674inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
675 uint8x16_t inputs[4];
676
677 for (unsigned int i=0; i<4; i++) {
678 if (i < active_rows) {
679 inputs[i] = vld1q_u8(input + i * in_stride);
680 } else {
681 inputs[i] = vdupq_n_u8(0);
682 }
683 }
684
685 int16x8_t sums_16b[4];
686
687 // Two adds for the low pairs
688 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
689 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
690 // Two adds for the high pairs
691 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
692 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
693
694 int32x4_t sums_32b[4];
695
696 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
697 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
698 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
699 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
700
701 for (unsigned int i=0; i<4; i++) {
702 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
703 }
704}
705
706template<unsigned int active_rows>
707inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
708 int8x16_t inputs[4];
709
710 for (unsigned int i=0; i<4; i++) {
711 if (i < active_rows) {
712 inputs[i] = vld1q_s8(input + i * in_stride);
713 } else {
714 inputs[i] = vdupq_n_s8(0);
715 }
716 }
717
718 int16x8_t sums_16b[4];
719
720 // Two adds for the low pairs
721 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
722 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
723 // Two adds for the high pairs
724 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
725 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
726
727 int32x4_t sums_32b[4];
728
729 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
730 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
731 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
732 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
733
734 for (unsigned int i=0; i<4; i++) {
735 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
736 }
737}
738
739
740/* "first_col" parameter is used to offset the read into the qp.bias array,
741 * in cases where we are not computing the first columns of the output (i.e.
742 * in multithreaded cases where we divide columns across threads) */
743template<typename T>
744void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col) {
745 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
746
747 for (unsigned int row=0; row<height; row+=4) {
748 unsigned int numrows=std::min(height-row, 4u);
749
750 for (unsigned int col=0; col<width; col+=16) {
751 unsigned int numcols=std::min(width-col, 16u);
752
753 if (numcols==16) {
754 switch(numrows) {
755 case 1:
756 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
757 break;
758
759 case 2:
760 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
761 break;
762
763 case 3:
764 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
765 break;
766
767 case 4:
768 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
769 break;
770 default:
771 break;
772 }
773 } else {
774 for (; col<width; col++) {
775 int32_t sum=0;
776 for (unsigned int r=0; r<numrows; r++) {
777 sum += input[(row + r)*in_stride + col];
778 }
779 col_bias[col] += sum;
780 }
781 }
782 }
783 }
784
785 for (unsigned int col=0; col<width; col++) {
786 int32_t result = col_bias[col];
787
788 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
789
790 if (qp.bias != nullptr) {
791 result += qp.bias[col + first_col];
792 }
793
794 col_bias[col] = result;
795 }
796}
797
798template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col);
799template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col);
800
801} // namespace arm_gemm
Georgios Pinitasf33484f2019-07-29 12:40:59 +0100802
803#endif // __aarch64__