blob: 400c6d9d8cb735f9bd1cf498b93389a76875f07b [file] [log] [blame]
Gian Marco Iodiceab182122017-10-09 15:05:40 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMLowpFinalizeKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Types.h"
32#include "arm_compute/core/Utils.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
35
36#include <arm_neon.h>
37#include <cstddef>
38#include <cstdint>
39
40using namespace arm_compute;
41
42namespace arm_compute
43{
44class Coordinates;
45} // namespace arm_compute
46
47template <bool add_a_offset, bool add_b_offset>
48void NEGEMMLowpFinalizeKernel::finalize(const Window &window)
49{
50 const int32x4_t c_offset_s32 = vdupq_n_s32(_c_offset);
51 const int32x4_t shift_s32 = vdupq_n_s32(-_shift);
52
53 Window collapsed_window = window.collapse_if_possible(IKernel::window(), Window::DimZ);
54
55 if(add_a_offset && add_b_offset) // true, true
56 {
57 // Set window for vector_sum_col
58 Window win_vector_sum_col(collapsed_window);
59 win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
60 if(!_slide_vector_sum_col)
61 {
62 win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
63 }
64
65 // Set window for vector_sum_row
66 Window win_vector_sum_row(collapsed_window);
67 win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
68 win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
69
70 Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col);
71 Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
72 Iterator mm_result(_mm_result, window);
73 Iterator out(_output, window);
74
75 execute_window_loop(window, [&](const Coordinates & id)
76 {
77 // Compute the leftover term due to a_offset.
78 int32x4x4_t a_offset_term_s32 =
79 {
80 {
81 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 0),
82 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 4),
83 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 8),
84 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 12)
85 }
86 };
87
88 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], _a_offset);
89 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], _a_offset);
90 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], _a_offset);
91 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset);
92
93 // Compute the leftover term due to b_offset.
94 int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
95 b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
96
97 // Add a_offset_term_s32 and b_offset_term_s32
98 int32x4x4_t offset_term_s32 =
99 {
100 {
101 vdupq_n_s32(_k_offset),
102 vdupq_n_s32(_k_offset),
103 vdupq_n_s32(_k_offset),
104 vdupq_n_s32(_k_offset)
105 }
106 };
107
108 offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32));
109 offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32));
110 offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32));
111 offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32));
112
113 // Add c_offset
114 offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], c_offset_s32);
115 offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], c_offset_s32);
116 offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], c_offset_s32);
117 offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], c_offset_s32);
118
119 int32x4x4_t in_s32 =
120 {
121 {
122 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
123 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
124 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
125 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
126 }
127 };
128
129 // Add the offset terms to GEMM's result
130 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
131 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
132 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
133 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
134
135 // Multiply by c_mult_int
136 in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
137 in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
138 in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
139 in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
140
141 // Shift final result (negative value shift right)
142 in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
143 in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
144 in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
145 in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
146
147 // Convert S32 to U16
148 const int16x8x2_t in_u16 =
149 {
150 {
151 vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
152 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3])),
153 }
154 };
155
156 // Convert U16 to U8
157 const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
158
159 vst1q_u8(out.ptr(), out_u8);
160 },
161 vector_sum_col, vector_sum_row, mm_result, out);
162 }
163 else if(!add_a_offset && add_b_offset) // false, true
164 {
165 // Set window for vector_sum_row
166 Window win_vector_sum_row(collapsed_window);
167 win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
168 win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
169
170 Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
171 Iterator mm_result(_mm_result, window);
172 Iterator out(_output, window);
173
174 execute_window_loop(window, [&](const Coordinates & id)
175 {
176 // Compute the leftover term due to b_offset.
177 int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
178 b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
179
180 // Add b_offset_term_s32 and c_offset_term_s32
181 int32x4_t offset_term_s32 = vaddq_s32(b_offset_term_s32, c_offset_s32);
182
183 int32x4x4_t in_s32 =
184 {
185 {
186 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
187 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
188 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
189 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
190 }
191 };
192
193 // Add the offset terms to GEMM's result
194 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32);
195 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32);
196 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32);
197 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32);
198
199 // Multiply by c_mult_int
200 in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
201 in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
202 in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
203 in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
204
205 // Shift final result (negative value shift right)
206 in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
207 in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
208 in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
209 in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
210
211 // Convert S32 to U16
212 const int16x8x2_t in_u16 =
213 {
214 {
215 vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
216 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3])),
217 }
218 };
219
220 // Convert U16 to U8
221 const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
222
223 vst1q_u8(out.ptr(), out_u8);
224 },
225 vector_sum_row, mm_result, out);
226 }
227 else if(add_a_offset && !add_b_offset) // true, false
228 {
229 // Set window for vector_sum_col
230 Window win_vector_sum_col(collapsed_window);
231 win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
232 if(!_slide_vector_sum_col)
233 {
234 win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
235 }
236
237 Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col);
238 Iterator mm_result(_mm_result, window);
239 Iterator out(_output, window);
240
241 execute_window_loop(window, [&](const Coordinates & id)
242 {
243 // Compute the leftover term due to a_offset.
244 int32x4x4_t a_offset_term_s32 =
245 {
246 {
247 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 0),
248 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 4),
249 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 8),
250 vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 12)
251 }
252 };
253
254 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], _a_offset);
255 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], _a_offset);
256 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], _a_offset);
257 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset);
258
259 // Add a_offset_term_s32 and b_offset_term_s32
260 int32x4x4_t offset_term_s32 =
261 {
262 {
263 vaddq_s32(c_offset_s32, a_offset_term_s32.val[0]),
264 vaddq_s32(c_offset_s32, a_offset_term_s32.val[1]),
265 vaddq_s32(c_offset_s32, a_offset_term_s32.val[2]),
266 vaddq_s32(c_offset_s32, a_offset_term_s32.val[3])
267 }
268 };
269
270 int32x4x4_t in_s32 =
271 {
272 {
273 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
274 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
275 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
276 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
277 }
278 };
279
280 // Add the offset terms to GEMM's result
281 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
282 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
283 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
284 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
285
286 // Multiply by c_mult_int
287 in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
288 in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
289 in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
290 in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
291
292 // Shift final result (negative value shift right)
293 in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
294 in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
295 in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
296 in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
297
298 // Convert S32 to U16
299 const int16x8x2_t in_u16 =
300 {
301 {
302 vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
303 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
304 }
305 };
306
307 // Convert U16 to U8
308 const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
309
310 vst1q_u8(out.ptr(), out_u8);
311 },
312 vector_sum_col, mm_result, out);
313 }
314 else // false, false
315 {
316 Iterator mm_result(_mm_result, window);
317 Iterator out(_output, window);
318
319 execute_window_loop(window, [&](const Coordinates & id)
320 {
321 int32x4x4_t in_s32 =
322 {
323 {
324 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
325 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
326 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
327 vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
328 }
329 };
330
331 // Add the offset terms to GEMM's result
332 in_s32.val[0] = vaddq_s32(in_s32.val[0], c_offset_s32);
333 in_s32.val[1] = vaddq_s32(in_s32.val[1], c_offset_s32);
334 in_s32.val[2] = vaddq_s32(in_s32.val[2], c_offset_s32);
335 in_s32.val[3] = vaddq_s32(in_s32.val[3], c_offset_s32);
336
337 // Multiply by c_mult_int
338 in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
339 in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
340 in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
341 in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
342
343 // Shift final result (negative value shift right)
344 in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
345 in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
346 in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
347 in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
348
349 // Convert S32 to U16
350 const int16x8x2_t in_u16 =
351 {
352 {
353 vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
354 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
355 }
356 };
357
358 // Convert U16 to U8
359 const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
360
361 vst1q_u8(out.ptr(), out_u8);
362 },
363 mm_result, out);
364 }
365}
366
367NEGEMMLowpFinalizeKernel::NEGEMMLowpFinalizeKernel()
368 : _func(nullptr), _vector_sum_col(nullptr), _vector_sum_row(nullptr), _mm_result(nullptr), _output(nullptr), _a_offset(0), _b_offset(0), _c_offset(0), _k_offset(0), _c_mult_int(0), _shift(0),
369 _slide_vector_sum_col(true)
370{
371}
372
373void NEGEMMLowpFinalizeKernel::configure(const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *mm_result, ITensor *output, int32_t num_mtx_a_cols, int32_t a_offset,
374 int32_t b_offset,
375 int32_t c_offset, int32_t c_mult_int, int32_t shift)
376{
377 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
378 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
379
380 TensorShape mm_result_shape = mm_result->info()->tensor_shape();
381 TensorShape output_shape = output->info()->tensor_shape();
382
383 mm_result_shape.collapse(2);
384 output_shape.collapse(2);
385
386 ARM_COMPUTE_ERROR_ON_MSG(mm_result_shape[2] != output_shape[2], "mm_result tensor must have the same number of batches of output tensor");
387
388 // If a_offset == 0, vector_sum_col can be a nullptr
389 if(a_offset != 0)
390 {
391 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
392 ARM_COMPUTE_ERROR_ON(vector_sum_col->info()->dimension(0) != mm_result->info()->dimension(0));
393
394 TensorShape vector_sum_col_shape = vector_sum_col->info()->tensor_shape();
395 vector_sum_col_shape.collapse(1);
396
397 // Check if vector_sum_col_shape should be slidden or not
398 // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
399 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
400 _slide_vector_sum_col = vector_sum_col_shape[1] != 1;
401 }
402
403 // If b_offset == 0, vector_sum_row can be a nullptr
404 if(b_offset != 0)
405 {
406 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
407 ARM_COMPUTE_ERROR_ON(vector_sum_row->info()->dimension(0) != mm_result->info()->dimension(1));
408
409 TensorShape vector_sum_row_shape = vector_sum_row->info()->tensor_shape();
410 vector_sum_row_shape.collapse(1);
411
412 ARM_COMPUTE_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[2], "mm_result tensor must have the same number of batches of output tensor");
413
414 if(a_offset != 0)
415 {
416 TensorShape vector_sum_col_shape = vector_sum_col->info()->tensor_shape();
417 vector_sum_col_shape.collapse(1);
418
419 ARM_COMPUTE_ERROR_ON_MSG(vector_sum_col_shape[1] != 1
420 && vector_sum_col_shape[1] != vector_sum_row_shape[1],
421 "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
422 }
423 }
424
425 _vector_sum_col = vector_sum_col;
426 _vector_sum_row = vector_sum_row;
427 _mm_result = mm_result;
428 _output = output;
429 _a_offset = a_offset;
430 _b_offset = b_offset;
431 _k_offset = a_offset * b_offset * num_mtx_a_cols;
432 _c_offset = c_offset;
433 _c_mult_int = c_mult_int;
434 _shift = shift;
435
436 constexpr unsigned int num_elems_processed_per_iteration = 16;
437
438 // Configure kernel window
439 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
440
441 AccessWindowHorizontal mm_result_access(mm_result->info(), 0, num_elems_processed_per_iteration);
442 AccessWindowHorizontal output_result_access(output->info(), 0, num_elems_processed_per_iteration);
443
444 // Accordingly with a_offset and b_offset, we can have 4 cases:
445 // a_offset != 0 && b_offset != 0
446 // a_offset = 0 && b_offset != 0
447 // a_offset != 0 && b_offset = 0
448 // a_offset = 0 && b_offset = 0
449 if(a_offset != 0 && b_offset != 0)
450 {
451 // Set the function to use
452 _func = &NEGEMMLowpFinalizeKernel::finalize<true, true>;
453
454 AccessWindowStatic vector_sum_row_access(vector_sum_row->info(), 0, 0, vector_sum_row->info()->dimension(0), 0);
455 AccessWindowHorizontal vector_sum_col_access(vector_sum_col->info(), 0, num_elems_processed_per_iteration);
456
457 update_window_and_padding(win,
458 vector_sum_col_access,
459 vector_sum_row_access,
460 mm_result_access,
461 output_result_access);
462 }
463 else if(a_offset == 0 && b_offset != 0)
464 {
465 // Set the function to use
466 _func = &NEGEMMLowpFinalizeKernel::finalize<false, true>;
467
468 AccessWindowStatic vector_sum_row_access(vector_sum_row->info(), 0, 0, vector_sum_row->info()->dimension(0), 0);
469
470 update_window_and_padding(win,
471 vector_sum_row_access,
472 mm_result_access,
473 output_result_access);
474 }
475 else if(a_offset != 0 && b_offset == 0)
476 {
477 // Set the function to use
478 _func = &NEGEMMLowpFinalizeKernel::finalize<true, false>;
479
480 AccessWindowHorizontal vector_sum_col_access(vector_sum_col->info(), 0, num_elems_processed_per_iteration);
481
482 update_window_and_padding(win,
483 vector_sum_col_access,
484 mm_result_access,
485 output_result_access);
486 }
487 else
488 {
489 // Set the function to use
490 _func = &NEGEMMLowpFinalizeKernel::finalize<false, false>;
491
492 update_window_and_padding(win,
493 mm_result_access,
494 output_result_access);
495 }
496
497 output_result_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
498
499 INEKernel::configure(win);
500}
501
502void NEGEMMLowpFinalizeKernel::run(const Window &window, const ThreadInfo &info)
503{
504 ARM_COMPUTE_UNUSED(info);
505 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
506 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
507
508 (this->*_func)(window);
509}