blob: 201bd9dc2c0db431792da8617cef69599723dd5e [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasf33484f2019-07-29 12:40:59 +010024#ifdef __aarch64__
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010025
26#include "arm_gemm.hpp"
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include "utils.hpp"
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010028
29#include <arm_neon.h>
30
31namespace arm_gemm {
32
33namespace {
34
35/* Requantize a block of data, using the requantize parameters in 'qp'.
36 *
37 * row_bias and col_bias are assumed to be precomputed values which include
38 * any externally supplied bias, plus the row/column contibution sums, plus
39 * the overall constant offset (A_offset * B_offset * depth).
40 *
41 * Note that this function works equally well for uint8_t output: just set
42 * minval/maxval appropriately and cast the output pointer. It is caller's
43 * responsibility to ensure that minval/maxval are representable in the
44 * target type - the downcast to (u)int8_t is done by simply extracting the
45 * LSB.
46 *
47 * The 'do_shift_correction' template parameter turns on the correction
48 * applied to negative values being shifted right to make sure they round
49 * properly - if negative values are never output (e.g. fused ReLU) this is
50 * unnecessary.
Michalis Spyrou71ac9032019-11-14 14:31:44 +000051 *
52 * The 'per_channel' template parameter selects between per channel and per
53 * layer requantization - in the former case we need to load vectors of
54 * shifts and multipliers for each column. A separate vector for each
55 * column is set up in any case (and it is hoped that the compiler can elide
56 * the needless movs in the per-layer case).
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010057 */
morgolock0bc80da2020-08-10 16:44:18 +010058template<bool do_shift_correction, bool per_channel, bool do_left_shift>
Michalis Spyrou71ac9032019-11-14 14:31:44 +000059void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010060 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +010061 const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
morgolock0bc80da2020-08-10 16:44:18 +010062 const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
63 const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
64 const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
65 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
66 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
67 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010068
69 /* To make sure we have plenty of accumulators, compute two rows at a
70 * time. If the number of rows is odd, compute the bottom row twice to
71 * avoid needing a duplicate codepath. */
72 for (unsigned int row=0; row<height; row+=2) {
73 /* Prefer to do 4 vectors (16 values) at once as this collapses
74 * neatly to a single vector of output, failing that a vector at a
75 * time and then the odd ones out at the end. */
76 unsigned int blocks=(width / 16);
77 unsigned int regs=(width % 16) / 4;
78 unsigned int odds=(width % 4);
79
80 const int32_t *colptr = col_bias;
morgolock0bc80da2020-08-10 16:44:18 +010081 const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
82 const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
83 const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010084
85 const int32_t *in_ptr = input + (row * in_stride);
86 int8_t *out_ptr = output + (row * out_stride);
87 int32_t row_sum = row_bias[row];
88
89 const int32_t *in_ptr1;
90 int8_t *out_ptr1;
91 int32_t row_sum1;
92
93 if (row == height-1) {
94 in_ptr1 = in_ptr;
95 out_ptr1 = out_ptr;
96 row_sum1 = row_sum;
97 } else {
98 in_ptr1 = in_ptr + in_stride;
99 out_ptr1 = out_ptr + out_stride;
100 row_sum1 = row_bias[row+1];
101 }
102
103 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
104 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
105
106 while (blocks--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000107 int32x4_t v_mul0;
108 int32x4_t v_mul1;
109 int32x4_t v_mul2;
110 int32x4_t v_mul3;
111
112 int32x4_t v_shf0;
113 int32x4_t v_shf1;
114 int32x4_t v_shf2;
115 int32x4_t v_shf3;
116
morgolock0bc80da2020-08-10 16:44:18 +0100117 int32x4_t v_shf0l;
118 int32x4_t v_shf1l;
119 int32x4_t v_shf2l;
120 int32x4_t v_shf3l;
121
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000122 if (per_channel) {
123 v_mul0 = vld1q_s32(perch_mul_ptr);
124 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
125 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
126 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
127 perch_mul_ptr += 16;
128
129 v_shf0 = vld1q_s32(perch_shift_ptr);
130 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
131 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
132 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
133 perch_shift_ptr += 16;
morgolock0bc80da2020-08-10 16:44:18 +0100134
135 if (do_left_shift) {
136 v_shf0l = vld1q_s32(perch_shiftl_ptr);
137 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
138 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
139 v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
140 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000141 } else {
142 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100143 v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
144 v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000145 }
146
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100147 // Load column pointers
148 int32x4_t v_col0 = vld1q_s32(colptr);
149 int32x4_t v_col1 = vld1q_s32(colptr + 4);
150 int32x4_t v_col2 = vld1q_s32(colptr + 8);
151 int32x4_t v_col3 = vld1q_s32(colptr + 12);
152 colptr += 16;
153
154 // Load input data (row 0);
155 int32x4_t v_in00 = vld1q_s32(in_ptr);
156 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
157 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
158 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
159 in_ptr += 16;
160
161 // Load input data (row 1);
162 int32x4_t v_in10 = vld1q_s32(in_ptr1);
163 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
164 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
165 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
166 in_ptr1 += 16;
167
168 // Add on row bias and column bias
169 v_in00 = vaddq_s32(v_in00, v_row_sum);
170 v_in01 = vaddq_s32(v_in01, v_row_sum);
171 v_in02 = vaddq_s32(v_in02, v_row_sum);
172 v_in03 = vaddq_s32(v_in03, v_row_sum);
173
174 v_in10 = vaddq_s32(v_in10, v_row_sum1);
175 v_in11 = vaddq_s32(v_in11, v_row_sum1);
176 v_in12 = vaddq_s32(v_in12, v_row_sum1);
177 v_in13 = vaddq_s32(v_in13, v_row_sum1);
178
179 v_in00 = vaddq_s32(v_in00, v_col0);
180 v_in01 = vaddq_s32(v_in01, v_col1);
181 v_in02 = vaddq_s32(v_in02, v_col2);
182 v_in03 = vaddq_s32(v_in03, v_col3);
183
184 v_in10 = vaddq_s32(v_in10, v_col0);
185 v_in11 = vaddq_s32(v_in11, v_col1);
186 v_in12 = vaddq_s32(v_in12, v_col2);
187 v_in13 = vaddq_s32(v_in13, v_col3);
188
morgolock0bc80da2020-08-10 16:44:18 +0100189 // Quantize
190
191 // If a left shift is needed it needs to happen first.
192 if (do_left_shift) {
193 v_in00 = vrshlq_s32(v_in00, v_shf0l);
194 v_in01 = vrshlq_s32(v_in01, v_shf1l);
195 v_in02 = vrshlq_s32(v_in02, v_shf2l);
196 v_in03 = vrshlq_s32(v_in03, v_shf3l);
197
198 v_in10 = vrshlq_s32(v_in10, v_shf0l);
199 v_in11 = vrshlq_s32(v_in11, v_shf1l);
200 v_in12 = vrshlq_s32(v_in12, v_shf2l);
201 v_in13 = vrshlq_s32(v_in13, v_shf3l);
202 }
203
204 // Multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000205 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
206 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
207 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
208 v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100209
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000210 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
211 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
212 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
213 v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100214
215 // Compute and add on corrective offset
216 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000217 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
218 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
219 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
220 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100221
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000222 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
223 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
224 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
225 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100226
227 v_temp00 = vshrq_n_s32(v_temp00, 31);
228 v_temp01 = vshrq_n_s32(v_temp01, 31);
229 v_temp02 = vshrq_n_s32(v_temp02, 31);
230 v_temp03 = vshrq_n_s32(v_temp03, 31);
231
232 v_temp10 = vshrq_n_s32(v_temp10, 31);
233 v_temp11 = vshrq_n_s32(v_temp11, 31);
234 v_temp12 = vshrq_n_s32(v_temp12, 31);
235 v_temp13 = vshrq_n_s32(v_temp13, 31);
236
237 v_in00 = vqaddq_s32(v_in00, v_temp00);
238 v_in01 = vqaddq_s32(v_in01, v_temp01);
239 v_in02 = vqaddq_s32(v_in02, v_temp02);
240 v_in03 = vqaddq_s32(v_in03, v_temp03);
241
242 v_in10 = vqaddq_s32(v_in10, v_temp10);
243 v_in11 = vqaddq_s32(v_in11, v_temp11);
244 v_in12 = vqaddq_s32(v_in12, v_temp12);
245 v_in13 = vqaddq_s32(v_in13, v_temp13);
246 }
247
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000248 v_in00 = vrshlq_s32(v_in00, v_shf0);
249 v_in01 = vrshlq_s32(v_in01, v_shf1);
250 v_in02 = vrshlq_s32(v_in02, v_shf2);
251 v_in03 = vrshlq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100252
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000253 v_in10 = vrshlq_s32(v_in10, v_shf0);
254 v_in11 = vrshlq_s32(v_in11, v_shf1);
255 v_in12 = vrshlq_s32(v_in12, v_shf2);
256 v_in13 = vrshlq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100257
258 v_in00 = vaddq_s32(v_in00, v_c_offset);
259 v_in01 = vaddq_s32(v_in01, v_c_offset);
260 v_in02 = vaddq_s32(v_in02, v_c_offset);
261 v_in03 = vaddq_s32(v_in03, v_c_offset);
262
263 v_in10 = vaddq_s32(v_in10, v_c_offset);
264 v_in11 = vaddq_s32(v_in11, v_c_offset);
265 v_in12 = vaddq_s32(v_in12, v_c_offset);
266 v_in13 = vaddq_s32(v_in13, v_c_offset);
267
268 v_in00 = vmaxq_s32(v_in00, v_minval);
269 v_in01 = vmaxq_s32(v_in01, v_minval);
270 v_in02 = vmaxq_s32(v_in02, v_minval);
271 v_in03 = vmaxq_s32(v_in03, v_minval);
272
273 v_in10 = vmaxq_s32(v_in10, v_minval);
274 v_in11 = vmaxq_s32(v_in11, v_minval);
275 v_in12 = vmaxq_s32(v_in12, v_minval);
276 v_in13 = vmaxq_s32(v_in13, v_minval);
277
278 v_in00 = vminq_s32(v_in00, v_maxval);
279 v_in01 = vminq_s32(v_in01, v_maxval);
280 v_in02 = vminq_s32(v_in02, v_maxval);
281 v_in03 = vminq_s32(v_in03, v_maxval);
282
283 v_in10 = vminq_s32(v_in10, v_maxval);
284 v_in11 = vminq_s32(v_in11, v_maxval);
285 v_in12 = vminq_s32(v_in12, v_maxval);
286 v_in13 = vminq_s32(v_in13, v_maxval);
287
288 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
289 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
290
291 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
292 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
293
294 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
295 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
296
297 vst1q_s8(out_ptr, v_uz0);
298 out_ptr += 16;
299 vst1q_s8(out_ptr1, v_uz1);
300 out_ptr1 += 16;
301 }
302
303 while (regs--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000304 int32x4_t v_mul0;
305 int32x4_t v_shf0;
morgolock0bc80da2020-08-10 16:44:18 +0100306 int32x4_t v_shf0l;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000307
308 if (per_channel) {
309 v_mul0 = vld1q_s32(perch_mul_ptr);
310 perch_mul_ptr += 4;
311
312 v_shf0 = vld1q_s32(perch_shift_ptr);
313 perch_shift_ptr += 4;
morgolock0bc80da2020-08-10 16:44:18 +0100314
315 if (do_left_shift) {
316 v_shf0l = vld1q_s32(perch_shiftl_ptr);
317 perch_shiftl_ptr += 4;
318 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000319 } else {
320 v_mul0=v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100321 v_shf0=v_right_shift;
322 v_shf0l=v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000323 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100324 // Load column pointers
325 int32x4_t v_col0 = vld1q_s32(colptr);
326 colptr += 4;
327
328 // Load input data (row 0);
329 int32x4_t v_in00 = vld1q_s32(in_ptr);
330 in_ptr += 4;
331
332 // Load input data (row 1);
333 int32x4_t v_in10 = vld1q_s32(in_ptr1);
334 in_ptr1 += 4;
335
336 // Add on row sum and bias constant
337 v_in00 = vaddq_s32(v_in00, v_row_sum);
338
339 v_in10 = vaddq_s32(v_in10, v_row_sum1);
340
341 // Subtract col sum * a_offset
342 v_in00 = vaddq_s32(v_in00, v_col0);
343
344 v_in10 = vaddq_s32(v_in10, v_col0);
345
morgolock0bc80da2020-08-10 16:44:18 +0100346 // Quantize - start with (optional) left shift
347 if (do_left_shift) {
348 v_in00 = vrshlq_s32(v_in00, v_shf0l);
349
350 v_in10 = vrshlq_s32(v_in10, v_shf0l);
351 }
352
353 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000354 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100355
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000356 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100357
358 // Compute and add on corrective offset
359 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000360 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100361
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000362 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100363
364 v_temp00 = vshrq_n_s32(v_temp00, 31);
365
366 v_temp10 = vshrq_n_s32(v_temp10, 31);
367
368 v_in00 = vqaddq_s32(v_in00, v_temp00);
369
370 v_in10 = vqaddq_s32(v_in10, v_temp10);
371 }
372
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000373 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100374
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000375 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100376
377 v_in00 = vaddq_s32(v_in00, v_c_offset);
378
379 v_in10 = vaddq_s32(v_in10, v_c_offset);
380
381 v_in00 = vmaxq_s32(v_in00, v_minval);
382
383 v_in10 = vmaxq_s32(v_in10, v_minval);
384
385 v_in00 = vminq_s32(v_in00, v_maxval);
386
387 v_in10 = vminq_s32(v_in10, v_maxval);
388
389 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
390
391 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
392
393 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
394 out_ptr += 4;
395 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
396 out_ptr1 += 4;
397 }
398
399 if (odds) {
400 int32x4_t v_col0 = vdupq_n_s32(0);
401 int32x4_t v_in00 = vdupq_n_s32(0);
402 int32x4_t v_in10 = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000403 int32x4_t v_mul0 = vdupq_n_s32(0);
404 int32x4_t v_shf0 = vdupq_n_s32(0);
morgolock0bc80da2020-08-10 16:44:18 +0100405 int32x4_t v_shf0l = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000406
407 if (!per_channel) {
408 v_mul0 = v_mul;
morgolock0bc80da2020-08-10 16:44:18 +0100409 v_shf0 = v_right_shift;
410 v_shf0l = v_left_shift;
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000411 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100412
413 do {
414 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
415 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
416 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000417 if (per_channel) {
418 v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
419 v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
morgolock0bc80da2020-08-10 16:44:18 +0100420 if (do_left_shift) {
421 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
422 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000423 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100424 if (odds == 1) { break; }
425
426 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
427 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
428 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000429 if (per_channel) {
430 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
431 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
morgolock0bc80da2020-08-10 16:44:18 +0100432 if (do_left_shift) {
433 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
434 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000435 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100436 if (odds == 2) { break; }
437
438 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
439 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
440 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000441 if (per_channel) {
442 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
443 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
morgolock0bc80da2020-08-10 16:44:18 +0100444 if (do_left_shift) {
445 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
446 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000447 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100448 } while (0);
449
450 // Add on row sum and bias constant
451 v_in00 = vaddq_s32(v_in00, v_row_sum);
452
453 v_in10 = vaddq_s32(v_in10, v_row_sum1);
454
455 // Subtract col sum * a_offset
456 v_in00 = vaddq_s32(v_in00, v_col0);
457
458 v_in10 = vaddq_s32(v_in10, v_col0);
459
morgolock0bc80da2020-08-10 16:44:18 +0100460 // Quantize - start with (optional) left shift
461 if (do_left_shift) {
462 v_in00 = vrshlq_s32(v_in00, v_shf0l);
463
464 v_in10 = vrshlq_s32(v_in10, v_shf0l);
465 }
466
467 // Then multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000468 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100469
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000470 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100471
472 // Compute and add on corrective offset
473 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000474 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100475
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000476 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100477
478 v_temp00 = vshrq_n_s32(v_temp00, 31);
479
480 v_temp10 = vshrq_n_s32(v_temp10, 31);
481
482 v_in00 = vqaddq_s32(v_in00, v_temp00);
483
484 v_in10 = vqaddq_s32(v_in10, v_temp10);
485 }
486
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000487 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100488
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000489 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100490
491 v_in00 = vaddq_s32(v_in00, v_c_offset);
492
493 v_in10 = vaddq_s32(v_in10, v_c_offset);
494
495 v_in00 = vmaxq_s32(v_in00, v_minval);
496
497 v_in10 = vmaxq_s32(v_in10, v_minval);
498
499 v_in00 = vminq_s32(v_in00, v_maxval);
500
501 v_in10 = vminq_s32(v_in10, v_maxval);
502
503 do {
504 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
505 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
506
507 if (odds==1) { break; }
508
509 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
510 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
511
512 if (odds==2) { break; }
513
514 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
515 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
516 } while(0);
517 }
518 }
519}
520
521} // anonymous namespace
522
523template<typename Tin, typename Tout>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000524void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100525 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100526 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000527 if (qp.per_channel_requant) {
528 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100529 if (qp.per_channel_left_shifts) {
530 requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
531 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
532 } else {
533 requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
534 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
535 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000536 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100537 if (qp.per_channel_left_shifts) {
538 requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
539 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
540 } else {
541 requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
542 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
543 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000544 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100545 } else {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000546 if (qp.minval >= qp.c_offset) {
morgolock0bc80da2020-08-10 16:44:18 +0100547 if (qp.per_layer_left_shift > 0) {
548 requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
549 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
550 } else {
551 requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
552 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
553 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000554 } else {
morgolock0bc80da2020-08-10 16:44:18 +0100555 if (qp.per_layer_left_shift > 0) {
556 requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
557 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
558 } else {
559 requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
560 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
561 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000562 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100563 }
564}
565
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000566template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100567 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100568 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100569
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000570template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100571 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
Georgios Pinitasaf56d522020-07-01 12:35:30 +0100572 const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100573
574/*
575 * Routine (and helpers) to compute row sums needed for offset correction.
576 *
577 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
578 * of length 27), therefore it's important not to sacrifice performance on
579 * odd length rows.
580 *
581 * To minimize performance loss in these cases, this routine will overread
582 * by up to 7 bytes.
583 *
584 * This is handled via "mask" and "mask mode" parameters to the inner
585 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
586 * (inclusive) needed at the end; in these cases we always read 8 bytes.
587 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
588 * the end, and in this case we always read 16 bytes. In both cases the
589 * 'mask' vector is set up so that the read value can be masked off to clear
590 * the overread lanes. This is handled by 'accumulate_masked_8' and
591 * 'accumulate_masked_16' above.
592 *
593 * This routine is templated on the type to be accumulated, because the
594 * innermost instruction used needs to be of the correct signedness.
595 * However, beyond this point we always use signed values in both cases.
596 * The instructions that need to be different are therefore wrapped in
597 * helper functions below.
Michalis Spyrou400abc82019-08-20 17:25:25 +0100598 *
599 * The general strategy used is to load vectors of 16 bytes and accumulate
600 * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
601 * These are then reduced (using uadalp/sadalp again) into 4x32-bit
602 * accumulators. The 4 accumulators for up to 4 rows being processed are
603 * then added together into a single output vector using pairwise adds.
604 *
605 * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
606 * occur before the 16-bit accumulators can overflow - which is every 32
607 * iterations (512 total bytes processed). This is explained more below.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100608 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100609namespace {
610 struct row_sum_helpers {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000611 const Requantize32 &qp;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100612
613 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
614 template<typename T>
615 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
616
617 /* Load a full 16 byte vector, but mask before accumulation (see above). */
618 template<typename T>
619 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
620
621 /* Load 8 bytes and mask before accumulation. */
622 template<typename T>
623 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
624
625 /* This function does the actual work for up to 4 rows at a time.
626 * It's pulled out so we can template on the row count to generate
627 * the 4 different cases. 4 rows are computed at a time as this
628 * reduces to a single vector write. */
629 template<unsigned int rows, typename T>
630 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
631 int16x8_t sums[rows];
632 int32x4_t finalsums[rows];
633
634 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100635 sums[i] = vdupq_n_s16(0);
636 finalsums[i] = vdupq_n_s32(0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100637 }
638
639 for (unsigned int i=0; i<blocks; i++) {
640 for (unsigned int r=0; r<rows; r++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100641 /* If we add too many blocks together, we run the risk
642 * of overflowing the intermediate 16-bit accumulators,
643 * especially in the unsigned case where we later treat
644 * the accumulator as signed.
645 *
646 * In that case, the maximum (signed) value is 16383,
647 * which is safe for 64 (unsigned) accumulations (255*64
648 * = 16,320).
649 *
650 * Each invocation of pairwise add adds 2 values to the
651 * accumulator - so in the unsigned case we can do 32
652 * adds before we need to reset the 16-bit accumulator
653 * by adding into the 32-bit 'finalsums'.
654 *
655 * We could do 64 adds in the signed case, but that
656 * optimization is not worth the complexity.
657 */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100658 if (i > 0 && ((i & 31) == 0)) {
659 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
660 sums[r] = vdupq_n_s16(0);
661 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100662 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
663 }
664 }
665
666 /* Handle the final masked read if needed. */
667 if (mask_mode > 0) {
668 for (unsigned int r=0; r<rows; r++) {
669 if (mask_mode == 1) {
670 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
671 } else {
672 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
673 }
674 }
675 }
676
677 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100678 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100679 }
680
681 int32x4_t t0, t1;
682 int32x2_t t2;
683
684 /* Result writeback - need to write back one value per row
685 * processed. Multiply all the final totals by -b_offset so
686 * that the terms can simply be added in the requantize code.
687 * */
688 switch (rows) {
689 case 1:
690 /* If we only have one output, just use ADDV. Multiply
691 * the offset into all four components separately so it
692 * can stay in the SIMD register file. */
693 t0 = vmulq_s32(finalsums[0], offset_mul);
694 *row_bias = vaddvq_s32(t0);
695 break;
696
697 case 2:
698 /* For two outputs, two rounds of pairwise adds will
699 * generate the result in a 2-vector we can store in one
700 * go. */
701 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
702 t0 = vpaddq_s32(t0, t0);
703 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
704 vst1_s32(row_bias, t2);
705 break;
706
707 case 3:
708 /* Three rows - need to store the low two words plus the odd value from lane 2 */
709 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
710 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
711
712 t0 = vpaddq_s32(t0, t1);
713 t0 = vmulq_s32(t0, offset_mul);
714
715 vst1_s32(row_bias, vget_low_s32(t0));
716 row_bias[2] = vgetq_lane_s32(t0, 2);
717 break;
718
719 case 4:
720 /* Four rows (most common case) - reduce to a single
721 * vector with pairwise adds. */
722 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
723 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
724
725 t0 = vpaddq_s32(t0, t1);
726 t0 = vmulq_s32(t0, offset_mul);
727
728 vst1q_s32(row_bias, t0);
729 break;
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100730
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100731 default:
732 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100733 }
734 }
735
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000736 row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100737 };
738
739 template<>
740 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
741 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
742 }
743
744 template<>
745 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
746 return vpadalq_s8(sum, vld1q_s8(ptr));
747 }
748
749 template<>
750 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
751 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
752 return vpadalq_s8(sum, v);
753 }
754
755 template<>
756 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
757 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
758 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
759 }
760
761 template<>
762 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
763 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
764 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
765 return vpadalq_s8(sum, v);
766 }
767
768 template<>
769 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
770 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
771 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
772 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
773 }
774}
775
776template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000777void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100778 const T *input, unsigned int in_stride, int32_t *row_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000779 /* If the 'b' offset is zero, just skip this entirely. */
780 if (qp.b_offset == 0) {
781 memset(row_bias, 0, height * sizeof(int32_t));
782 return;
783 }
784
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100785 row_sum_helpers thehelpers(qp);
786
787 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
788
789 /* Work out how many full vectors of 16 bytes we will read, and how many
790 * odd bytes at the end */
791 unsigned int blocks = (width / 16);
792 const unsigned int odds = width % 16;
793
794 /* Generate a mask to use on the last iteration, if necessary. */
795 uint64x2_t mask;
796 unsigned int mask_mode = 0;
797
798 if (odds > 0 && odds <= 8) {
799 /* 1-8 odds: mask in the low lane, 0 in the top */
800 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
801
802 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
803
804 mask_mode = 1;
805 } else if (odds > 8) {
806 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
807 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
808
809 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
810
811 mask_mode = 2;
812 }
813
814 for (unsigned int row=0; row<height; row+=4) {
815 switch(height-row) {
816 default:
817 case 4:
818 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
819 break;
820 case 3:
821 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
822 break;
823 case 2:
824 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
825 break;
826 case 1:
827 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
828 break;
829 }
830 }
831}
832
833/* Instantiate the two versions for uint8_t and int8_t. */
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000834template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
835template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100836
837template<unsigned int active_rows, typename T>
838inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
839
840template<unsigned int active_rows>
841inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
842 uint8x16_t inputs[4];
843
844 for (unsigned int i=0; i<4; i++) {
845 if (i < active_rows) {
846 inputs[i] = vld1q_u8(input + i * in_stride);
847 } else {
848 inputs[i] = vdupq_n_u8(0);
849 }
850 }
851
852 int16x8_t sums_16b[4];
853
854 // Two adds for the low pairs
855 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
856 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
857 // Two adds for the high pairs
858 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
859 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
860
861 int32x4_t sums_32b[4];
862
863 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
864 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
865 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
866 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
867
868 for (unsigned int i=0; i<4; i++) {
869 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
870 }
871}
872
873template<unsigned int active_rows>
874inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
875 int8x16_t inputs[4];
876
877 for (unsigned int i=0; i<4; i++) {
878 if (i < active_rows) {
879 inputs[i] = vld1q_s8(input + i * in_stride);
880 } else {
881 inputs[i] = vdupq_n_s8(0);
882 }
883 }
884
885 int16x8_t sums_16b[4];
886
887 // Two adds for the low pairs
888 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
889 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
890 // Two adds for the high pairs
891 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
892 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
893
894 int32x4_t sums_32b[4];
895
896 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
897 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
898 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
899 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
900
901 for (unsigned int i=0; i<4; i++) {
902 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
903 }
904}
905
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100906/* "first_col" parameter is used to offset the read into the qp.bias array,
907 * in cases where we are not computing the first columns of the output (i.e.
908 * in multithreaded cases where we divide columns across threads) */
909template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000910void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
911 /* Only actually add up the columns if a_offset is non-zero. */
912 if (qp.a_offset != 0) {
913 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100914
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000915 for (unsigned int row=0; row<height; row+=4) {
916 unsigned int numrows=std::min(height-row, 4u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100917
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000918 for (unsigned int col=0; col<width; col+=16) {
919 unsigned int numcols=std::min(width-col, 16u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100920
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000921 if (numcols==16) {
922 switch(numrows) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000923 case 1:
924 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
925 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100926
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000927 case 2:
928 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
929 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100930
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000931 case 3:
932 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
933 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100934
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000935 case 4:
936 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
937 break;
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100938
939 default:
940 UNREACHABLE("Impossible.");
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100941 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000942 } else {
943 for (; col<width; col++) {
944 int32_t sum=0;
945 for (unsigned int r=0; r<numrows; r++) {
946 sum += input[(row + r)*in_stride + col];
947 }
948 col_bias[col] += sum;
949 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100950 }
951 }
952 }
953 }
954
955 for (unsigned int col=0; col<width; col++) {
956 int32_t result = col_bias[col];
957
958 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
959
960 if (qp.bias != nullptr) {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100961 result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100962 }
963
964 col_bias[col] = result;
965 }
966}
967
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000968template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
969template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100970
971} // namespace arm_gemm
Georgios Pinitasf33484f2019-07-29 12:40:59 +0100972
973#endif // __aarch64__