blob: 9566ced768a5ac69f25a1007556befca040d44c7 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +01002 * Copyright (c) 2016-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NENonMaximaSuppression3x3Kernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035
36#include <arm_neon.h>
37#include <cstddef>
38
39using namespace arm_compute;
40
41namespace arm_compute
42{
43class Coordinates;
44} // namespace arm_compute
45
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000046#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047namespace fp16
48{
49inline void mask_top(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
50{
51 // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
52 mask = vandq_u16(mask, vcgeq_f16(vc, in0));
53 mask = vandq_u16(mask, vcgeq_f16(vc, vextq_f16(in0, in1, 1)));
54 mask = vandq_u16(mask, vcgeq_f16(vc, vextq_f16(in0, in1, 2)));
55}
56
57inline void mask_middle(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
58{
59 // vc >= nc.val[0], vc > nc.val[2]
60 mask = vandq_u16(mask, vcgeq_f16(vc, in0));
61 mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 2)));
62}
63
64inline void mask_bottom(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
65{
66 // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
67 mask = vandq_u16(mask, vcgtq_f16(vc, in0));
68 mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 1)));
69 mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 2)));
70}
71
72inline void non_maxima_suppression3x3_F32_F32(const void *__restrict in_ptr, void *__restrict out_ptr, const uint32_t in_stride)
73{
74 auto in = static_cast<const float *__restrict>(in_ptr) - 1;
75 const auto out = static_cast<float *__restrict>(out_ptr);
76
77 // Get centre scores
78 const float16x8x2_t vc =
79 {
80 vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 1)), vcvt_f16_f32(vld1q_f32(in + 5))),
81 vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 9)), vcvt_f16_f32(vld1q_f32(in + 13)))
82 };
83
84 // Neighboring pixels
85 in -= in_stride;
86
87 static const float16x4_t zero_f16x4 = vdup_n_f16(0);
88 static const uint16x8_t zero_u16 = vdupq_n_u16(0);
89 static const uint16x8_t true_mask = vceqq_u16(zero_u16, zero_u16);
90 static const uint16x8x2_t true_mask_x2 =
91 {
92 true_mask,
93 true_mask
94 };
95
96 uint16x8x2_t mask = true_mask_x2;
97
98 // Top row
99 const float16x8_t tmp_top0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
100 const float16x8_t tmp_top1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
101 const float16x8_t tmp_top2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
102
103 // vc >= nc.val[0], vc >= nc.val[1], vc >= nc.val[2]
104 mask_top(vc.val[0], tmp_top0, tmp_top1, mask.val[0]);
105 mask_top(vc.val[1], tmp_top1, tmp_top2, mask.val[1]);
106
107 in += in_stride;
108
109 // Middle row
110 const float16x8_t tmp_mid0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
111 const float16x8_t tmp_mid1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
112 const float16x8_t tmp_mid2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
113
114 // vc >= nc.val[0], vc > nc.val[2]
115 mask_middle(vc.val[0], tmp_mid0, tmp_mid1, mask.val[0]);
116 mask_middle(vc.val[1], tmp_mid1, tmp_mid2, mask.val[1]);
117
118 in += in_stride;
119
120 // Bottom row
121 const float16x8_t tmp_bot0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
122 const float16x8_t tmp_bot1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
123 const float16x8_t tmp_bot2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
124
125 // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
126 mask_bottom(vc.val[0], tmp_bot0, tmp_bot1, mask.val[0]);
127 mask_bottom(vc.val[1], tmp_bot1, tmp_bot2, mask.val[1]);
128
129 // Store
130 static const float16x8_t zero_f16x8 = vdupq_n_f16(0);
131
132 const float16x8_t suppressed0 = vbslq_f16(mask.val[0], vc.val[0], zero_f16x8);
133 vst1q_f32(out + 0, vcvt_f32_f16(vget_low_f16(suppressed0)));
134 vst1q_f32(out + 4, vcvt_f32_f16(vget_high_f16(suppressed0)));
135
136 const float16x8_t suppressed1 = vbslq_f16(mask.val[1], vc.val[1], zero_f16x8);
137 vst1q_f32(out + 8, vcvt_f32_f16(vget_low_f16(suppressed1)));
138 vst1q_f32(out + 12, vcvt_f32_f16(vget_high_f16(suppressed1)));
139}
140
141inline void non_maxima_suppression3x3_U8_U8(const void *__restrict in_ptr, void *__restrict out_ptr, const uint32_t in_stride)
142{
143 auto in = static_cast<const uint8_t *__restrict>(in_ptr) - 1;
144 const auto out = static_cast<uint8_t *__restrict>(out_ptr);
145
146 // Get centre scores
147 const uint8x16_t vc = vld1q_u8(in + 1);
148
149 // Neighboring pixels
150 in -= in_stride;
151
152 // Top row
153 const uint8x16_t l_nc_0 = vld1q_u8(in);
154 const uint8x16_t m_nc_0 = vld1q_u8(in + 1);
155 const uint8x16_t r_nc_0 = vld1q_u8(in + 2);
156
157 // Keep center scores if ...
158 // vc >= l_nc_0, vc >= m_nc_0, vc >= r_nc_0
159 uint8x16_t mask = vcgeq_u8(vc, l_nc_0);
160 mask = vandq_u8(mask, vcgeq_u8(vc, m_nc_0));
161 mask = vandq_u8(mask, vcgeq_u8(vc, r_nc_0));
162
163 in += in_stride;
164
165 // Middle row
166 const uint8x16_t l_nc_1 = vld1q_u8(in);
167 const uint8x16_t r_nc_1 = vld1q_u8(in + 2);
168
169 // ... and ...
170 // vc >= l_nc_1, vc > r_nc_1
171 mask = vandq_u8(mask, vcgeq_u8(vc, l_nc_1));
172 mask = vandq_u8(mask, vcgtq_u8(vc, r_nc_1));
173
174 in += in_stride;
175
176 // Bottom row
177 const uint8x16_t l_nc_2 = vld1q_u8(in);
178 const uint8x16_t m_nc_2 = vld1q_u8(in + 1);
179 const uint8x16_t r_nc_2 = vld1q_u8(in + 2);
180
181 // ... and ...
182 // vc > l_nc_2, vc > m_nc_2, vc > r_nc_2
183 mask = vandq_u8(mask, vcgtq_u8(vc, l_nc_2));
184 mask = vandq_u8(mask, vcgtq_u8(vc, m_nc_2));
185 mask = vandq_u8(mask, vcgtq_u8(vc, r_nc_2));
186
187 // Store
188 static const uint8x16_t zero = vdupq_n_u8(0);
189 vst1q_u8(out, vbslq_u8(mask, vc, zero));
190}
191} // namespace fp16
192
193void NENonMaximaSuppression3x3FP16Kernel::configure(const ITensor *input, ITensor *output, bool border_undefined)
194{
195 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F32);
196 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::F32);
197 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
198
199 _input = input;
200 _output = output;
201
202 switch(input->info()->data_type())
203 {
204 case DataType::U8:
205 _func = &fp16::non_maxima_suppression3x3_U8_U8;
206 break;
207 default:
208 _func = &fp16::non_maxima_suppression3x3_F32_F32;
209 break;
210 }
211
212 constexpr unsigned int num_elems_processed_per_iteration = 16;
213 const unsigned int num_elems_read_per_iteration = 16 + 2 * border_size().left + (input->info()->data_type() == DataType::U8 ? 0 : 3);
214 constexpr unsigned int num_elems_written_per_iteration = 16;
215 constexpr unsigned int num_rows_read_per_iteration = 3;
216
217 // Configure kernel window
218 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration), border_undefined, border_size());
219 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_iteration);
220
221 update_window_and_padding(win,
222 AccessWindowRectangle(input->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, num_rows_read_per_iteration),
223 output_access);
224
225 output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
226
227 INEKernel::configure(win);
228}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000229#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100230
231namespace
232{
233inline void non_maxima_suppression3x3_FLOAT_FLOAT(const void *__restrict input_ptr, void *__restrict output_ptr, const uint32_t input_stride)
234{
235 auto input = static_cast<const float *__restrict>(input_ptr) - 1;
236 const auto output = static_cast<float *__restrict>(output_ptr);
237
238 // Get centre scores
239 const float32x4x4_t vc =
240 {
241 {
242 vld1q_f32(input + 1),
243 vld1q_f32(input + 5),
244 vld1q_f32(input + 9),
245 vld1q_f32(input + 13)
246 }
247 };
248
249 // Neighboring pixels
250 float32x4x4_t l_nc{ {} };
251 float32x4x4_t m_nc{ {} };
252 float32x4x4_t r_nc{ {} };
253
254 input -= input_stride;
255
256 // Row0 - Low part
257 float32x4_t tmp_low = vld1q_f32(input);
258 float32x4_t tmp_high = vld1q_f32(input + 4);
259 float32x4_t tmp_high1 = vld1q_f32(input + 8);
260
261 l_nc.val[0] = tmp_low;
262 m_nc.val[0] = vextq_f32(tmp_low, tmp_high, 1);
263 r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
264
265 tmp_low = tmp_high;
266 tmp_high = tmp_high1;
267
268 l_nc.val[1] = tmp_low;
269 m_nc.val[1] = vextq_f32(tmp_low, tmp_high, 1);
270 r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
271
272 // Row0 - High part
273 tmp_low = tmp_high1;
274 tmp_high = vld1q_f32(input + 12);
275 tmp_high1 = vld1q_f32(input + 16);
276
277 l_nc.val[2] = tmp_low;
278 m_nc.val[2] = vextq_f32(tmp_low, tmp_high, 1);
279 r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
280
281 tmp_low = tmp_high;
282 tmp_high = tmp_high1;
283
284 l_nc.val[3] = tmp_low;
285 m_nc.val[3] = vextq_f32(tmp_low, tmp_high, 1);
286 r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
287
288 // mc >= nc.val[0], mc >= nc.val[1], mc >= nc.val[2]
289 uint32x4x4_t mask{ {} };
290 mask.val[0] = vcgeq_f32(vc.val[0], l_nc.val[0]);
291 mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], m_nc.val[0]));
292 mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], r_nc.val[0]));
293 mask.val[1] = vcgeq_f32(vc.val[1], l_nc.val[1]);
294 mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], m_nc.val[1]));
295 mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], r_nc.val[1]));
296 mask.val[2] = vcgeq_f32(vc.val[2], l_nc.val[2]);
297 mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], m_nc.val[2]));
298 mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], r_nc.val[2]));
299 mask.val[3] = vcgeq_f32(vc.val[3], l_nc.val[3]);
300 mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], m_nc.val[3]));
301 mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], r_nc.val[3]));
302
303 input += input_stride;
304
305 // Row1 - Low part
306 tmp_low = vld1q_f32(input);
307 tmp_high = vld1q_f32(input + 4);
308 tmp_high1 = vld1q_f32(input + 8);
309
310 l_nc.val[0] = tmp_low;
311 r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
312
313 tmp_low = tmp_high;
314 tmp_high = tmp_high1;
315
316 l_nc.val[1] = tmp_low;
317 r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
318
319 // Row1 - High part
320 tmp_low = tmp_high1;
321 tmp_high = vld1q_f32(input + 12);
322 tmp_high1 = vld1q_f32(input + 16);
323
324 l_nc.val[2] = tmp_low;
325 r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
326
327 tmp_low = tmp_high;
328 tmp_high = tmp_high1;
329
330 l_nc.val[3] = tmp_low;
331 r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
332
333 // mc >= nc.val[0], mc > nc.val[2]
334 mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], l_nc.val[0]));
335 mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], r_nc.val[0]));
336 mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], l_nc.val[1]));
337 mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], r_nc.val[1]));
338 mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], l_nc.val[2]));
339 mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], r_nc.val[2]));
340 mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], l_nc.val[3]));
341 mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], r_nc.val[3]));
342
343 input += input_stride;
344
345 // Row2 - Low part
346 tmp_low = vld1q_f32(input);
347 tmp_high = vld1q_f32(input + 4);
348 tmp_high1 = vld1q_f32(input + 8);
349
350 l_nc.val[0] = tmp_low;
351 m_nc.val[0] = vextq_f32(tmp_low, tmp_high, 1);
352 r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
353
354 tmp_low = tmp_high;
355 tmp_high = tmp_high1;
356
357 l_nc.val[1] = tmp_low;
358 m_nc.val[1] = vextq_f32(tmp_low, tmp_high, 1);
359 r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
360
361 // Row2 - High part
362 tmp_low = tmp_high1;
363 tmp_high = vld1q_f32(input + 12);
364 tmp_high1 = vld1q_f32(input + 16);
365
366 l_nc.val[2] = tmp_low;
367 m_nc.val[2] = vextq_f32(tmp_low, tmp_high, 1);
368 r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
369
370 tmp_low = tmp_high;
371 tmp_high = tmp_high1;
372
373 l_nc.val[3] = tmp_low;
374 m_nc.val[3] = vextq_f32(tmp_low, tmp_high, 1);
375 r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
376
377 // mc > nc.val[0], mc > nc.val[1], mc > nc.val[2]
378 mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], l_nc.val[0]));
379 mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], m_nc.val[0]));
380 mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], r_nc.val[0]));
381 mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], l_nc.val[1]));
382 mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], m_nc.val[1]));
383 mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], r_nc.val[1]));
384 mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], l_nc.val[2]));
385 mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], m_nc.val[2]));
386 mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], r_nc.val[2]));
387 mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], l_nc.val[3]));
388 mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], m_nc.val[3]));
389 mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], r_nc.val[3]));
390
391 static const float32x4_t zero = vdupq_n_f32(0.f);
392
393 // Store
394 vst1q_f32(output + 0, vbslq_f32(mask.val[0], vc.val[0], zero));
395 vst1q_f32(output + 4, vbslq_f32(mask.val[1], vc.val[1], zero));
396 vst1q_f32(output + 8, vbslq_f32(mask.val[2], vc.val[2], zero));
397 vst1q_f32(output + 12, vbslq_f32(mask.val[3], vc.val[3], zero));
398}
399
400inline void non_maxima_suppression3x3_U8_U8(const void *__restrict input_ptr, void *__restrict output_ptr, const uint32_t input_stride)
401{
402 auto input = static_cast<const uint8_t *__restrict>(input_ptr) - 1;
403 const auto output = static_cast<uint8_t *__restrict>(output_ptr);
404
405 // Get centre scores
406 const uint8x16_t vc = vld1q_u8(input + 1);
407
408 // Neighboring pixels
409 uint8x16_t l_nc{};
410 uint8x16_t m_nc{};
411 uint8x16_t r_nc{};
412
413 input -= input_stride;
414
415 // Row0
416 l_nc = vld1q_u8(input);
417 m_nc = vld1q_u8(input + 1);
418 r_nc = vld1q_u8(input + 2);
419
420 // mc >= l_nc, mc >= m_nc, mc >= r_nc
421 uint8x16_t mask = vcgeq_u8(vc, l_nc);
422 mask = vandq_u8(mask, vcgeq_u8(vc, m_nc));
423 mask = vandq_u8(mask, vcgeq_u8(vc, r_nc));
424
425 input += input_stride;
426
427 // Row1
428 l_nc = vld1q_u8(input);
429 r_nc = vld1q_u8(input + 2);
430
431 // mc >= l_nc, mc > r_nc
432 mask = vandq_u8(mask, vcgeq_u8(vc, l_nc));
433 mask = vandq_u8(mask, vcgtq_u8(vc, r_nc));
434
435 input += input_stride;
436
437 // Row2
438 l_nc = vld1q_u8(input);
439 m_nc = vld1q_u8(input + 1);
440 r_nc = vld1q_u8(input + 2);
441
442 // mc > l_nc, mc > m_nc, mc > r_nc
443 mask = vandq_u8(mask, vcgtq_u8(vc, l_nc));
444 mask = vandq_u8(mask, vcgtq_u8(vc, m_nc));
445 mask = vandq_u8(mask, vcgtq_u8(vc, r_nc));
446
447 static const uint8x16_t zero = vdupq_n_u8(0);
448
449 // Store
450 vst1q_u8(output, vbslq_u8(mask, vc, zero));
451}
452} // namespace
453
454NENonMaximaSuppression3x3Kernel::NENonMaximaSuppression3x3Kernel()
455 : _func(nullptr), _input(nullptr), _output(nullptr)
456{
457}
458
459BorderSize NENonMaximaSuppression3x3Kernel::border_size() const
460{
461 return BorderSize(1);
462}
463
464void NENonMaximaSuppression3x3Kernel::configure(const ITensor *input, ITensor *output, bool border_undefined)
465{
466 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F32);
467 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::F32);
468 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
469
470 _input = input;
471 _output = output;
472
473 if(input->info()->data_type() == DataType::U8)
474 {
475 _func = &non_maxima_suppression3x3_U8_U8;
476 }
477 else
478 {
479 _func = &non_maxima_suppression3x3_FLOAT_FLOAT;
480 }
481
482 constexpr unsigned int num_elems_processed_per_iteration = 16;
483 const unsigned int num_elems_read_per_iteration = 16 + 2 * border_size().left + (input->info()->data_type() == DataType::U8 ? 0 : 3);
484 constexpr unsigned int num_elems_written_per_iteration = 16;
485 constexpr unsigned int num_rows_read_per_iteration = 3;
486
487 // Configure kernel window
488 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration), border_undefined, border_size());
489 AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_iteration);
490
491 update_window_and_padding(win,
492 AccessWindowRectangle(input->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, num_rows_read_per_iteration),
493 output_access);
494
495 output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
496
497 INEKernel::configure(win);
498}
499
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100500void NENonMaximaSuppression3x3Kernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100501{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100502 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100503 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
504 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
505 ARM_COMPUTE_ERROR_ON(_func == nullptr);
506 Iterator input(_input, window);
507 Iterator output(_output, window);
508
509 const size_t input_stride = _input->info()->strides_in_bytes()[1] / element_size_from_data_type(_input->info()->data_type());
510
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100511 execute_window_loop(window, [&](const Coordinates &)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 {
513 _func(input.ptr(), output.ptr(), input_stride);
514 },
515 input, output);
516}