blob: dd4eb31ea3c6e54be9b7ca10a58abd9670f6a3fd [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_gemm.hpp"
26
27#include <arm_neon.h>
28
29namespace arm_gemm {
30
31namespace {
32
33/* Requantize a block of data, using the requantize parameters in 'qp'.
34 *
35 * row_bias and col_bias are assumed to be precomputed values which include
36 * any externally supplied bias, plus the row/column contibution sums, plus
37 * the overall constant offset (A_offset * B_offset * depth).
38 *
39 * Note that this function works equally well for uint8_t output: just set
40 * minval/maxval appropriately and cast the output pointer. It is caller's
41 * responsibility to ensure that minval/maxval are representable in the
42 * target type - the downcast to (u)int8_t is done by simply extracting the
43 * LSB.
44 *
45 * The 'do_shift_correction' template parameter turns on the correction
46 * applied to negative values being shifted right to make sure they round
47 * properly - if negative values are never output (e.g. fused ReLU) this is
48 * unnecessary.
49 */
50template<bool do_shift_correction>
51void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
52 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
53 const int32_t *row_bias, const int32_t *col_bias) {
54 const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul);
55 const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift);
56 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
57 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
58 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
59
60 /* To make sure we have plenty of accumulators, compute two rows at a
61 * time. If the number of rows is odd, compute the bottom row twice to
62 * avoid needing a duplicate codepath. */
63 for (unsigned int row=0; row<height; row+=2) {
64 /* Prefer to do 4 vectors (16 values) at once as this collapses
65 * neatly to a single vector of output, failing that a vector at a
66 * time and then the odd ones out at the end. */
67 unsigned int blocks=(width / 16);
68 unsigned int regs=(width % 16) / 4;
69 unsigned int odds=(width % 4);
70
71 const int32_t *colptr = col_bias;
72
73 const int32_t *in_ptr = input + (row * in_stride);
74 int8_t *out_ptr = output + (row * out_stride);
75 int32_t row_sum = row_bias[row];
76
77 const int32_t *in_ptr1;
78 int8_t *out_ptr1;
79 int32_t row_sum1;
80
81 if (row == height-1) {
82 in_ptr1 = in_ptr;
83 out_ptr1 = out_ptr;
84 row_sum1 = row_sum;
85 } else {
86 in_ptr1 = in_ptr + in_stride;
87 out_ptr1 = out_ptr + out_stride;
88 row_sum1 = row_bias[row+1];
89 }
90
91 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
92 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
93
94 while (blocks--) {
95 // Load column pointers
96 int32x4_t v_col0 = vld1q_s32(colptr);
97 int32x4_t v_col1 = vld1q_s32(colptr + 4);
98 int32x4_t v_col2 = vld1q_s32(colptr + 8);
99 int32x4_t v_col3 = vld1q_s32(colptr + 12);
100 colptr += 16;
101
102 // Load input data (row 0);
103 int32x4_t v_in00 = vld1q_s32(in_ptr);
104 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
105 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
106 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
107 in_ptr += 16;
108
109 // Load input data (row 1);
110 int32x4_t v_in10 = vld1q_s32(in_ptr1);
111 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
112 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
113 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
114 in_ptr1 += 16;
115
116 // Add on row bias and column bias
117 v_in00 = vaddq_s32(v_in00, v_row_sum);
118 v_in01 = vaddq_s32(v_in01, v_row_sum);
119 v_in02 = vaddq_s32(v_in02, v_row_sum);
120 v_in03 = vaddq_s32(v_in03, v_row_sum);
121
122 v_in10 = vaddq_s32(v_in10, v_row_sum1);
123 v_in11 = vaddq_s32(v_in11, v_row_sum1);
124 v_in12 = vaddq_s32(v_in12, v_row_sum1);
125 v_in13 = vaddq_s32(v_in13, v_row_sum1);
126
127 v_in00 = vaddq_s32(v_in00, v_col0);
128 v_in01 = vaddq_s32(v_in01, v_col1);
129 v_in02 = vaddq_s32(v_in02, v_col2);
130 v_in03 = vaddq_s32(v_in03, v_col3);
131
132 v_in10 = vaddq_s32(v_in10, v_col0);
133 v_in11 = vaddq_s32(v_in11, v_col1);
134 v_in12 = vaddq_s32(v_in12, v_col2);
135 v_in13 = vaddq_s32(v_in13, v_col3);
136
137 // Quantize - start with multiply
138 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
139 v_in01 = vqrdmulhq_s32(v_in01, v_mul);
140 v_in02 = vqrdmulhq_s32(v_in02, v_mul);
141 v_in03 = vqrdmulhq_s32(v_in03, v_mul);
142
143 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
144 v_in11 = vqrdmulhq_s32(v_in11, v_mul);
145 v_in12 = vqrdmulhq_s32(v_in12, v_mul);
146 v_in13 = vqrdmulhq_s32(v_in13, v_mul);
147
148 // Compute and add on corrective offset
149 if (do_shift_correction) {
150 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
151 int32x4_t v_temp01 = vandq_s32(v_in01, v_shift);
152 int32x4_t v_temp02 = vandq_s32(v_in02, v_shift);
153 int32x4_t v_temp03 = vandq_s32(v_in03, v_shift);
154
155 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
156 int32x4_t v_temp11 = vandq_s32(v_in11, v_shift);
157 int32x4_t v_temp12 = vandq_s32(v_in12, v_shift);
158 int32x4_t v_temp13 = vandq_s32(v_in13, v_shift);
159
160 v_temp00 = vshrq_n_s32(v_temp00, 31);
161 v_temp01 = vshrq_n_s32(v_temp01, 31);
162 v_temp02 = vshrq_n_s32(v_temp02, 31);
163 v_temp03 = vshrq_n_s32(v_temp03, 31);
164
165 v_temp10 = vshrq_n_s32(v_temp10, 31);
166 v_temp11 = vshrq_n_s32(v_temp11, 31);
167 v_temp12 = vshrq_n_s32(v_temp12, 31);
168 v_temp13 = vshrq_n_s32(v_temp13, 31);
169
170 v_in00 = vqaddq_s32(v_in00, v_temp00);
171 v_in01 = vqaddq_s32(v_in01, v_temp01);
172 v_in02 = vqaddq_s32(v_in02, v_temp02);
173 v_in03 = vqaddq_s32(v_in03, v_temp03);
174
175 v_in10 = vqaddq_s32(v_in10, v_temp10);
176 v_in11 = vqaddq_s32(v_in11, v_temp11);
177 v_in12 = vqaddq_s32(v_in12, v_temp12);
178 v_in13 = vqaddq_s32(v_in13, v_temp13);
179 }
180
181 v_in00 = vrshlq_s32(v_in00, v_shift);
182 v_in01 = vrshlq_s32(v_in01, v_shift);
183 v_in02 = vrshlq_s32(v_in02, v_shift);
184 v_in03 = vrshlq_s32(v_in03, v_shift);
185
186 v_in10 = vrshlq_s32(v_in10, v_shift);
187 v_in11 = vrshlq_s32(v_in11, v_shift);
188 v_in12 = vrshlq_s32(v_in12, v_shift);
189 v_in13 = vrshlq_s32(v_in13, v_shift);
190
191 v_in00 = vaddq_s32(v_in00, v_c_offset);
192 v_in01 = vaddq_s32(v_in01, v_c_offset);
193 v_in02 = vaddq_s32(v_in02, v_c_offset);
194 v_in03 = vaddq_s32(v_in03, v_c_offset);
195
196 v_in10 = vaddq_s32(v_in10, v_c_offset);
197 v_in11 = vaddq_s32(v_in11, v_c_offset);
198 v_in12 = vaddq_s32(v_in12, v_c_offset);
199 v_in13 = vaddq_s32(v_in13, v_c_offset);
200
201 v_in00 = vmaxq_s32(v_in00, v_minval);
202 v_in01 = vmaxq_s32(v_in01, v_minval);
203 v_in02 = vmaxq_s32(v_in02, v_minval);
204 v_in03 = vmaxq_s32(v_in03, v_minval);
205
206 v_in10 = vmaxq_s32(v_in10, v_minval);
207 v_in11 = vmaxq_s32(v_in11, v_minval);
208 v_in12 = vmaxq_s32(v_in12, v_minval);
209 v_in13 = vmaxq_s32(v_in13, v_minval);
210
211 v_in00 = vminq_s32(v_in00, v_maxval);
212 v_in01 = vminq_s32(v_in01, v_maxval);
213 v_in02 = vminq_s32(v_in02, v_maxval);
214 v_in03 = vminq_s32(v_in03, v_maxval);
215
216 v_in10 = vminq_s32(v_in10, v_maxval);
217 v_in11 = vminq_s32(v_in11, v_maxval);
218 v_in12 = vminq_s32(v_in12, v_maxval);
219 v_in13 = vminq_s32(v_in13, v_maxval);
220
221 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
222 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
223
224 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
225 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
226
227 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
228 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
229
230 vst1q_s8(out_ptr, v_uz0);
231 out_ptr += 16;
232 vst1q_s8(out_ptr1, v_uz1);
233 out_ptr1 += 16;
234 }
235
236 while (regs--) {
237 // Load column pointers
238 int32x4_t v_col0 = vld1q_s32(colptr);
239 colptr += 4;
240
241 // Load input data (row 0);
242 int32x4_t v_in00 = vld1q_s32(in_ptr);
243 in_ptr += 4;
244
245 // Load input data (row 1);
246 int32x4_t v_in10 = vld1q_s32(in_ptr1);
247 in_ptr1 += 4;
248
249 // Add on row sum and bias constant
250 v_in00 = vaddq_s32(v_in00, v_row_sum);
251
252 v_in10 = vaddq_s32(v_in10, v_row_sum1);
253
254 // Subtract col sum * a_offset
255 v_in00 = vaddq_s32(v_in00, v_col0);
256
257 v_in10 = vaddq_s32(v_in10, v_col0);
258
259 // Quantize - start with multiply
260 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
261
262 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
263
264 // Compute and add on corrective offset
265 if (do_shift_correction) {
266 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
267
268 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
269
270 v_temp00 = vshrq_n_s32(v_temp00, 31);
271
272 v_temp10 = vshrq_n_s32(v_temp10, 31);
273
274 v_in00 = vqaddq_s32(v_in00, v_temp00);
275
276 v_in10 = vqaddq_s32(v_in10, v_temp10);
277 }
278
279 v_in00 = vrshlq_s32(v_in00, v_shift);
280
281 v_in10 = vrshlq_s32(v_in10, v_shift);
282
283 v_in00 = vaddq_s32(v_in00, v_c_offset);
284
285 v_in10 = vaddq_s32(v_in10, v_c_offset);
286
287 v_in00 = vmaxq_s32(v_in00, v_minval);
288
289 v_in10 = vmaxq_s32(v_in10, v_minval);
290
291 v_in00 = vminq_s32(v_in00, v_maxval);
292
293 v_in10 = vminq_s32(v_in10, v_maxval);
294
295 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
296
297 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
298
299 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
300 out_ptr += 4;
301 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
302 out_ptr1 += 4;
303 }
304
305 if (odds) {
306 int32x4_t v_col0 = vdupq_n_s32(0);
307 int32x4_t v_in00 = vdupq_n_s32(0);
308 int32x4_t v_in10 = vdupq_n_s32(0);
309
310 do {
311 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
312 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
313 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
314 if (odds == 1) { break; }
315
316 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
317 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
318 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
319 if (odds == 2) { break; }
320
321 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
322 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
323 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
324 } while (0);
325
326 // Add on row sum and bias constant
327 v_in00 = vaddq_s32(v_in00, v_row_sum);
328
329 v_in10 = vaddq_s32(v_in10, v_row_sum1);
330
331 // Subtract col sum * a_offset
332 v_in00 = vaddq_s32(v_in00, v_col0);
333
334 v_in10 = vaddq_s32(v_in10, v_col0);
335
336 // Quantize - start with multiply
337 v_in00 = vqrdmulhq_s32(v_in00, v_mul);
338
339 v_in10 = vqrdmulhq_s32(v_in10, v_mul);
340
341 // Compute and add on corrective offset
342 if (do_shift_correction) {
343 int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
344
345 int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
346
347 v_temp00 = vshrq_n_s32(v_temp00, 31);
348
349 v_temp10 = vshrq_n_s32(v_temp10, 31);
350
351 v_in00 = vqaddq_s32(v_in00, v_temp00);
352
353 v_in10 = vqaddq_s32(v_in10, v_temp10);
354 }
355
356 v_in00 = vrshlq_s32(v_in00, v_shift);
357
358 v_in10 = vrshlq_s32(v_in10, v_shift);
359
360 v_in00 = vaddq_s32(v_in00, v_c_offset);
361
362 v_in10 = vaddq_s32(v_in10, v_c_offset);
363
364 v_in00 = vmaxq_s32(v_in00, v_minval);
365
366 v_in10 = vmaxq_s32(v_in10, v_minval);
367
368 v_in00 = vminq_s32(v_in00, v_maxval);
369
370 v_in10 = vminq_s32(v_in10, v_maxval);
371
372 do {
373 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
374 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
375
376 if (odds==1) { break; }
377
378 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
379 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
380
381 if (odds==2) { break; }
382
383 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
384 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
385 } while(0);
386 }
387 }
388}
389
390} // anonymous namespace
391
392template<typename Tin, typename Tout>
393void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
394 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
395 const int32_t *row_bias, const int32_t *col_bias) {
396 if (qp.minval >= qp.c_offset) {
397 requantize_block_32_int<false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
398 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
399 } else {
400 requantize_block_32_int<true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
401 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
402 }
403}
404
405template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
406 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
407 const int32_t *row_bias, const int32_t *col_bias);
408
409template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
410 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
411 const int32_t *row_bias, const int32_t *col_bias);
412
413/*
414 * Routine (and helpers) to compute row sums needed for offset correction.
415 *
416 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
417 * of length 27), therefore it's important not to sacrifice performance on
418 * odd length rows.
419 *
420 * To minimize performance loss in these cases, this routine will overread
421 * by up to 7 bytes.
422 *
423 * This is handled via "mask" and "mask mode" parameters to the inner
424 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
425 * (inclusive) needed at the end; in these cases we always read 8 bytes.
426 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
427 * the end, and in this case we always read 16 bytes. In both cases the
428 * 'mask' vector is set up so that the read value can be masked off to clear
429 * the overread lanes. This is handled by 'accumulate_masked_8' and
430 * 'accumulate_masked_16' above.
431 *
432 * This routine is templated on the type to be accumulated, because the
433 * innermost instruction used needs to be of the correct signedness.
434 * However, beyond this point we always use signed values in both cases.
435 * The instructions that need to be different are therefore wrapped in
436 * helper functions below.
437 */
438
439namespace {
440 struct row_sum_helpers {
441 const ARequantizeLayer32 &qp;
442
443 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
444 template<typename T>
445 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
446
447 /* Load a full 16 byte vector, but mask before accumulation (see above). */
448 template<typename T>
449 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
450
451 /* Load 8 bytes and mask before accumulation. */
452 template<typename T>
453 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
454
455 /* This function does the actual work for up to 4 rows at a time.
456 * It's pulled out so we can template on the row count to generate
457 * the 4 different cases. 4 rows are computed at a time as this
458 * reduces to a single vector write. */
459 template<unsigned int rows, typename T>
460 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
461 int16x8_t sums[rows];
462 int32x4_t finalsums[rows];
463
464 for (unsigned int i=0; i<rows; i++) {
465 sums[i] = vdupq_n_s16(0);
466 }
467
468 for (unsigned int i=0; i<blocks; i++) {
469 for (unsigned int r=0; r<rows; r++) {
470 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
471 }
472 }
473
474 /* Handle the final masked read if needed. */
475 if (mask_mode > 0) {
476 for (unsigned int r=0; r<rows; r++) {
477 if (mask_mode == 1) {
478 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
479 } else {
480 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
481 }
482 }
483 }
484
485 for (unsigned int i=0; i<rows; i++) {
486 finalsums[i] = vpaddlq_s16(sums[i]);
487 }
488
489 int32x4_t t0, t1;
490 int32x2_t t2;
491
492 /* Result writeback - need to write back one value per row
493 * processed. Multiply all the final totals by -b_offset so
494 * that the terms can simply be added in the requantize code.
495 * */
496 switch (rows) {
497 case 1:
498 /* If we only have one output, just use ADDV. Multiply
499 * the offset into all four components separately so it
500 * can stay in the SIMD register file. */
501 t0 = vmulq_s32(finalsums[0], offset_mul);
502 *row_bias = vaddvq_s32(t0);
503 break;
504
505 case 2:
506 /* For two outputs, two rounds of pairwise adds will
507 * generate the result in a 2-vector we can store in one
508 * go. */
509 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
510 t0 = vpaddq_s32(t0, t0);
511 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
512 vst1_s32(row_bias, t2);
513 break;
514
515 case 3:
516 /* Three rows - need to store the low two words plus the odd value from lane 2 */
517 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
518 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
519
520 t0 = vpaddq_s32(t0, t1);
521 t0 = vmulq_s32(t0, offset_mul);
522
523 vst1_s32(row_bias, vget_low_s32(t0));
524 row_bias[2] = vgetq_lane_s32(t0, 2);
525 break;
526
527 case 4:
528 /* Four rows (most common case) - reduce to a single
529 * vector with pairwise adds. */
530 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
531 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
532
533 t0 = vpaddq_s32(t0, t1);
534 t0 = vmulq_s32(t0, offset_mul);
535
536 vst1q_s32(row_bias, t0);
537 break;
538 default:
539 break;
540 }
541 }
542
543 row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { }
544 };
545
546 template<>
547 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
548 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
549 }
550
551 template<>
552 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
553 return vpadalq_s8(sum, vld1q_s8(ptr));
554 }
555
556 template<>
557 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
558 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
559 return vpadalq_s8(sum, v);
560 }
561
562 template<>
563 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
564 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
565 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
566 }
567
568 template<>
569 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
570 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
571 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
572 return vpadalq_s8(sum, v);
573 }
574
575 template<>
576 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
577 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
578 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
579 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
580 }
581}
582
583template<typename T>
584void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
585 const T *input, unsigned int in_stride, int32_t *row_bias) {
586 row_sum_helpers thehelpers(qp);
587
588 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
589
590 /* Work out how many full vectors of 16 bytes we will read, and how many
591 * odd bytes at the end */
592 unsigned int blocks = (width / 16);
593 const unsigned int odds = width % 16;
594
595 /* Generate a mask to use on the last iteration, if necessary. */
596 uint64x2_t mask;
597 unsigned int mask_mode = 0;
598
599 if (odds > 0 && odds <= 8) {
600 /* 1-8 odds: mask in the low lane, 0 in the top */
601 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
602
603 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
604
605 mask_mode = 1;
606 } else if (odds > 8) {
607 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
608 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
609
610 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
611
612 mask_mode = 2;
613 }
614
615 for (unsigned int row=0; row<height; row+=4) {
616 switch(height-row) {
617 default:
618 case 4:
619 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
620 break;
621 case 3:
622 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
623 break;
624 case 2:
625 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
626 break;
627 case 1:
628 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
629 break;
630 }
631 }
632}
633
634/* Instantiate the two versions for uint8_t and int8_t. */
635template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
636template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
637
638template<unsigned int active_rows, typename T>
639inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
640
641template<unsigned int active_rows>
642inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
643 uint8x16_t inputs[4];
644
645 for (unsigned int i=0; i<4; i++) {
646 if (i < active_rows) {
647 inputs[i] = vld1q_u8(input + i * in_stride);
648 } else {
649 inputs[i] = vdupq_n_u8(0);
650 }
651 }
652
653 int16x8_t sums_16b[4];
654
655 // Two adds for the low pairs
656 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
657 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
658 // Two adds for the high pairs
659 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
660 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
661
662 int32x4_t sums_32b[4];
663
664 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
665 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
666 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
667 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
668
669 for (unsigned int i=0; i<4; i++) {
670 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
671 }
672}
673
674template<unsigned int active_rows>
675inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
676 int8x16_t inputs[4];
677
678 for (unsigned int i=0; i<4; i++) {
679 if (i < active_rows) {
680 inputs[i] = vld1q_s8(input + i * in_stride);
681 } else {
682 inputs[i] = vdupq_n_s8(0);
683 }
684 }
685
686 int16x8_t sums_16b[4];
687
688 // Two adds for the low pairs
689 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
690 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
691 // Two adds for the high pairs
692 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
693 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
694
695 int32x4_t sums_32b[4];
696
697 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
698 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
699 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
700 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
701
702 for (unsigned int i=0; i<4; i++) {
703 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
704 }
705}
706
707
708/* "first_col" parameter is used to offset the read into the qp.bias array,
709 * in cases where we are not computing the first columns of the output (i.e.
710 * in multithreaded cases where we divide columns across threads) */
711template<typename T>
712void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col) {
713 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
714
715 for (unsigned int row=0; row<height; row+=4) {
716 unsigned int numrows=std::min(height-row, 4u);
717
718 for (unsigned int col=0; col<width; col+=16) {
719 unsigned int numcols=std::min(width-col, 16u);
720
721 if (numcols==16) {
722 switch(numrows) {
723 case 1:
724 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
725 break;
726
727 case 2:
728 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
729 break;
730
731 case 3:
732 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
733 break;
734
735 case 4:
736 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
737 break;
738 default:
739 break;
740 }
741 } else {
742 for (; col<width; col++) {
743 int32_t sum=0;
744 for (unsigned int r=0; r<numrows; r++) {
745 sum += input[(row + r)*in_stride + col];
746 }
747 col_bias[col] += sum;
748 }
749 }
750 }
751 }
752
753 for (unsigned int col=0; col<width; col++) {
754 int32_t result = col_bias[col];
755
756 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
757
758 if (qp.bias != nullptr) {
759 result += qp.bias[col + first_col];
760 }
761
762 col_bias[col] = result;
763 }
764}
765
766template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col);
767template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col);
768
769} // namespace arm_gemm