blob: 00b42cf4227a960f00f5c510898b7669d351abf3 [file] [log] [blame]
Georgios Pinitascfa2bba2019-06-27 17:00:52 +01001/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasf33484f2019-07-29 12:40:59 +010024#ifdef __aarch64__
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010025
26#include "arm_gemm.hpp"
27
28#include <arm_neon.h>
29
30namespace arm_gemm {
31
32namespace {
33
34/* Requantize a block of data, using the requantize parameters in 'qp'.
35 *
36 * row_bias and col_bias are assumed to be precomputed values which include
37 * any externally supplied bias, plus the row/column contibution sums, plus
38 * the overall constant offset (A_offset * B_offset * depth).
39 *
40 * Note that this function works equally well for uint8_t output: just set
41 * minval/maxval appropriately and cast the output pointer. It is caller's
42 * responsibility to ensure that minval/maxval are representable in the
43 * target type - the downcast to (u)int8_t is done by simply extracting the
44 * LSB.
45 *
46 * The 'do_shift_correction' template parameter turns on the correction
47 * applied to negative values being shifted right to make sure they round
48 * properly - if negative values are never output (e.g. fused ReLU) this is
49 * unnecessary.
Michalis Spyrou71ac9032019-11-14 14:31:44 +000050 *
51 * The 'per_channel' template parameter selects between per channel and per
52 * layer requantization - in the former case we need to load vectors of
53 * shifts and multipliers for each column. A separate vector for each
54 * column is set up in any case (and it is hoped that the compiler can elide
55 * the needless movs in the per-layer case).
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010056 */
Michalis Spyrou71ac9032019-11-14 14:31:44 +000057template<bool do_shift_correction, bool per_channel>
58void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010059 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
60 const int32_t *row_bias, const int32_t *col_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +000061 const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
62 const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010063 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
64 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
65 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
66
67 /* To make sure we have plenty of accumulators, compute two rows at a
68 * time. If the number of rows is odd, compute the bottom row twice to
69 * avoid needing a duplicate codepath. */
70 for (unsigned int row=0; row<height; row+=2) {
71 /* Prefer to do 4 vectors (16 values) at once as this collapses
72 * neatly to a single vector of output, failing that a vector at a
73 * time and then the odd ones out at the end. */
74 unsigned int blocks=(width / 16);
75 unsigned int regs=(width % 16) / 4;
76 unsigned int odds=(width % 4);
77
78 const int32_t *colptr = col_bias;
Michalis Spyrou71ac9032019-11-14 14:31:44 +000079 const int32_t *perch_mul_ptr = qp.per_channel_muls;
80 const int32_t *perch_shift_ptr = qp.per_channel_shifts;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010081
82 const int32_t *in_ptr = input + (row * in_stride);
83 int8_t *out_ptr = output + (row * out_stride);
84 int32_t row_sum = row_bias[row];
85
86 const int32_t *in_ptr1;
87 int8_t *out_ptr1;
88 int32_t row_sum1;
89
90 if (row == height-1) {
91 in_ptr1 = in_ptr;
92 out_ptr1 = out_ptr;
93 row_sum1 = row_sum;
94 } else {
95 in_ptr1 = in_ptr + in_stride;
96 out_ptr1 = out_ptr + out_stride;
97 row_sum1 = row_bias[row+1];
98 }
99
100 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
101 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
102
103 while (blocks--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000104 int32x4_t v_mul0;
105 int32x4_t v_mul1;
106 int32x4_t v_mul2;
107 int32x4_t v_mul3;
108
109 int32x4_t v_shf0;
110 int32x4_t v_shf1;
111 int32x4_t v_shf2;
112 int32x4_t v_shf3;
113
114 if (per_channel) {
115 v_mul0 = vld1q_s32(perch_mul_ptr);
116 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
117 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
118 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
119 perch_mul_ptr += 16;
120
121 v_shf0 = vld1q_s32(perch_shift_ptr);
122 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
123 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
124 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
125 perch_shift_ptr += 16;
126 } else {
127 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
128 v_shf0=v_shf1=v_shf2=v_shf3=v_shift;
129 }
130
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100131 // Load column pointers
132 int32x4_t v_col0 = vld1q_s32(colptr);
133 int32x4_t v_col1 = vld1q_s32(colptr + 4);
134 int32x4_t v_col2 = vld1q_s32(colptr + 8);
135 int32x4_t v_col3 = vld1q_s32(colptr + 12);
136 colptr += 16;
137
138 // Load input data (row 0);
139 int32x4_t v_in00 = vld1q_s32(in_ptr);
140 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
141 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
142 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
143 in_ptr += 16;
144
145 // Load input data (row 1);
146 int32x4_t v_in10 = vld1q_s32(in_ptr1);
147 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
148 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
149 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
150 in_ptr1 += 16;
151
152 // Add on row bias and column bias
153 v_in00 = vaddq_s32(v_in00, v_row_sum);
154 v_in01 = vaddq_s32(v_in01, v_row_sum);
155 v_in02 = vaddq_s32(v_in02, v_row_sum);
156 v_in03 = vaddq_s32(v_in03, v_row_sum);
157
158 v_in10 = vaddq_s32(v_in10, v_row_sum1);
159 v_in11 = vaddq_s32(v_in11, v_row_sum1);
160 v_in12 = vaddq_s32(v_in12, v_row_sum1);
161 v_in13 = vaddq_s32(v_in13, v_row_sum1);
162
163 v_in00 = vaddq_s32(v_in00, v_col0);
164 v_in01 = vaddq_s32(v_in01, v_col1);
165 v_in02 = vaddq_s32(v_in02, v_col2);
166 v_in03 = vaddq_s32(v_in03, v_col3);
167
168 v_in10 = vaddq_s32(v_in10, v_col0);
169 v_in11 = vaddq_s32(v_in11, v_col1);
170 v_in12 = vaddq_s32(v_in12, v_col2);
171 v_in13 = vaddq_s32(v_in13, v_col3);
172
173 // Quantize - start with multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000174 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
175 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
176 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
177 v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100178
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000179 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
180 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
181 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
182 v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100183
184 // Compute and add on corrective offset
185 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000186 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
187 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
188 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
189 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100190
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000191 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
192 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
193 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
194 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100195
196 v_temp00 = vshrq_n_s32(v_temp00, 31);
197 v_temp01 = vshrq_n_s32(v_temp01, 31);
198 v_temp02 = vshrq_n_s32(v_temp02, 31);
199 v_temp03 = vshrq_n_s32(v_temp03, 31);
200
201 v_temp10 = vshrq_n_s32(v_temp10, 31);
202 v_temp11 = vshrq_n_s32(v_temp11, 31);
203 v_temp12 = vshrq_n_s32(v_temp12, 31);
204 v_temp13 = vshrq_n_s32(v_temp13, 31);
205
206 v_in00 = vqaddq_s32(v_in00, v_temp00);
207 v_in01 = vqaddq_s32(v_in01, v_temp01);
208 v_in02 = vqaddq_s32(v_in02, v_temp02);
209 v_in03 = vqaddq_s32(v_in03, v_temp03);
210
211 v_in10 = vqaddq_s32(v_in10, v_temp10);
212 v_in11 = vqaddq_s32(v_in11, v_temp11);
213 v_in12 = vqaddq_s32(v_in12, v_temp12);
214 v_in13 = vqaddq_s32(v_in13, v_temp13);
215 }
216
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000217 v_in00 = vrshlq_s32(v_in00, v_shf0);
218 v_in01 = vrshlq_s32(v_in01, v_shf1);
219 v_in02 = vrshlq_s32(v_in02, v_shf2);
220 v_in03 = vrshlq_s32(v_in03, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100221
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000222 v_in10 = vrshlq_s32(v_in10, v_shf0);
223 v_in11 = vrshlq_s32(v_in11, v_shf1);
224 v_in12 = vrshlq_s32(v_in12, v_shf2);
225 v_in13 = vrshlq_s32(v_in13, v_shf3);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100226
227 v_in00 = vaddq_s32(v_in00, v_c_offset);
228 v_in01 = vaddq_s32(v_in01, v_c_offset);
229 v_in02 = vaddq_s32(v_in02, v_c_offset);
230 v_in03 = vaddq_s32(v_in03, v_c_offset);
231
232 v_in10 = vaddq_s32(v_in10, v_c_offset);
233 v_in11 = vaddq_s32(v_in11, v_c_offset);
234 v_in12 = vaddq_s32(v_in12, v_c_offset);
235 v_in13 = vaddq_s32(v_in13, v_c_offset);
236
237 v_in00 = vmaxq_s32(v_in00, v_minval);
238 v_in01 = vmaxq_s32(v_in01, v_minval);
239 v_in02 = vmaxq_s32(v_in02, v_minval);
240 v_in03 = vmaxq_s32(v_in03, v_minval);
241
242 v_in10 = vmaxq_s32(v_in10, v_minval);
243 v_in11 = vmaxq_s32(v_in11, v_minval);
244 v_in12 = vmaxq_s32(v_in12, v_minval);
245 v_in13 = vmaxq_s32(v_in13, v_minval);
246
247 v_in00 = vminq_s32(v_in00, v_maxval);
248 v_in01 = vminq_s32(v_in01, v_maxval);
249 v_in02 = vminq_s32(v_in02, v_maxval);
250 v_in03 = vminq_s32(v_in03, v_maxval);
251
252 v_in10 = vminq_s32(v_in10, v_maxval);
253 v_in11 = vminq_s32(v_in11, v_maxval);
254 v_in12 = vminq_s32(v_in12, v_maxval);
255 v_in13 = vminq_s32(v_in13, v_maxval);
256
257 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
258 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
259
260 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
261 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
262
263 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
264 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
265
266 vst1q_s8(out_ptr, v_uz0);
267 out_ptr += 16;
268 vst1q_s8(out_ptr1, v_uz1);
269 out_ptr1 += 16;
270 }
271
272 while (regs--) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000273 int32x4_t v_mul0;
274 int32x4_t v_shf0;
275
276 if (per_channel) {
277 v_mul0 = vld1q_s32(perch_mul_ptr);
278 perch_mul_ptr += 4;
279
280 v_shf0 = vld1q_s32(perch_shift_ptr);
281 perch_shift_ptr += 4;
282 } else {
283 v_mul0=v_mul;
284 v_shf0=v_shift;
285 }
286
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100287 // Load column pointers
288 int32x4_t v_col0 = vld1q_s32(colptr);
289 colptr += 4;
290
291 // Load input data (row 0);
292 int32x4_t v_in00 = vld1q_s32(in_ptr);
293 in_ptr += 4;
294
295 // Load input data (row 1);
296 int32x4_t v_in10 = vld1q_s32(in_ptr1);
297 in_ptr1 += 4;
298
299 // Add on row sum and bias constant
300 v_in00 = vaddq_s32(v_in00, v_row_sum);
301
302 v_in10 = vaddq_s32(v_in10, v_row_sum1);
303
304 // Subtract col sum * a_offset
305 v_in00 = vaddq_s32(v_in00, v_col0);
306
307 v_in10 = vaddq_s32(v_in10, v_col0);
308
309 // Quantize - start with multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000310 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100311
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000312 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100313
314 // Compute and add on corrective offset
315 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000316 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100317
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000318 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100319
320 v_temp00 = vshrq_n_s32(v_temp00, 31);
321
322 v_temp10 = vshrq_n_s32(v_temp10, 31);
323
324 v_in00 = vqaddq_s32(v_in00, v_temp00);
325
326 v_in10 = vqaddq_s32(v_in10, v_temp10);
327 }
328
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000329 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100330
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000331 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100332
333 v_in00 = vaddq_s32(v_in00, v_c_offset);
334
335 v_in10 = vaddq_s32(v_in10, v_c_offset);
336
337 v_in00 = vmaxq_s32(v_in00, v_minval);
338
339 v_in10 = vmaxq_s32(v_in10, v_minval);
340
341 v_in00 = vminq_s32(v_in00, v_maxval);
342
343 v_in10 = vminq_s32(v_in10, v_maxval);
344
345 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
346
347 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
348
349 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
350 out_ptr += 4;
351 vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
352 out_ptr1 += 4;
353 }
354
355 if (odds) {
356 int32x4_t v_col0 = vdupq_n_s32(0);
357 int32x4_t v_in00 = vdupq_n_s32(0);
358 int32x4_t v_in10 = vdupq_n_s32(0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000359 int32x4_t v_mul0 = vdupq_n_s32(0);
360 int32x4_t v_shf0 = vdupq_n_s32(0);
361
362 if (!per_channel) {
363 v_mul0 = v_mul;
364 v_shf0 = v_shift;
365 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100366
367 do {
368 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
369 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
370 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000371 if (per_channel) {
372 v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
373 v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
374 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100375 if (odds == 1) { break; }
376
377 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
378 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
379 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000380 if (per_channel) {
381 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
382 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
383 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100384 if (odds == 2) { break; }
385
386 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
387 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
388 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000389 if (per_channel) {
390 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
391 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
392 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100393 } while (0);
394
395 // Add on row sum and bias constant
396 v_in00 = vaddq_s32(v_in00, v_row_sum);
397
398 v_in10 = vaddq_s32(v_in10, v_row_sum1);
399
400 // Subtract col sum * a_offset
401 v_in00 = vaddq_s32(v_in00, v_col0);
402
403 v_in10 = vaddq_s32(v_in10, v_col0);
404
405 // Quantize - start with multiply
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000406 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100407
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000408 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100409
410 // Compute and add on corrective offset
411 if (do_shift_correction) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000412 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100413
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000414 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100415
416 v_temp00 = vshrq_n_s32(v_temp00, 31);
417
418 v_temp10 = vshrq_n_s32(v_temp10, 31);
419
420 v_in00 = vqaddq_s32(v_in00, v_temp00);
421
422 v_in10 = vqaddq_s32(v_in10, v_temp10);
423 }
424
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000425 v_in00 = vrshlq_s32(v_in00, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100426
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000427 v_in10 = vrshlq_s32(v_in10, v_shf0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100428
429 v_in00 = vaddq_s32(v_in00, v_c_offset);
430
431 v_in10 = vaddq_s32(v_in10, v_c_offset);
432
433 v_in00 = vmaxq_s32(v_in00, v_minval);
434
435 v_in10 = vmaxq_s32(v_in10, v_minval);
436
437 v_in00 = vminq_s32(v_in00, v_maxval);
438
439 v_in10 = vminq_s32(v_in10, v_maxval);
440
441 do {
442 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
443 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
444
445 if (odds==1) { break; }
446
447 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
448 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
449
450 if (odds==2) { break; }
451
452 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
453 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
454 } while(0);
455 }
456 }
457}
458
459} // anonymous namespace
460
461template<typename Tin, typename Tout>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000462void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100463 const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
464 const int32_t *row_bias, const int32_t *col_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000465 if (qp.per_channel_requant) {
466 if (qp.minval >= qp.c_offset) {
467 requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
468 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
469 } else {
470 requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
471 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
472 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100473 } else {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000474 if (qp.minval >= qp.c_offset) {
475 requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
476 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
477 } else {
478 requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
479 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
480 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100481 }
482}
483
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000484template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100485 const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
486 const int32_t *row_bias, const int32_t *col_bias);
487
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000488template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100489 const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
490 const int32_t *row_bias, const int32_t *col_bias);
491
492/*
493 * Routine (and helpers) to compute row sums needed for offset correction.
494 *
495 * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows
496 * of length 27), therefore it's important not to sacrifice performance on
497 * odd length rows.
498 *
499 * To minimize performance loss in these cases, this routine will overread
500 * by up to 7 bytes.
501 *
502 * This is handled via "mask" and "mask mode" parameters to the inner
503 * routines; mask mode == 1 indicates that are between 1 and 8 bytes
504 * (inclusive) needed at the end; in these cases we always read 8 bytes.
505 * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
506 * the end, and in this case we always read 16 bytes. In both cases the
507 * 'mask' vector is set up so that the read value can be masked off to clear
508 * the overread lanes. This is handled by 'accumulate_masked_8' and
509 * 'accumulate_masked_16' above.
510 *
511 * This routine is templated on the type to be accumulated, because the
512 * innermost instruction used needs to be of the correct signedness.
513 * However, beyond this point we always use signed values in both cases.
514 * The instructions that need to be different are therefore wrapped in
515 * helper functions below.
Michalis Spyrou400abc82019-08-20 17:25:25 +0100516 *
517 * The general strategy used is to load vectors of 16 bytes and accumulate
518 * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
519 * These are then reduced (using uadalp/sadalp again) into 4x32-bit
520 * accumulators. The 4 accumulators for up to 4 rows being processed are
521 * then added together into a single output vector using pairwise adds.
522 *
523 * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
524 * occur before the 16-bit accumulators can overflow - which is every 32
525 * iterations (512 total bytes processed). This is explained more below.
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100526 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100527namespace {
528 struct row_sum_helpers {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000529 const Requantize32 &qp;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100530
531 /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
532 template<typename T>
533 inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
534
535 /* Load a full 16 byte vector, but mask before accumulation (see above). */
536 template<typename T>
537 inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
538
539 /* Load 8 bytes and mask before accumulation. */
540 template<typename T>
541 inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
542
543 /* This function does the actual work for up to 4 rows at a time.
544 * It's pulled out so we can template on the row count to generate
545 * the 4 different cases. 4 rows are computed at a time as this
546 * reduces to a single vector write. */
547 template<unsigned int rows, typename T>
548 void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
549 int16x8_t sums[rows];
550 int32x4_t finalsums[rows];
551
552 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100553 sums[i] = vdupq_n_s16(0);
554 finalsums[i] = vdupq_n_s32(0);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100555 }
556
557 for (unsigned int i=0; i<blocks; i++) {
558 for (unsigned int r=0; r<rows; r++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100559 /* If we add too many blocks together, we run the risk
560 * of overflowing the intermediate 16-bit accumulators,
561 * especially in the unsigned case where we later treat
562 * the accumulator as signed.
563 *
564 * In that case, the maximum (signed) value is 16383,
565 * which is safe for 64 (unsigned) accumulations (255*64
566 * = 16,320).
567 *
568 * Each invocation of pairwise add adds 2 values to the
569 * accumulator - so in the unsigned case we can do 32
570 * adds before we need to reset the 16-bit accumulator
571 * by adding into the 32-bit 'finalsums'.
572 *
573 * We could do 64 adds in the signed case, but that
574 * optimization is not worth the complexity.
575 */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100576 if (i > 0 && ((i & 31) == 0)) {
577 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
578 sums[r] = vdupq_n_s16(0);
579 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100580 sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
581 }
582 }
583
584 /* Handle the final masked read if needed. */
585 if (mask_mode > 0) {
586 for (unsigned int r=0; r<rows; r++) {
587 if (mask_mode == 1) {
588 sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
589 } else {
590 sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
591 }
592 }
593 }
594
595 for (unsigned int i=0; i<rows; i++) {
Michalis Spyrou400abc82019-08-20 17:25:25 +0100596 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100597 }
598
599 int32x4_t t0, t1;
600 int32x2_t t2;
601
602 /* Result writeback - need to write back one value per row
603 * processed. Multiply all the final totals by -b_offset so
604 * that the terms can simply be added in the requantize code.
605 * */
606 switch (rows) {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100607 default:
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100608 case 1:
609 /* If we only have one output, just use ADDV. Multiply
610 * the offset into all four components separately so it
611 * can stay in the SIMD register file. */
612 t0 = vmulq_s32(finalsums[0], offset_mul);
613 *row_bias = vaddvq_s32(t0);
614 break;
615
616 case 2:
617 /* For two outputs, two rounds of pairwise adds will
618 * generate the result in a 2-vector we can store in one
619 * go. */
620 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
621 t0 = vpaddq_s32(t0, t0);
622 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
623 vst1_s32(row_bias, t2);
624 break;
625
626 case 3:
627 /* Three rows - need to store the low two words plus the odd value from lane 2 */
628 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
629 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
630
631 t0 = vpaddq_s32(t0, t1);
632 t0 = vmulq_s32(t0, offset_mul);
633
634 vst1_s32(row_bias, vget_low_s32(t0));
635 row_bias[2] = vgetq_lane_s32(t0, 2);
636 break;
637
638 case 4:
639 /* Four rows (most common case) - reduce to a single
640 * vector with pairwise adds. */
641 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
642 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
643
644 t0 = vpaddq_s32(t0, t1);
645 t0 = vmulq_s32(t0, offset_mul);
646
647 vst1q_s32(row_bias, t0);
648 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100649 }
650 }
651
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000652 row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100653 };
654
655 template<>
656 int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
657 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
658 }
659
660 template<>
661 int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
662 return vpadalq_s8(sum, vld1q_s8(ptr));
663 }
664
665 template<>
666 int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
667 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
668 return vpadalq_s8(sum, v);
669 }
670
671 template<>
672 int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
673 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
674 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
675 }
676
677 template<>
678 int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
679 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
680 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
681 return vpadalq_s8(sum, v);
682 }
683
684 template<>
685 int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
686 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
687 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
688 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
689 }
690}
691
692template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000693void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100694 const T *input, unsigned int in_stride, int32_t *row_bias) {
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000695 /* If the 'b' offset is zero, just skip this entirely. */
696 if (qp.b_offset == 0) {
697 memset(row_bias, 0, height * sizeof(int32_t));
698 return;
699 }
700
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100701 row_sum_helpers thehelpers(qp);
702
703 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
704
705 /* Work out how many full vectors of 16 bytes we will read, and how many
706 * odd bytes at the end */
707 unsigned int blocks = (width / 16);
708 const unsigned int odds = width % 16;
709
710 /* Generate a mask to use on the last iteration, if necessary. */
711 uint64x2_t mask;
712 unsigned int mask_mode = 0;
713
714 if (odds > 0 && odds <= 8) {
715 /* 1-8 odds: mask in the low lane, 0 in the top */
716 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
717
718 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
719
720 mask_mode = 1;
721 } else if (odds > 8) {
722 /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
723 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
724
725 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
726
727 mask_mode = 2;
728 }
729
730 for (unsigned int row=0; row<height; row+=4) {
731 switch(height-row) {
732 default:
733 case 4:
734 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
735 break;
736 case 3:
737 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
738 break;
739 case 2:
740 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
741 break;
742 case 1:
743 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
744 break;
745 }
746 }
747}
748
749/* Instantiate the two versions for uint8_t and int8_t. */
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000750template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
751template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100752
753template<unsigned int active_rows, typename T>
754inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
755
756template<unsigned int active_rows>
757inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
758 uint8x16_t inputs[4];
759
760 for (unsigned int i=0; i<4; i++) {
761 if (i < active_rows) {
762 inputs[i] = vld1q_u8(input + i * in_stride);
763 } else {
764 inputs[i] = vdupq_n_u8(0);
765 }
766 }
767
768 int16x8_t sums_16b[4];
769
770 // Two adds for the low pairs
771 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
772 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
773 // Two adds for the high pairs
774 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
775 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
776
777 int32x4_t sums_32b[4];
778
779 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
780 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
781 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
782 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
783
784 for (unsigned int i=0; i<4; i++) {
785 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
786 }
787}
788
789template<unsigned int active_rows>
790inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
791 int8x16_t inputs[4];
792
793 for (unsigned int i=0; i<4; i++) {
794 if (i < active_rows) {
795 inputs[i] = vld1q_s8(input + i * in_stride);
796 } else {
797 inputs[i] = vdupq_n_s8(0);
798 }
799 }
800
801 int16x8_t sums_16b[4];
802
803 // Two adds for the low pairs
804 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
805 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
806 // Two adds for the high pairs
807 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
808 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
809
810 int32x4_t sums_32b[4];
811
812 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
813 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
814 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
815 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
816
817 for (unsigned int i=0; i<4; i++) {
818 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
819 }
820}
821
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100822/* "first_col" parameter is used to offset the read into the qp.bias array,
823 * in cases where we are not computing the first columns of the output (i.e.
824 * in multithreaded cases where we divide columns across threads) */
825template<typename T>
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000826void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
827 /* Only actually add up the columns if a_offset is non-zero. */
828 if (qp.a_offset != 0) {
829 memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100830
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000831 for (unsigned int row=0; row<height; row+=4) {
832 unsigned int numrows=std::min(height-row, 4u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100833
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000834 for (unsigned int col=0; col<width; col+=16) {
835 unsigned int numcols=std::min(width-col, 16u);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100836
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000837 if (numcols==16) {
838 switch(numrows) {
839 default:
840 case 1:
841 add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
842 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100843
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000844 case 2:
845 add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
846 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100847
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000848 case 3:
849 add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
850 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100851
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000852 case 4:
853 add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
854 break;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100855 }
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000856 } else {
857 for (; col<width; col++) {
858 int32_t sum=0;
859 for (unsigned int r=0; r<numrows; r++) {
860 sum += input[(row + r)*in_stride + col];
861 }
862 col_bias[col] += sum;
863 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100864 }
865 }
866 }
867 }
868
869 for (unsigned int col=0; col<width; col++) {
870 int32_t result = col_bias[col];
871
872 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
873
874 if (qp.bias != nullptr) {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100875 result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100876 }
877
878 col_bias[col] = result;
879 }
880}
881
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000882template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
883template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100884
885} // namespace arm_gemm
Georgios Pinitasf33484f2019-07-29 12:40:59 +0100886
887#endif // __aarch64__