blob: 6da9f4be0e8b2259ab4d8b1bbf65256174eff8f8 [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
Jonathan Deakina668f9f2024-01-24 09:15:38 +00002 * Copyright (c) 2019, 2024 Arm Limited.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasf33484f2019-07-29 12:40:59 +010024#ifdef __aarch64__
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010025
26#include "arm_gemm.hpp"
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include "utils.hpp"
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010028
29#include <arm_neon.h>
30
31namespace arm_gemm {
32
33namespace {
34
35/* Requantize a block of data, using the requantize parameters in 'qp'.
36 *
37 * row_bias and col_bias are assumed to be precomputed values which include
38 * any externally supplied bias, plus the row/column contibution sums, plus
39 * the overall constant offset (A_offset * B_offset * depth).
40 *
41 * Note that this function works equally well for uint8_t output: just set
42 * minval/maxval appropriately and cast the output pointer. It is caller's
43 * responsibility to ensure that minval/maxval are representable in the
44 * target type - the downcast to (u)int8_t is done by simply extracting the
45 * LSB.
46 *
47 * The 'do_shift_correction' template parameter turns on the correction
48 * applied to negative values being shifted right to make sure they round
49 * properly - if negative values are never output (e.g. fused ReLU) this is
50 * unnecessary.
Michalis Spyrou71ac9032019-11-14 14:31:44 +000051 *
52 * The 'per_channel' template parameter selects between per channel and per
53 * layer requantization - in the former case we need to load vectors of
54 * shifts and multipliers for each column. A separate vector for each
55 * column is set up in any case (and it is hoped that the compiler can elide
56 * the needless movs in the per-layer case).
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010057 */
morgolock0bc80da2020-08-10 16:44:18 +010058template<bool do_shift_correction, bool per_channel, bool do_left_shift>
Michalis Spyrou71ac9032019-11-14 14:31:44 +000059void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010060 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +010061 const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
morgolock0bc80da2020-08-10 16:44:18 +010062 const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
63 const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
64 const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
65 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
66 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
67 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010068
69 /* To make sure we have plenty of accumulators, compute two rows at a
70 * time. If the number of rows is odd, compute the bottom row twice to
71 * avoid needing a duplicate codepath. */
72 for (unsigned int row=0; row<height; row+=2) {
73 /* Prefer to do 4 vectors (16 values) at once as this collapses
74 * neatly to a single vector of output, failing that a vector at a
75 * time and then the odd ones out at the end. */
76 unsigned int blocks=(width / 16);
77 unsigned int regs=(width % 16) / 4;
78 unsigned int odds=(width % 4);
79
80 const int32_t *colptr = col_bias;
morgolock0bc80da2020-08-10 16:44:18 +010081 const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
82 const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
83 const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010084
85 const int32_t *in_ptr = input + (row * in_stride);
86 int8_t *out_ptr = output + (row * out_stride);
87 int32_t row_sum = row_bias[row];
88
89 const int32_t *in_ptr1;
90 int8_t *out_ptr1;
91 int32_t row_sum1;
92
93 if (row == height-1) {
94 in_ptr1 = in_ptr;
95 out_ptr1 = out_ptr;
96 row_sum1 = row_sum;
97 } else {
98 in_ptr1 = in_ptr + in_stride;
99 out_ptr1 = out_ptr + out_stride;
100 row_sum1 = row_bias[row+1];
101 }
102
103 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
104 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
105
106 while (blocks--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000107 int32x4_t v_mul0;
108 int32x4_t v_mul1;
109 int32x4_t v_mul2;
110 int32x4_t v_mul3;
111
112 int32x4_t v_shf0;
113 int32x4_t v_shf1;
114 int32x4_t v_shf2;
115 int32x4_t v_shf3;
116
morgolock0bc80da2020-08-10 16:44:18 +0100117 int32x4_t v_shf0l;
118 int32x4_t v_shf1l;
119 int32x4_t v_shf2l;
120 int32x4_t v_shf3l;
121
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000122 if (per_channel) {
123 v_mul0 = vld1q_s32(perch_mul_ptr);
124 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
125 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
126 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
127 perch_mul_ptr += 16;
128
129 v_shf0 = vld1q_s32(perch_shift_ptr);
130 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
131 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
132 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
133 perch_shift_ptr += 16;
morgolock0bc80da2020-08-10 16:44:18 +0100134
135 if (do_left_shift) {
136 v_shf0l = vld1q_s32(perch_shiftl_ptr);
137 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
138 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
139 v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
morgolockfa269bb2020-09-08 16:00:56 +0100140 perch_shiftl_ptr += 16;
morgolock0bc80da2020-08-10 16:44:18 +0100141 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000142 } else {
143 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100144 v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
145 v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000146 }
147
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100148 // Load column pointers
149 int32x4_t v_col0 = vld1q_s32(colptr);
150 int32x4_t v_col1 = vld1q_s32(colptr + 4);
151 int32x4_t v_col2 = vld1q_s32(colptr + 8);
152 int32x4_t v_col3 = vld1q_s32(colptr + 12);
153 colptr += 16;
154
155 // Load input data (row 0);
156 int32x4_t v_in00 = vld1q_s32(in_ptr);
157 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
158 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
159 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
160 in_ptr += 16;
161
162 // Load input data (row 1);
163 int32x4_t v_in10 = vld1q_s32(in_ptr1);
164 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
165 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
166 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
167 in_ptr1 += 16;
168
169 // Add on row bias and column bias
170 v_in00 = vaddq_s32(v_in00, v_row_sum);
171 v_in01 = vaddq_s32(v_in01, v_row_sum);
172 v_in02 = vaddq_s32(v_in02, v_row_sum);
173 v_in03 = vaddq_s32(v_in03, v_row_sum);
174
175 v_in10 = vaddq_s32(v_in10, v_row_sum1);
176 v_in11 = vaddq_s32(v_in11, v_row_sum1);
177 v_in12 = vaddq_s32(v_in12, v_row_sum1);
178 v_in13 = vaddq_s32(v_in13, v_row_sum1);
179
180 v_in00 = vaddq_s32(v_in00, v_col0);
181 v_in01 = vaddq_s32(v_in01, v_col1);
182 v_in02 = vaddq_s32(v_in02, v_col2);
183 v_in03 = vaddq_s32(v_in03, v_col3);
184
185 v_in10 = vaddq_s32(v_in10, v_col0);
186 v_in11 = vaddq_s32(v_in11, v_col1);
187 v_in12 = vaddq_s32(v_in12, v_col2);
188 v_in13 = vaddq_s32(v_in13, v_col3);
189
morgolock0bc80da2020-08-10 16:44:18 +0100190 // Quantize
191
192 // If a left shift is needed it needs to happen first.
193 if (do_left_shift) {
194 v_in00 = vrshlq_s32(v_in00, v_shf0l);
195 v_in01 = vrshlq_s32(v_in01, v_shf1l);
196 v_in02 = vrshlq_s32(v_in02, v_shf2l);
197 v_in03 = vrshlq_s32(v_in03, v_shf3l);
198
199 v_in10 = vrshlq_s32(v_in10, v_shf0l);
200 v_in11 = vrshlq_s32(v_in11, v_shf1l);
201 v_in12 = vrshlq_s32(v_in12, v_shf2l);
202 v_in13 = vrshlq_s32(v_in13, v_shf3l);
203 }
204
205 // Multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000206 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
207 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
208 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
209 v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100210
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000211 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
212 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
213 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
214 v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100215
216 // Compute and add on corrective offset
217 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000218 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
219 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
220 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
221 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100222
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000223 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
224 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
225 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
226 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100227
228 v_temp00 = vshrq_n_s32(v_temp00, 31);
229 v_temp01 = vshrq_n_s32(v_temp01, 31);
230 v_temp02 = vshrq_n_s32(v_temp02, 31);
231 v_temp03 = vshrq_n_s32(v_temp03, 31);
232
233 v_temp10 = vshrq_n_s32(v_temp10, 31);
234 v_temp11 = vshrq_n_s32(v_temp11, 31);
235 v_temp12 = vshrq_n_s32(v_temp12, 31);
236 v_temp13 = vshrq_n_s32(v_temp13, 31);
237
238 v_in00 = vqaddq_s32(v_in00, v_temp00);
239 v_in01 = vqaddq_s32(v_in01, v_temp01);
240 v_in02 = vqaddq_s32(v_in02, v_temp02);
241 v_in03 = vqaddq_s32(v_in03, v_temp03);
242
243 v_in10 = vqaddq_s32(v_in10, v_temp10);
244 v_in11 = vqaddq_s32(v_in11, v_temp11);
245 v_in12 = vqaddq_s32(v_in12, v_temp12);
246 v_in13 = vqaddq_s32(v_in13, v_temp13);
247 }
248
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000249 v_in00 = vrshlq_s32(v_in00, v_shf0);
250 v_in01 = vrshlq_s32(v_in01, v_shf1);
251 v_in02 = vrshlq_s32(v_in02, v_shf2);
252 v_in03 = vrshlq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100253
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000254 v_in10 = vrshlq_s32(v_in10, v_shf0);
255 v_in11 = vrshlq_s32(v_in11, v_shf1);
256 v_in12 = vrshlq_s32(v_in12, v_shf2);
257 v_in13 = vrshlq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100258
259 v_in00 = vaddq_s32(v_in00, v_c_offset);
260 v_in01 = vaddq_s32(v_in01, v_c_offset);
261 v_in02 = vaddq_s32(v_in02, v_c_offset);
262 v_in03 = vaddq_s32(v_in03, v_c_offset);
263
264 v_in10 = vaddq_s32(v_in10, v_c_offset);
265 v_in11 = vaddq_s32(v_in11, v_c_offset);
266 v_in12 = vaddq_s32(v_in12, v_c_offset);
267 v_in13 = vaddq_s32(v_in13, v_c_offset);
268
269 v_in00 = vmaxq_s32(v_in00, v_minval);
270 v_in01 = vmaxq_s32(v_in01, v_minval);
271 v_in02 = vmaxq_s32(v_in02, v_minval);
272 v_in03 = vmaxq_s32(v_in03, v_minval);
273
274 v_in10 = vmaxq_s32(v_in10, v_minval);
275 v_in11 = vmaxq_s32(v_in11, v_minval);
276 v_in12 = vmaxq_s32(v_in12, v_minval);
277 v_in13 = vmaxq_s32(v_in13, v_minval);
278
279 v_in00 = vminq_s32(v_in00, v_maxval);
280 v_in01 = vminq_s32(v_in01, v_maxval);
281 v_in02 = vminq_s32(v_in02, v_maxval);
282 v_in03 = vminq_s32(v_in03, v_maxval);
283
284 v_in10 = vminq_s32(v_in10, v_maxval);
285 v_in11 = vminq_s32(v_in11, v_maxval);
286 v_in12 = vminq_s32(v_in12, v_maxval);
287 v_in13 = vminq_s32(v_in13, v_maxval);
288
289 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
290 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
291
292 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
293 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
294
295 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
296 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
297
298 vst1q_s8(out_ptr, v_uz0);
299 out_ptr += 16;
300 vst1q_s8(out_ptr1, v_uz1);
301 out_ptr1 += 16;
302 }
303
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000304 // We are often quantizing one block of interleaved kernel output at a time - these are three registers
305 // wide. Special case that here.
306 if (regs==3) {
307 regs -= 3;
308
309 int32x4_t v_mul0;
310 int32x4_t v_mul1;
311 int32x4_t v_mul2;
312
313 int32x4_t v_shf0;
314 int32x4_t v_shf1;
315 int32x4_t v_shf2;
316
317 int32x4_t v_shf0l;
318 int32x4_t v_shf1l;
319 int32x4_t v_shf2l;
320
321 if (per_channel) {
322 v_mul0 = vld1q_s32(perch_mul_ptr);
323 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
324 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
325 perch_mul_ptr += 12;
326
327 v_shf0 = vld1q_s32(perch_shift_ptr);
328 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
329 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
330 perch_shift_ptr += 12;
331
332 if (do_left_shift) {
333 v_shf0l = vld1q_s32(perch_shiftl_ptr);
334 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
335 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
336 perch_shiftl_ptr += 12;
337 }
338 } else {
339 v_mul0=v_mul1=v_mul2=v_mul;
340 v_shf0=v_shf1=v_shf2=v_right_shift;
341 v_shf0l=v_shf1l=v_shf2l=v_left_shift;
342 }
343
344 // Load column pointers
345 int32x4_t v_col0 = vld1q_s32(colptr);
346 int32x4_t v_col1 = vld1q_s32(colptr + 4);
347 int32x4_t v_col2 = vld1q_s32(colptr + 8);
348 colptr += 12;
349
350 // Load input data (row 0);
351 int32x4_t v_in00 = vld1q_s32(in_ptr);
352 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
353 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
354 in_ptr += 12;
355
356 // Load input data (row 1);
357 int32x4_t v_in10 = vld1q_s32(in_ptr1);
358 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
359 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
360 in_ptr1 += 12;
361
362 // Add on row bias and column bias
363 v_in00 = vaddq_s32(v_in00, v_row_sum);
364 v_in01 = vaddq_s32(v_in01, v_row_sum);
365 v_in02 = vaddq_s32(v_in02, v_row_sum);
366
367 v_in10 = vaddq_s32(v_in10, v_row_sum1);
368 v_in11 = vaddq_s32(v_in11, v_row_sum1);
369 v_in12 = vaddq_s32(v_in12, v_row_sum1);
370
371 v_in00 = vaddq_s32(v_in00, v_col0);
372 v_in01 = vaddq_s32(v_in01, v_col1);
373 v_in02 = vaddq_s32(v_in02, v_col2);
374
375 v_in10 = vaddq_s32(v_in10, v_col0);
376 v_in11 = vaddq_s32(v_in11, v_col1);
377 v_in12 = vaddq_s32(v_in12, v_col2);
378
379 // Quantize
380
381 // If a left shift is needed it needs to happen first.
382 if (do_left_shift) {
383 v_in00 = vrshlq_s32(v_in00, v_shf0l);
384 v_in01 = vrshlq_s32(v_in01, v_shf1l);
385 v_in02 = vrshlq_s32(v_in02, v_shf2l);
386
387 v_in10 = vrshlq_s32(v_in10, v_shf0l);
388 v_in11 = vrshlq_s32(v_in11, v_shf1l);
389 v_in12 = vrshlq_s32(v_in12, v_shf2l);
390 }
391
392 // Multiply
393 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
394 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
395 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
396
397 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
398 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
399 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
400
401 // Compute and add on corrective offset
402 if (do_shift_correction) {
403 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
404 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
405 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
406
407 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
408 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
409 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
410
411 v_temp00 = vshrq_n_s32(v_temp00, 31);
412 v_temp01 = vshrq_n_s32(v_temp01, 31);
413 v_temp02 = vshrq_n_s32(v_temp02, 31);
414
415 v_temp10 = vshrq_n_s32(v_temp10, 31);
416 v_temp11 = vshrq_n_s32(v_temp11, 31);
417 v_temp12 = vshrq_n_s32(v_temp12, 31);
418
419 v_in00 = vqaddq_s32(v_in00, v_temp00);
420 v_in01 = vqaddq_s32(v_in01, v_temp01);
421 v_in02 = vqaddq_s32(v_in02, v_temp02);
422
423 v_in10 = vqaddq_s32(v_in10, v_temp10);
424 v_in11 = vqaddq_s32(v_in11, v_temp11);
425 v_in12 = vqaddq_s32(v_in12, v_temp12);
426 }
427
428 v_in00 = vrshlq_s32(v_in00, v_shf0);
429 v_in01 = vrshlq_s32(v_in01, v_shf1);
430 v_in02 = vrshlq_s32(v_in02, v_shf2);
431
432 v_in10 = vrshlq_s32(v_in10, v_shf0);
433 v_in11 = vrshlq_s32(v_in11, v_shf1);
434 v_in12 = vrshlq_s32(v_in12, v_shf2);
435
436 v_in00 = vaddq_s32(v_in00, v_c_offset);
437 v_in01 = vaddq_s32(v_in01, v_c_offset);
438 v_in02 = vaddq_s32(v_in02, v_c_offset);
439
440 v_in10 = vaddq_s32(v_in10, v_c_offset);
441 v_in11 = vaddq_s32(v_in11, v_c_offset);
442 v_in12 = vaddq_s32(v_in12, v_c_offset);
443
444 v_in00 = vmaxq_s32(v_in00, v_minval);
445 v_in01 = vmaxq_s32(v_in01, v_minval);
446 v_in02 = vmaxq_s32(v_in02, v_minval);
447
448 v_in10 = vmaxq_s32(v_in10, v_minval);
449 v_in11 = vmaxq_s32(v_in11, v_minval);
450 v_in12 = vmaxq_s32(v_in12, v_minval);
451
452 v_in00 = vminq_s32(v_in00, v_maxval);
453 v_in01 = vminq_s32(v_in01, v_maxval);
454 v_in02 = vminq_s32(v_in02, v_maxval);
455
456 v_in10 = vminq_s32(v_in10, v_maxval);
457 v_in11 = vminq_s32(v_in11, v_maxval);
458 v_in12 = vminq_s32(v_in12, v_maxval);
459
460 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
461 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02));
462
463 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
464 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12));
465
466 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
467 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
468
469 vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0);
470 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2);
471 out_ptr += 12;
472 vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0);
473 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2);
474 out_ptr1 += 12;
475 }
476
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100477 while (regs--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000478 int32x4_t v_mul0;
479 int32x4_t v_shf0;
morgolock0bc80da2020-08-10 16:44:18 +0100480 int32x4_t v_shf0l;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000481
482 if (per_channel) {
483 v_mul0 = vld1q_s32(perch_mul_ptr);
484 perch_mul_ptr += 4;
485
486 v_shf0 = vld1q_s32(perch_shift_ptr);
487 perch_shift_ptr += 4;
morgolock0bc80da2020-08-10 16:44:18 +0100488
489 if (do_left_shift) {
490 v_shf0l = vld1q_s32(perch_shiftl_ptr);
491 perch_shiftl_ptr += 4;
492 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000493 } else {
494 v_mul0=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100495 v_shf0=v_right_shift;
496 v_shf0l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000497 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100498 // Load column pointers
499 int32x4_t v_col0 = vld1q_s32(colptr);
500 colptr += 4;
501
502 // Load input data (row 0);
503 int32x4_t v_in00 = vld1q_s32(in_ptr);
504 in_ptr += 4;
505
506 // Load input data (row 1);
507 int32x4_t v_in10 = vld1q_s32(in_ptr1);
508 in_ptr1 += 4;
509
510 // Add on row sum and bias constant
511 v_in00 = vaddq_s32(v_in00, v_row_sum);
512
513 v_in10 = vaddq_s32(v_in10, v_row_sum1);
514
515 // Subtract col sum * a_offset
516 v_in00 = vaddq_s32(v_in00, v_col0);
517
518 v_in10 = vaddq_s32(v_in10, v_col0);
519
morgolock0bc80da2020-08-10 16:44:18 +0100520 // Quantize - start with (optional) left shift
521 if (do_left_shift) {
522 v_in00 = vrshlq_s32(v_in00, v_shf0l);
523
524 v_in10 = vrshlq_s32(v_in10, v_shf0l);
525 }
526
527 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000528 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100529
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000530 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100531
532 // Compute and add on corrective offset
533 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000534 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100535
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000536 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100537
538 v_temp00 = vshrq_n_s32(v_temp00, 31);
539
540 v_temp10 = vshrq_n_s32(v_temp10, 31);
541
542 v_in00 = vqaddq_s32(v_in00, v_temp00);
543
544 v_in10 = vqaddq_s32(v_in10, v_temp10);
545 }
546
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000547 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100548
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000549 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100550
551 v_in00 = vaddq_s32(v_in00, v_c_offset);
552
553 v_in10 = vaddq_s32(v_in10, v_c_offset);
554
555 v_in00 = vmaxq_s32(v_in00, v_minval);
556
557 v_in10 = vmaxq_s32(v_in10, v_minval);
558
559 v_in00 = vminq_s32(v_in00, v_maxval);
560
561 v_in10 = vminq_s32(v_in10, v_maxval);
562
563 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
564
565 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
566
567 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
568 out_ptr += 4;
569 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
570 out_ptr1 += 4;
571 }
572
573 if (odds) {
574 int32x4_t v_col0 = vdupq_n_s32(0);
575 int32x4_t v_in00 = vdupq_n_s32(0);
576 int32x4_t v_in10 = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000577 int32x4_t v_mul0 = vdupq_n_s32(0);
578 int32x4_t v_shf0 = vdupq_n_s32(0);
morgolock0bc80da2020-08-10 16:44:18 +0100579 int32x4_t v_shf0l = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000580
581 if (!per_channel) {
582 v_mul0 = v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100583 v_shf0 = v_right_shift;
584 v_shf0l = v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000585 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100586
587 do {
588 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
589 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
590 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000591 if (per_channel) {
592 v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
593 v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
morgolock0bc80da2020-08-10 16:44:18 +0100594 if (do_left_shift) {
595 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
596 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000597 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100598 if (odds == 1) { break; }
599
600 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
601 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
602 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000603 if (per_channel) {
604 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
605 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
morgolock0bc80da2020-08-10 16:44:18 +0100606 if (do_left_shift) {
607 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
608 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000609 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100610 if (odds == 2) { break; }
611
612 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
613 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
614 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000615 if (per_channel) {
616 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
617 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
morgolock0bc80da2020-08-10 16:44:18 +0100618 if (do_left_shift) {
619 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
620 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000621 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100622 } while (0);
623
624 // Add on row sum and bias constant
625 v_in00 = vaddq_s32(v_in00, v_row_sum);
626
627 v_in10 = vaddq_s32(v_in10, v_row_sum1);
628
629 // Subtract col sum * a_offset
630 v_in00 = vaddq_s32(v_in00, v_col0);
631
632 v_in10 = vaddq_s32(v_in10, v_col0);
633
morgolock0bc80da2020-08-10 16:44:18 +0100634 // Quantize - start with (optional) left shift
635 if (do_left_shift) {
636 v_in00 = vrshlq_s32(v_in00, v_shf0l);
637
638 v_in10 = vrshlq_s32(v_in10, v_shf0l);
639 }
640
641 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000642 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100643
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000644 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100645
646 // Compute and add on corrective offset
647 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000648 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100649
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000650 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100651
652 v_temp00 = vshrq_n_s32(v_temp00, 31);
653
654 v_temp10 = vshrq_n_s32(v_temp10, 31);
655
656 v_in00 = vqaddq_s32(v_in00, v_temp00);
657
658 v_in10 = vqaddq_s32(v_in10, v_temp10);
659 }
660
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000661 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100662
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000663 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100664
665 v_in00 = vaddq_s32(v_in00, v_c_offset);
666
667 v_in10 = vaddq_s32(v_in10, v_c_offset);
668
669 v_in00 = vmaxq_s32(v_in00, v_minval);
670
671 v_in10 = vmaxq_s32(v_in10, v_minval);
672
673 v_in00 = vminq_s32(v_in00, v_maxval);
674
675 v_in10 = vminq_s32(v_in10, v_maxval);
676
677 do {
678 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
679 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
680
681 if (odds==1) { break; }
682
683 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
684 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
685
686 if (odds==2) { break; }
687
688 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
689 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
690 } while(0);
691 }
692 }
693}
694
695} // anonymous namespace
696
697template<typename Tin, typename Tout>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000698void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100699 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100700 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000701 if (qp.per_channel_requant) {
702 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100703 if (qp.per_channel_left_shifts) {
704 requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
705 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
706 } else {
707 requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
708 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
709 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000710 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100711 if (qp.per_channel_left_shifts) {
712 requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
713 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
714 } else {
715 requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
716 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
717 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000718 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100719 } else {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000720 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100721 if (qp.per_layer_left_shift > 0) {
722 requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
723 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
724 } else {
725 requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
726 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
727 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000728 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100729 if (qp.per_layer_left_shift > 0) {
730 requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
731 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
732 } else {
733 requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
734 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
735 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000736 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100737 }
738}
739
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000740template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100741 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100742 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100743
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000744template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100745 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100746 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100747
748/*
749 * Routine (and helpers) to compute row sums needed for offset correction.
750 *
751 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
752 * of length 27), therefore it's important not to sacrifice performance on
753 * odd length rows.
754 *
755 * To minimize performance loss in these cases, this routine will overread
756 * by up to 7 bytes.
757 *
758 * This is handled via "mask" and "mask mode" parameters to the inner
759 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
760 * (inclusive) needed at the end; in these cases we always read 8 bytes.
761 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
762 * the end, and in this case we always read 16 bytes. In both cases the
763 * 'mask' vector is set up so that the read value can be masked off to clear
764 * the overread lanes. This is handled by 'accumulate_masked_8' and
765 * 'accumulate_masked_16' above.
766 *
767 * This routine is templated on the type to be accumulated, because the
768 * innermost instruction used needs to be of the correct signedness.
769 * However, beyond this point we always use signed values in both cases.
770 * The instructions that need to be different are therefore wrapped in
771 * helper functions below.
Michalis Spyrou400abc82019-08-20 17:25:25 +0100772 *
773 * The general strategy used is to load vectors of 16 bytes and accumulate
774 * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
775 * These are then reduced (using uadalp/sadalp again) into 4x32-bit
776 * accumulators. The 4 accumulators for up to 4 rows being processed are
777 * then added together into a single output vector using pairwise adds.
778 *
779 * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
780 * occur before the 16-bit accumulators can overflow - which is every 32
781 * iterations (512 total bytes processed). This is explained more below.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100782 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100783namespace {
784 struct row_sum_helpers {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000785 const Requantize32 &qp;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100786
787 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
788 template<typename T>
789 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
790
791 /* Load a full 16 byte vector, but mask before accumulation (see above). */
792 template<typename T>
793 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
794
795 /* Load 8 bytes and mask before accumulation. */
796 template<typename T>
797 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
798
799 /* This function does the actual work for up to 4 rows at a time.
800 * It's pulled out so we can template on the row count to generate
801 * the 4 different cases. 4 rows are computed at a time as this
802 * reduces to a single vector write. */
803 template<unsigned int rows, typename T>
804 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
805 int16x8_t sums[rows];
806 int32x4_t finalsums[rows];
807
808 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100809 sums[i] = vdupq_n_s16(0);
810 finalsums[i] = vdupq_n_s32(0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100811 }
812
813 for (unsigned int i=0; i<blocks; i++) {
814 for (unsigned int r=0; r<rows; r++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100815 /* If we add too many blocks together, we run the risk
816 * of overflowing the intermediate 16-bit accumulators,
817 * especially in the unsigned case where we later treat
818 * the accumulator as signed.
819 *
820 * In that case, the maximum (signed) value is 16383,
821 * which is safe for 64 (unsigned) accumulations (255*64
822 * = 16,320).
823 *
824 * Each invocation of pairwise add adds 2 values to the
825 * accumulator - so in the unsigned case we can do 32
826 * adds before we need to reset the 16-bit accumulator
827 * by adding into the 32-bit 'finalsums'.
828 *
829 * We could do 64 adds in the signed case, but that
830 * optimization is not worth the complexity.
831 */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100832 if (i > 0 && ((i & 31) == 0)) {
833 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
834 sums[r] = vdupq_n_s16(0);
835 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100836 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
837 }
838 }
839
840 /* Handle the final masked read if needed. */
841 if (mask_mode > 0) {
842 for (unsigned int r=0; r<rows; r++) {
843 if (mask_mode == 1) {
844 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
845 } else {
846 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
847 }
848 }
849 }
850
851 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100852 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100853 }
854
855 int32x4_t t0, t1;
856 int32x2_t t2;
857
858 /* Result writeback - need to write back one value per row
859 * processed. Multiply all the final totals by -b_offset so
860 * that the terms can simply be added in the requantize code.
861 * */
862 switch (rows) {
863 case 1:
864 /* If we only have one output, just use ADDV. Multiply
865 * the offset into all four components separately so it
866 * can stay in the SIMD register file. */
867 t0 = vmulq_s32(finalsums[0], offset_mul);
868 *row_bias = vaddvq_s32(t0);
869 break;
870
871 case 2:
872 /* For two outputs, two rounds of pairwise adds will
873 * generate the result in a 2-vector we can store in one
874 * go. */
875 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
876 t0 = vpaddq_s32(t0, t0);
877 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
878 vst1_s32(row_bias, t2);
879 break;
880
881 case 3:
882 /* Three rows - need to store the low two words plus the odd value from lane 2 */
883 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
884 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
885
886 t0 = vpaddq_s32(t0, t1);
887 t0 = vmulq_s32(t0, offset_mul);
888
889 vst1_s32(row_bias, vget_low_s32(t0));
890 row_bias[2] = vgetq_lane_s32(t0, 2);
891 break;
892
893 case 4:
894 /* Four rows (most common case) - reduce to a single
895 * vector with pairwise adds. */
896 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
897 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
898
899 t0 = vpaddq_s32(t0, t1);
900 t0 = vmulq_s32(t0, offset_mul);
901
902 vst1q_s32(row_bias, t0);
903 break;
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100904
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100905 default:
906 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100907 }
908 }
909
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000910 row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100911 };
912
913 template<>
914 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
915 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
916 }
917
918 template<>
919 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
920 return vpadalq_s8(sum, vld1q_s8(ptr));
921 }
922
923 template<>
924 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
925 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
926 return vpadalq_s8(sum, v);
927 }
928
929 template<>
930 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
931 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
932 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
933 }
934
935 template<>
936 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
937 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
938 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
939 return vpadalq_s8(sum, v);
940 }
941
942 template<>
943 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
944 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
945 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
946 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
947 }
948}
949
950template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000951void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100952 const T *input, unsigned int in_stride, int32_t *row_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000953 /* If the 'b' offset is zero, just skip this entirely. */
954 if (qp.b_offset == 0) {
955 memset(row_bias, 0, height * sizeof(int32_t));
956 return;
957 }
958
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100959 row_sum_helpers thehelpers(qp);
960
961 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
962
963 /* Work out how many full vectors of 16 bytes we will read, and how many
964 * odd bytes at the end */
965 unsigned int blocks = (width / 16);
966 const unsigned int odds = width % 16;
967
968 /* Generate a mask to use on the last iteration, if necessary. */
969 uint64x2_t mask;
970 unsigned int mask_mode = 0;
971
972 if (odds > 0 && odds <= 8) {
973 /* 1-8 odds: mask in the low lane, 0 in the top */
974 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
975
976 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
977
978 mask_mode = 1;
979 } else if (odds > 8) {
980 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
981 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
982
983 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
984
985 mask_mode = 2;
986 }
987
988 for (unsigned int row=0; row<height; row+=4) {
989 switch(height-row) {
990 default:
991 case 4:
992 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
993 break;
994 case 3:
995 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
996 break;
997 case 2:
998 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
999 break;
1000 case 1:
1001 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
1002 break;
1003 }
1004 }
1005}
1006
1007/* Instantiate the two versions for uint8_t and int8_t. */
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001008template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
1009template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001010
1011template<unsigned int active_rows, typename T>
1012inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
1013
1014template<unsigned int active_rows>
1015inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
1016 uint8x16_t inputs[4];
1017
1018 for (unsigned int i=0; i<4; i++) {
1019 if (i < active_rows) {
1020 inputs[i] = vld1q_u8(input + i * in_stride);
1021 } else {
1022 inputs[i] = vdupq_n_u8(0);
1023 }
1024 }
1025
1026 int16x8_t sums_16b[4];
1027
1028 // Two adds for the low pairs
1029 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
1030 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
1031 // Two adds for the high pairs
1032 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
1033 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
1034
1035 int32x4_t sums_32b[4];
1036
1037 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1038 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1039 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1040 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1041
1042 for (unsigned int i=0; i<4; i++) {
1043 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1044 }
1045}
1046
1047template<unsigned int active_rows>
1048inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
1049 int8x16_t inputs[4];
1050
1051 for (unsigned int i=0; i<4; i++) {
1052 if (i < active_rows) {
1053 inputs[i] = vld1q_s8(input + i * in_stride);
1054 } else {
1055 inputs[i] = vdupq_n_s8(0);
1056 }
1057 }
1058
1059 int16x8_t sums_16b[4];
1060
1061 // Two adds for the low pairs
1062 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
1063 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
1064 // Two adds for the high pairs
1065 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
1066 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
1067
1068 int32x4_t sums_32b[4];
1069
1070 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1071 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1072 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1073 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1074
1075 for (unsigned int i=0; i<4; i++) {
1076 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1077 }
1078}
1079
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001080/* "first_col" parameter is used to offset the read into the qp.bias array,
1081 * in cases where we are not computing the first columns of the output (i.e.
1082 * in multithreaded cases where we divide columns across threads) */
1083template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001084void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
1085 /* Only actually add up the columns if a_offset is non-zero. */
1086 if (qp.a_offset != 0) {
1087 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001088
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001089 for (unsigned int row=0; row<height; row+=4) {
1090 unsigned int numrows=std::min(height-row, 4u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001091
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001092 for (unsigned int col=0; col<width; col+=16) {
1093 unsigned int numcols=std::min(width-col, 16u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001094
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001095 if (numcols==16) {
1096 switch(numrows) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001097 case 1:
1098 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
1099 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001100
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001101 case 2:
1102 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
1103 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001104
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001105 case 3:
1106 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
1107 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001108
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001109 case 4:
1110 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
1111 break;
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +01001112
1113 default:
1114 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001115 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001116 } else {
1117 for (; col<width; col++) {
1118 int32_t sum=0;
1119 for (unsigned int r=0; r<numrows; r++) {
1120 sum += input[(row + r)*in_stride + col];
1121 }
1122 col_bias[col] += sum;
1123 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001124 }
1125 }
1126 }
1127 }
1128
1129 for (unsigned int col=0; col<width; col++) {
1130 int32_t result = col_bias[col];
1131
1132 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
1133
1134 if (qp.bias != nullptr) {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +01001135 result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001136 }
1137
1138 col_bias[col] = result;
1139 }
1140}
1141
Michalis Spyrou71ac9032019-11-14 14:31:44 +00001142template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
1143template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001144
Jonathan Deakina668f9f2024-01-24 09:15:38 +00001145void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
1146 const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride,
1147 const float* bias_ptr, bool accumulate, const Activation &act)
1148{
1149 const float32x4_t vscale = vdupq_n_f32(qp.scale);
1150 float maxval = std::numeric_limits<float>::infinity();
1151 float minval = -std::numeric_limits<float>::infinity();
1152
1153 switch(act.type) {
1154 default:
1155 case Activation::Type::None:
1156 break;
1157 case Activation::Type::BoundedReLU:
1158 maxval = static_cast<float>(act.param1);
1159 /* fall through */
1160 case Activation::Type::ReLU:
1161 minval = 0;
1162 break;
1163 }
1164
1165 const float32x4_t vmin = vdupq_n_f32(minval);
1166 const float32x4_t vmax = vdupq_n_f32(maxval);
1167
1168 for(unsigned int row=0; row<height; row++) {
1169 auto row_in_ptr = in_ptr + (row * in_stride);
1170 auto row_out_ptr = out_ptr + (row * out_stride);
1171 unsigned int col=0;
1172 if (width >= 4) {
1173 for(; col <= (width - 4); col+= 4) {
1174 const int32x4_t vin = vld1q_s32(row_in_ptr + col);
1175 float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale);
1176 if(bias_ptr) {
1177 const float32x4_t bin = vld1q_f32(bias_ptr + col);
1178 vdeq = vaddq_f32(vdeq, bin);
1179 }
1180 if(accumulate) {
1181 vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col));
1182 }
1183 vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax);
1184 vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq);
1185 }
1186 }
1187 // left-over elements
1188 for(; col < width; ++col) {
1189 const int32_t val = *(row_in_ptr + col);
1190 float res = static_cast<float>(val * qp.scale);
1191 if(bias_ptr) {
1192 res += static_cast<float>(*(bias_ptr + col));
1193 }
1194 if(accumulate) {
1195 res += *(row_out_ptr + col);
1196 }
1197 res = std::min(std::max(res, minval), maxval);
1198 *(row_out_ptr + col) = res;
1199 }
1200 }
1201}
1202
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001203} // namespace arm_gemm
Georgios Pinitasf33484f2019-07-29 12:40:59 +01001204
1205#endif // __aarch64__