blob: cac02cf28e5039e4d000d41a11b5173feaf0983c [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasf33484f2019-07-29 12:40:59 +010024#ifdef __aarch64__
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010025
26#include "arm_gemm.hpp"
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include "utils.hpp"
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010028
29#include <arm_neon.h>
30
31namespace arm_gemm {
32
33namespace {
34
35/* Requantize a block of data, using the requantize parameters in 'qp'.
36 *
37 * row_bias and col_bias are assumed to be precomputed values which include
38 * any externally supplied bias, plus the row/column contibution sums, plus
39 * the overall constant offset (A_offset * B_offset * depth).
40 *
41 * Note that this function works equally well for uint8_t output: just set
42 * minval/maxval appropriately and cast the output pointer. It is caller's
43 * responsibility to ensure that minval/maxval are representable in the
44 * target type - the downcast to (u)int8_t is done by simply extracting the
45 * LSB.
46 *
47 * The 'do_shift_correction' template parameter turns on the correction
48 * applied to negative values being shifted right to make sure they round
49 * properly - if negative values are never output (e.g. fused ReLU) this is
50 * unnecessary.
Michalis Spyrou71ac9032019-11-14 14:31:44 +000051 *
52 * The 'per_channel' template parameter selects between per channel and per
53 * layer requantization - in the former case we need to load vectors of
54 * shifts and multipliers for each column. A separate vector for each
55 * column is set up in any case (and it is hoped that the compiler can elide
56 * the needless movs in the per-layer case).
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010057 */
morgolock0bc80da2020-08-10 16:44:18 +010058template<bool do_shift_correction, bool per_channel, bool do_left_shift>
Michalis Spyrou71ac9032019-11-14 14:31:44 +000059void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010060 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +010061 const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
morgolock0bc80da2020-08-10 16:44:18 +010062 const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
63 const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
64 const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
65 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
66 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
67 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010068
69 /* To make sure we have plenty of accumulators, compute two rows at a
70 * time. If the number of rows is odd, compute the bottom row twice to
71 * avoid needing a duplicate codepath. */
72 for (unsigned int row=0; row<height; row+=2) {
73 /* Prefer to do 4 vectors (16 values) at once as this collapses
74 * neatly to a single vector of output, failing that a vector at a
75 * time and then the odd ones out at the end. */
76 unsigned int blocks=(width / 16);
77 unsigned int regs=(width % 16) / 4;
78 unsigned int odds=(width % 4);
79
80 const int32_t *colptr = col_bias;
morgolock0bc80da2020-08-10 16:44:18 +010081 const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
82 const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
83 const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010084
85 const int32_t *in_ptr = input + (row * in_stride);
86 int8_t *out_ptr = output + (row * out_stride);
87 int32_t row_sum = row_bias[row];
88
89 const int32_t *in_ptr1;
90 int8_t *out_ptr1;
91 int32_t row_sum1;
92
93 if (row == height-1) {
94 in_ptr1 = in_ptr;
95 out_ptr1 = out_ptr;
96 row_sum1 = row_sum;
97 } else {
98 in_ptr1 = in_ptr + in_stride;
99 out_ptr1 = out_ptr + out_stride;
100 row_sum1 = row_bias[row+1];
101 }
102
103 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
104 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
105
106 while (blocks--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000107 int32x4_t v_mul0;
108 int32x4_t v_mul1;
109 int32x4_t v_mul2;
110 int32x4_t v_mul3;
111
112 int32x4_t v_shf0;
113 int32x4_t v_shf1;
114 int32x4_t v_shf2;
115 int32x4_t v_shf3;
116
morgolock0bc80da2020-08-10 16:44:18 +0100117 int32x4_t v_shf0l;
118 int32x4_t v_shf1l;
119 int32x4_t v_shf2l;
120 int32x4_t v_shf3l;
121
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000122 if (per_channel) {
123 v_mul0 = vld1q_s32(perch_mul_ptr);
124 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
125 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
126 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
127 perch_mul_ptr += 16;
128
129 v_shf0 = vld1q_s32(perch_shift_ptr);
130 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
131 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
132 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
133 perch_shift_ptr += 16;
morgolock0bc80da2020-08-10 16:44:18 +0100134
135 if (do_left_shift) {
136 v_shf0l = vld1q_s32(perch_shiftl_ptr);
137 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
138 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
139 v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
morgolockfa269bb2020-09-08 16:00:56 +0100140 perch_shiftl_ptr += 16;
morgolock0bc80da2020-08-10 16:44:18 +0100141 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000142 } else {
143 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100144 v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
145 v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000146 }
147
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100148 // Load column pointers
149 int32x4_t v_col0 = vld1q_s32(colptr);
150 int32x4_t v_col1 = vld1q_s32(colptr + 4);
151 int32x4_t v_col2 = vld1q_s32(colptr + 8);
152 int32x4_t v_col3 = vld1q_s32(colptr + 12);
153 colptr += 16;
154
155 // Load input data (row 0);
156 int32x4_t v_in00 = vld1q_s32(in_ptr);
157 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
158 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
159 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
160 in_ptr += 16;
161
162 // Load input data (row 1);
163 int32x4_t v_in10 = vld1q_s32(in_ptr1);
164 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
165 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
166 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
167 in_ptr1 += 16;
168
169 // Add on row bias and column bias
170 v_in00 = vaddq_s32(v_in00, v_row_sum);
171 v_in01 = vaddq_s32(v_in01, v_row_sum);
172 v_in02 = vaddq_s32(v_in02, v_row_sum);
173 v_in03 = vaddq_s32(v_in03, v_row_sum);
174
175 v_in10 = vaddq_s32(v_in10, v_row_sum1);
176 v_in11 = vaddq_s32(v_in11, v_row_sum1);
177 v_in12 = vaddq_s32(v_in12, v_row_sum1);
178 v_in13 = vaddq_s32(v_in13, v_row_sum1);
179
180 v_in00 = vaddq_s32(v_in00, v_col0);
181 v_in01 = vaddq_s32(v_in01, v_col1);
182 v_in02 = vaddq_s32(v_in02, v_col2);
183 v_in03 = vaddq_s32(v_in03, v_col3);
184
185 v_in10 = vaddq_s32(v_in10, v_col0);
186 v_in11 = vaddq_s32(v_in11, v_col1);
187 v_in12 = vaddq_s32(v_in12, v_col2);
188 v_in13 = vaddq_s32(v_in13, v_col3);
189
morgolock0bc80da2020-08-10 16:44:18 +0100190 // Quantize
191
192 // If a left shift is needed it needs to happen first.
193 if (do_left_shift) {
194 v_in00 = vrshlq_s32(v_in00, v_shf0l);
195 v_in01 = vrshlq_s32(v_in01, v_shf1l);
196 v_in02 = vrshlq_s32(v_in02, v_shf2l);
197 v_in03 = vrshlq_s32(v_in03, v_shf3l);
198
199 v_in10 = vrshlq_s32(v_in10, v_shf0l);
200 v_in11 = vrshlq_s32(v_in11, v_shf1l);
201 v_in12 = vrshlq_s32(v_in12, v_shf2l);
202 v_in13 = vrshlq_s32(v_in13, v_shf3l);
203 }
204
205 // Multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000206 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
207 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
208 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
209 v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100210
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000211 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
212 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
213 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
214 v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100215
216 // Compute and add on corrective offset
217 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000218 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
219 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
220 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
221 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100222
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000223 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
224 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
225 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
226 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100227
228 v_temp00 = vshrq_n_s32(v_temp00, 31);
229 v_temp01 = vshrq_n_s32(v_temp01, 31);
230 v_temp02 = vshrq_n_s32(v_temp02, 31);
231 v_temp03 = vshrq_n_s32(v_temp03, 31);
232
233 v_temp10 = vshrq_n_s32(v_temp10, 31);
234 v_temp11 = vshrq_n_s32(v_temp11, 31);
235 v_temp12 = vshrq_n_s32(v_temp12, 31);
236 v_temp13 = vshrq_n_s32(v_temp13, 31);
237
238 v_in00 = vqaddq_s32(v_in00, v_temp00);
239 v_in01 = vqaddq_s32(v_in01, v_temp01);
240 v_in02 = vqaddq_s32(v_in02, v_temp02);
241 v_in03 = vqaddq_s32(v_in03, v_temp03);
242
243 v_in10 = vqaddq_s32(v_in10, v_temp10);
244 v_in11 = vqaddq_s32(v_in11, v_temp11);
245 v_in12 = vqaddq_s32(v_in12, v_temp12);
246 v_in13 = vqaddq_s32(v_in13, v_temp13);
247 }
248
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000249 v_in00 = vrshlq_s32(v_in00, v_shf0);
250 v_in01 = vrshlq_s32(v_in01, v_shf1);
251 v_in02 = vrshlq_s32(v_in02, v_shf2);
252 v_in03 = vrshlq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100253
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000254 v_in10 = vrshlq_s32(v_in10, v_shf0);
255 v_in11 = vrshlq_s32(v_in11, v_shf1);
256 v_in12 = vrshlq_s32(v_in12, v_shf2);
257 v_in13 = vrshlq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100258
259 v_in00 = vaddq_s32(v_in00, v_c_offset);
260 v_in01 = vaddq_s32(v_in01, v_c_offset);
261 v_in02 = vaddq_s32(v_in02, v_c_offset);
262 v_in03 = vaddq_s32(v_in03, v_c_offset);
263
264 v_in10 = vaddq_s32(v_in10, v_c_offset);
265 v_in11 = vaddq_s32(v_in11, v_c_offset);
266 v_in12 = vaddq_s32(v_in12, v_c_offset);
267 v_in13 = vaddq_s32(v_in13, v_c_offset);
268
269 v_in00 = vmaxq_s32(v_in00, v_minval);
270 v_in01 = vmaxq_s32(v_in01, v_minval);
271 v_in02 = vmaxq_s32(v_in02, v_minval);
272 v_in03 = vmaxq_s32(v_in03, v_minval);
273
274 v_in10 = vmaxq_s32(v_in10, v_minval);
275 v_in11 = vmaxq_s32(v_in11, v_minval);
276 v_in12 = vmaxq_s32(v_in12, v_minval);
277 v_in13 = vmaxq_s32(v_in13, v_minval);
278
279 v_in00 = vminq_s32(v_in00, v_maxval);
280 v_in01 = vminq_s32(v_in01, v_maxval);
281 v_in02 = vminq_s32(v_in02, v_maxval);
282 v_in03 = vminq_s32(v_in03, v_maxval);
283
284 v_in10 = vminq_s32(v_in10, v_maxval);
285 v_in11 = vminq_s32(v_in11, v_maxval);
286 v_in12 = vminq_s32(v_in12, v_maxval);
287 v_in13 = vminq_s32(v_in13, v_maxval);
288
289 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
290 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
291
292 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
293 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
294
295 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
296 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
297
298 vst1q_s8(out_ptr, v_uz0);
299 out_ptr += 16;
300 vst1q_s8(out_ptr1, v_uz1);
301 out_ptr1 += 16;
302 }
303
304 while (regs--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000305 int32x4_t v_mul0;
306 int32x4_t v_shf0;
morgolock0bc80da2020-08-10 16:44:18 +0100307 int32x4_t v_shf0l;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000308
309 if (per_channel) {
310 v_mul0 = vld1q_s32(perch_mul_ptr);
311 perch_mul_ptr += 4;
312
313 v_shf0 = vld1q_s32(perch_shift_ptr);
314 perch_shift_ptr += 4;
morgolock0bc80da2020-08-10 16:44:18 +0100315
316 if (do_left_shift) {
317 v_shf0l = vld1q_s32(perch_shiftl_ptr);
318 perch_shiftl_ptr += 4;
319 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000320 } else {
321 v_mul0=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100322 v_shf0=v_right_shift;
323 v_shf0l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000324 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100325 // Load column pointers
326 int32x4_t v_col0 = vld1q_s32(colptr);
327 colptr += 4;
328
329 // Load input data (row 0);
330 int32x4_t v_in00 = vld1q_s32(in_ptr);
331 in_ptr += 4;
332
333 // Load input data (row 1);
334 int32x4_t v_in10 = vld1q_s32(in_ptr1);
335 in_ptr1 += 4;
336
337 // Add on row sum and bias constant
338 v_in00 = vaddq_s32(v_in00, v_row_sum);
339
340 v_in10 = vaddq_s32(v_in10, v_row_sum1);
341
342 // Subtract col sum * a_offset
343 v_in00 = vaddq_s32(v_in00, v_col0);
344
345 v_in10 = vaddq_s32(v_in10, v_col0);
346
morgolock0bc80da2020-08-10 16:44:18 +0100347 // Quantize - start with (optional) left shift
348 if (do_left_shift) {
349 v_in00 = vrshlq_s32(v_in00, v_shf0l);
350
351 v_in10 = vrshlq_s32(v_in10, v_shf0l);
352 }
353
354 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000355 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100356
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000357 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100358
359 // Compute and add on corrective offset
360 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000361 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100362
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000363 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100364
365 v_temp00 = vshrq_n_s32(v_temp00, 31);
366
367 v_temp10 = vshrq_n_s32(v_temp10, 31);
368
369 v_in00 = vqaddq_s32(v_in00, v_temp00);
370
371 v_in10 = vqaddq_s32(v_in10, v_temp10);
372 }
373
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000374 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100375
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000376 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100377
378 v_in00 = vaddq_s32(v_in00, v_c_offset);
379
380 v_in10 = vaddq_s32(v_in10, v_c_offset);
381
382 v_in00 = vmaxq_s32(v_in00, v_minval);
383
384 v_in10 = vmaxq_s32(v_in10, v_minval);
385
386 v_in00 = vminq_s32(v_in00, v_maxval);
387
388 v_in10 = vminq_s32(v_in10, v_maxval);
389
390 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
391
392 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
393
394 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
395 out_ptr += 4;
396 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
397 out_ptr1 += 4;
398 }
399
400 if (odds) {
401 int32x4_t v_col0 = vdupq_n_s32(0);
402 int32x4_t v_in00 = vdupq_n_s32(0);
403 int32x4_t v_in10 = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000404 int32x4_t v_mul0 = vdupq_n_s32(0);
405 int32x4_t v_shf0 = vdupq_n_s32(0);
morgolock0bc80da2020-08-10 16:44:18 +0100406 int32x4_t v_shf0l = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000407
408 if (!per_channel) {
409 v_mul0 = v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100410 v_shf0 = v_right_shift;
411 v_shf0l = v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000412 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100413
414 do {
415 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
416 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
417 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000418 if (per_channel) {
419 v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
420 v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
morgolock0bc80da2020-08-10 16:44:18 +0100421 if (do_left_shift) {
422 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
423 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000424 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100425 if (odds == 1) { break; }
426
427 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
428 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
429 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000430 if (per_channel) {
431 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
432 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
morgolock0bc80da2020-08-10 16:44:18 +0100433 if (do_left_shift) {
434 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
435 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000436 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100437 if (odds == 2) { break; }
438
439 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
440 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
441 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000442 if (per_channel) {
443 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
444 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
morgolock0bc80da2020-08-10 16:44:18 +0100445 if (do_left_shift) {
446 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
447 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000448 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100449 } while (0);
450
451 // Add on row sum and bias constant
452 v_in00 = vaddq_s32(v_in00, v_row_sum);
453
454 v_in10 = vaddq_s32(v_in10, v_row_sum1);
455
456 // Subtract col sum * a_offset
457 v_in00 = vaddq_s32(v_in00, v_col0);
458
459 v_in10 = vaddq_s32(v_in10, v_col0);
460
morgolock0bc80da2020-08-10 16:44:18 +0100461 // Quantize - start with (optional) left shift
462 if (do_left_shift) {
463 v_in00 = vrshlq_s32(v_in00, v_shf0l);
464
465 v_in10 = vrshlq_s32(v_in10, v_shf0l);
466 }
467
468 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000469 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100470
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000471 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100472
473 // Compute and add on corrective offset
474 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000475 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100476
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000477 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100478
479 v_temp00 = vshrq_n_s32(v_temp00, 31);
480
481 v_temp10 = vshrq_n_s32(v_temp10, 31);
482
483 v_in00 = vqaddq_s32(v_in00, v_temp00);
484
485 v_in10 = vqaddq_s32(v_in10, v_temp10);
486 }
487
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000488 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100489
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000490 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100491
492 v_in00 = vaddq_s32(v_in00, v_c_offset);
493
494 v_in10 = vaddq_s32(v_in10, v_c_offset);
495
496 v_in00 = vmaxq_s32(v_in00, v_minval);
497
498 v_in10 = vmaxq_s32(v_in10, v_minval);
499
500 v_in00 = vminq_s32(v_in00, v_maxval);
501
502 v_in10 = vminq_s32(v_in10, v_maxval);
503
504 do {
505 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
506 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
507
508 if (odds==1) { break; }
509
510 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
511 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
512
513 if (odds==2) { break; }
514
515 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
516 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
517 } while(0);
518 }
519 }
520}
521
522} // anonymous namespace
523
524template<typename Tin, typename Tout>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000525void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100526 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100527 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000528 if (qp.per_channel_requant) {
529 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100530 if (qp.per_channel_left_shifts) {
531 requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
532 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
533 } else {
534 requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
535 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
536 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000537 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100538 if (qp.per_channel_left_shifts) {
539 requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
540 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
541 } else {
542 requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
543 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
544 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000545 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100546 } else {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000547 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100548 if (qp.per_layer_left_shift > 0) {
549 requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
550 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
551 } else {
552 requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
553 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
554 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000555 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100556 if (qp.per_layer_left_shift > 0) {
557 requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
558 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
559 } else {
560 requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
561 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
562 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000563 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100564 }
565}
566
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000567template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100568 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100569 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100570
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000571template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100572 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100573 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100574
575/*
576 * Routine (and helpers) to compute row sums needed for offset correction.
577 *
578 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
579 * of length 27), therefore it's important not to sacrifice performance on
580 * odd length rows.
581 *
582 * To minimize performance loss in these cases, this routine will overread
583 * by up to 7 bytes.
584 *
585 * This is handled via "mask" and "mask mode" parameters to the inner
586 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
587 * (inclusive) needed at the end; in these cases we always read 8 bytes.
588 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
589 * the end, and in this case we always read 16 bytes. In both cases the
590 * 'mask' vector is set up so that the read value can be masked off to clear
591 * the overread lanes. This is handled by 'accumulate_masked_8' and
592 * 'accumulate_masked_16' above.
593 *
594 * This routine is templated on the type to be accumulated, because the
595 * innermost instruction used needs to be of the correct signedness.
596 * However, beyond this point we always use signed values in both cases.
597 * The instructions that need to be different are therefore wrapped in
598 * helper functions below.
Michalis Spyrou400abc82019-08-20 17:25:25 +0100599 *
600 * The general strategy used is to load vectors of 16 bytes and accumulate
601 * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
602 * These are then reduced (using uadalp/sadalp again) into 4x32-bit
603 * accumulators. The 4 accumulators for up to 4 rows being processed are
604 * then added together into a single output vector using pairwise adds.
605 *
606 * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
607 * occur before the 16-bit accumulators can overflow - which is every 32
608 * iterations (512 total bytes processed). This is explained more below.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100609 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100610namespace {
611 struct row_sum_helpers {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000612 const Requantize32 &qp;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100613
614 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
615 template<typename T>
616 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
617
618 /* Load a full 16 byte vector, but mask before accumulation (see above). */
619 template<typename T>
620 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
621
622 /* Load 8 bytes and mask before accumulation. */
623 template<typename T>
624 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
625
626 /* This function does the actual work for up to 4 rows at a time.
627 * It's pulled out so we can template on the row count to generate
628 * the 4 different cases. 4 rows are computed at a time as this
629 * reduces to a single vector write. */
630 template<unsigned int rows, typename T>
631 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
632 int16x8_t sums[rows];
633 int32x4_t finalsums[rows];
634
635 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100636 sums[i] = vdupq_n_s16(0);
637 finalsums[i] = vdupq_n_s32(0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100638 }
639
640 for (unsigned int i=0; i<blocks; i++) {
641 for (unsigned int r=0; r<rows; r++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100642 /* If we add too many blocks together, we run the risk
643 * of overflowing the intermediate 16-bit accumulators,
644 * especially in the unsigned case where we later treat
645 * the accumulator as signed.
646 *
647 * In that case, the maximum (signed) value is 16383,
648 * which is safe for 64 (unsigned) accumulations (255*64
649 * = 16,320).
650 *
651 * Each invocation of pairwise add adds 2 values to the
652 * accumulator - so in the unsigned case we can do 32
653 * adds before we need to reset the 16-bit accumulator
654 * by adding into the 32-bit 'finalsums'.
655 *
656 * We could do 64 adds in the signed case, but that
657 * optimization is not worth the complexity.
658 */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100659 if (i > 0 && ((i & 31) == 0)) {
660 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
661 sums[r] = vdupq_n_s16(0);
662 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100663 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
664 }
665 }
666
667 /* Handle the final masked read if needed. */
668 if (mask_mode > 0) {
669 for (unsigned int r=0; r<rows; r++) {
670 if (mask_mode == 1) {
671 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
672 } else {
673 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
674 }
675 }
676 }
677
678 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100679 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100680 }
681
682 int32x4_t t0, t1;
683 int32x2_t t2;
684
685 /* Result writeback - need to write back one value per row
686 * processed. Multiply all the final totals by -b_offset so
687 * that the terms can simply be added in the requantize code.
688 * */
689 switch (rows) {
690 case 1:
691 /* If we only have one output, just use ADDV. Multiply
692 * the offset into all four components separately so it
693 * can stay in the SIMD register file. */
694 t0 = vmulq_s32(finalsums[0], offset_mul);
695 *row_bias = vaddvq_s32(t0);
696 break;
697
698 case 2:
699 /* For two outputs, two rounds of pairwise adds will
700 * generate the result in a 2-vector we can store in one
701 * go. */
702 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
703 t0 = vpaddq_s32(t0, t0);
704 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
705 vst1_s32(row_bias, t2);
706 break;
707
708 case 3:
709 /* Three rows - need to store the low two words plus the odd value from lane 2 */
710 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
711 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
712
713 t0 = vpaddq_s32(t0, t1);
714 t0 = vmulq_s32(t0, offset_mul);
715
716 vst1_s32(row_bias, vget_low_s32(t0));
717 row_bias[2] = vgetq_lane_s32(t0, 2);
718 break;
719
720 case 4:
721 /* Four rows (most common case) - reduce to a single
722 * vector with pairwise adds. */
723 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
724 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
725
726 t0 = vpaddq_s32(t0, t1);
727 t0 = vmulq_s32(t0, offset_mul);
728
729 vst1q_s32(row_bias, t0);
730 break;
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100731
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100732 default:
733 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100734 }
735 }
736
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000737 row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100738 };
739
740 template<>
741 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
742 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
743 }
744
745 template<>
746 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
747 return vpadalq_s8(sum, vld1q_s8(ptr));
748 }
749
750 template<>
751 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
752 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
753 return vpadalq_s8(sum, v);
754 }
755
756 template<>
757 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
758 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
759 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
760 }
761
762 template<>
763 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
764 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
765 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
766 return vpadalq_s8(sum, v);
767 }
768
769 template<>
770 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
771 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
772 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
773 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
774 }
775}
776
777template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000778void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100779 const T *input, unsigned int in_stride, int32_t *row_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000780 /* If the 'b' offset is zero, just skip this entirely. */
781 if (qp.b_offset == 0) {
782 memset(row_bias, 0, height * sizeof(int32_t));
783 return;
784 }
785
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100786 row_sum_helpers thehelpers(qp);
787
788 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
789
790 /* Work out how many full vectors of 16 bytes we will read, and how many
791 * odd bytes at the end */
792 unsigned int blocks = (width / 16);
793 const unsigned int odds = width % 16;
794
795 /* Generate a mask to use on the last iteration, if necessary. */
796 uint64x2_t mask;
797 unsigned int mask_mode = 0;
798
799 if (odds > 0 && odds <= 8) {
800 /* 1-8 odds: mask in the low lane, 0 in the top */
801 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
802
803 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
804
805 mask_mode = 1;
806 } else if (odds > 8) {
807 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
808 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
809
810 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
811
812 mask_mode = 2;
813 }
814
815 for (unsigned int row=0; row<height; row+=4) {
816 switch(height-row) {
817 default:
818 case 4:
819 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
820 break;
821 case 3:
822 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
823 break;
824 case 2:
825 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
826 break;
827 case 1:
828 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
829 break;
830 }
831 }
832}
833
834/* Instantiate the two versions for uint8_t and int8_t. */
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000835template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
836template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100837
838template<unsigned int active_rows, typename T>
839inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
840
841template<unsigned int active_rows>
842inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
843 uint8x16_t inputs[4];
844
845 for (unsigned int i=0; i<4; i++) {
846 if (i < active_rows) {
847 inputs[i] = vld1q_u8(input + i * in_stride);
848 } else {
849 inputs[i] = vdupq_n_u8(0);
850 }
851 }
852
853 int16x8_t sums_16b[4];
854
855 // Two adds for the low pairs
856 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
857 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
858 // Two adds for the high pairs
859 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
860 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
861
862 int32x4_t sums_32b[4];
863
864 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
865 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
866 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
867 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
868
869 for (unsigned int i=0; i<4; i++) {
870 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
871 }
872}
873
874template<unsigned int active_rows>
875inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
876 int8x16_t inputs[4];
877
878 for (unsigned int i=0; i<4; i++) {
879 if (i < active_rows) {
880 inputs[i] = vld1q_s8(input + i * in_stride);
881 } else {
882 inputs[i] = vdupq_n_s8(0);
883 }
884 }
885
886 int16x8_t sums_16b[4];
887
888 // Two adds for the low pairs
889 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
890 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
891 // Two adds for the high pairs
892 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
893 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
894
895 int32x4_t sums_32b[4];
896
897 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
898 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
899 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
900 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
901
902 for (unsigned int i=0; i<4; i++) {
903 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
904 }
905}
906
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100907/* "first_col" parameter is used to offset the read into the qp.bias array,
908 * in cases where we are not computing the first columns of the output (i.e.
909 * in multithreaded cases where we divide columns across threads) */
910template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000911void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
912 /* Only actually add up the columns if a_offset is non-zero. */
913 if (qp.a_offset != 0) {
914 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100915
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000916 for (unsigned int row=0; row<height; row+=4) {
917 unsigned int numrows=std::min(height-row, 4u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100918
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000919 for (unsigned int col=0; col<width; col+=16) {
920 unsigned int numcols=std::min(width-col, 16u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100921
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000922 if (numcols==16) {
923 switch(numrows) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000924 case 1:
925 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
926 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100927
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000928 case 2:
929 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
930 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100931
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000932 case 3:
933 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
934 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100935
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000936 case 4:
937 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
938 break;
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100939
940 default:
941 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100942 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000943 } else {
944 for (; col<width; col++) {
945 int32_t sum=0;
946 for (unsigned int r=0; r<numrows; r++) {
947 sum += input[(row + r)*in_stride + col];
948 }
949 col_bias[col] += sum;
950 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100951 }
952 }
953 }
954 }
955
956 for (unsigned int col=0; col<width; col++) {
957 int32_t result = col_bias[col];
958
959 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
960
961 if (qp.bias != nullptr) {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100962 result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100963 }
964
965 col_bias[col] = result;
966 }
967}
968
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000969template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
970template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100971
972} // namespace arm_gemm
Georgios Pinitasf33484f2019-07-29 12:40:59 +0100973
974#endif // __aarch64__