blob: a6e2b00302c44e4834772fcc5090038ee53c6d3e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NENonLinearFilterKernel.h"
25
26#include "arm_compute/core/Coordinates.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Validate.h"
32
33#include <algorithm>
34#include <arm_neon.h>
35#include <array>
36#include <tuple>
37#include <utility>
38
39namespace arm_compute
40{
41namespace
42{
43const uint8x16_t zero_u8 = vdupq_n_u8(0);
44
45template <size_t columns>
46inline uint8x8_t min_row(uint8x16_t row_data)
47{
48 uint8x8_t min = vget_low_u8(row_data);
49
50 for(size_t c = 1; c < columns; ++c)
51 {
52 row_data = vextq_u8(row_data, zero_u8, 1);
53 min = vmin_u8(min, vget_low_u8(row_data));
54 }
55
56 return min;
57}
58
59template <size_t columns>
60inline uint8x8_t max_row(uint8x16_t row_data)
61{
62 uint8x8_t max = vget_low_u8(row_data);
63
64 for(size_t c = 1; c < columns; ++c)
65 {
66 row_data = vextq_u8(row_data, zero_u8, 1);
67 max = vmax_u8(max, vget_low_u8(row_data));
68 }
69
70 return max;
71}
72
73inline void sort(uint8x8_t &a, uint8x8_t &b)
74{
75 const uint8x8_t min = vmin_u8(a, b);
76 const uint8x8_t max = vmax_u8(a, b);
77 a = min;
78 b = max;
79}
80
81// Sorting networks below were generated using http://pages.ripco.net/~jgamble/nw.html
82// Calculations that do not affect the median were removed.
83inline void sort5(uint8x8_t &p0, uint8x8_t &p1, uint8x8_t &p2, uint8x8_t &p3, uint8x8_t &p4)
84{
85 sort(p0, p1);
86 sort(p2, p3);
87 sort(p0, p2);
88 sort(p1, p3);
89 sort(p1, p2);
90 sort(p0, p4);
91 sort(p1, p4);
92 sort(p2, p4);
93}
94
95inline void sort9(uint8x8_t &p0, uint8x8_t &p1, uint8x8_t &p2,
96 uint8x8_t &p3, uint8x8_t &p4, uint8x8_t &p5,
97 uint8x8_t &p6, uint8x8_t &p7, uint8x8_t &p8)
98{
99 sort(p1, p2);
100 sort(p4, p5);
101 sort(p7, p8);
102 sort(p0, p1);
103 sort(p3, p4);
104 sort(p6, p7);
105 sort(p1, p2);
106 sort(p4, p5);
107 sort(p7, p8);
108 sort(p0, p3);
109 sort(p5, p8);
110 sort(p4, p7);
111 sort(p3, p6);
112 sort(p1, p4);
113 sort(p2, p5);
114 sort(p4, p7);
115 sort(p4, p2);
116 sort(p6, p4);
117 sort(p4, p2);
118}
119
120inline void sort21(uint8x8_t p[21])
121{
122 sort(p[0], p[1]);
123 sort(p[2], p[3]);
124 sort(p[4], p[5]);
125 sort(p[6], p[7]);
126 sort(p[8], p[9]);
127 sort(p[10], p[11]);
128 sort(p[12], p[13]);
129 sort(p[14], p[15]);
130 sort(p[16], p[17]);
131 sort(p[18], p[19]);
132 sort(p[0], p[2]);
133 sort(p[1], p[3]);
134 sort(p[4], p[6]);
135 sort(p[5], p[7]);
136 sort(p[8], p[10]);
137 sort(p[9], p[11]);
138 sort(p[12], p[14]);
139 sort(p[13], p[15]);
140 sort(p[16], p[18]);
141 sort(p[17], p[19]);
142 sort(p[1], p[2]);
143 sort(p[5], p[6]);
144 sort(p[0], p[4]);
145 sort(p[3], p[7]);
146 sort(p[9], p[10]);
147 sort(p[13], p[14]);
148 sort(p[8], p[12]);
149 sort(p[11], p[15]);
150 sort(p[17], p[18]);
151 sort(p[16], p[20]);
152 sort(p[1], p[5]);
153 sort(p[2], p[6]);
154 sort(p[9], p[13]);
155 sort(p[10], p[14]);
156 sort(p[0], p[8]);
157 sort(p[7], p[15]);
158 sort(p[17], p[20]);
159 sort(p[1], p[4]);
160 sort(p[3], p[6]);
161 sort(p[9], p[12]);
162 sort(p[11], p[14]);
163 sort(p[18], p[20]);
164 sort(p[0], p[16]);
165 sort(p[2], p[4]);
166 sort(p[3], p[5]);
167 sort(p[10], p[12]);
168 sort(p[11], p[13]);
169 sort(p[1], p[9]);
170 sort(p[6], p[14]);
171 sort(p[19], p[20]);
172 sort(p[3], p[4]);
173 sort(p[11], p[12]);
174 sort(p[1], p[8]);
175 sort(p[2], p[10]);
176 sort(p[5], p[13]);
177 sort(p[7], p[14]);
178 sort(p[3], p[11]);
179 sort(p[2], p[8]);
180 sort(p[4], p[12]);
181 sort(p[7], p[13]);
182 sort(p[1], p[17]);
183 sort(p[3], p[10]);
184 sort(p[5], p[12]);
185 sort(p[1], p[16]);
186 sort(p[2], p[18]);
187 sort(p[3], p[9]);
188 sort(p[6], p[12]);
189 sort(p[2], p[16]);
190 sort(p[3], p[8]);
191 sort(p[7], p[12]);
192 sort(p[5], p[9]);
193 sort(p[6], p[10]);
194 sort(p[4], p[8]);
195 sort(p[7], p[11]);
196 sort(p[3], p[19]);
197 sort(p[5], p[8]);
198 sort(p[7], p[10]);
199 sort(p[3], p[18]);
200 sort(p[4], p[20]);
201 sort(p[6], p[8]);
202 sort(p[7], p[9]);
203 sort(p[3], p[17]);
204 sort(p[5], p[20]);
205 sort(p[7], p[8]);
206 sort(p[3], p[16]);
207 sort(p[6], p[20]);
208 sort(p[5], p[17]);
209 sort(p[7], p[20]);
210 sort(p[4], p[16]);
211 sort(p[6], p[18]);
212 sort(p[5], p[16]);
213 sort(p[7], p[19]);
214 sort(p[7], p[18]);
215 sort(p[6], p[16]);
216 sort(p[7], p[17]);
217 sort(p[10], p[18]);
218 sort(p[7], p[16]);
219 sort(p[9], p[17]);
220 sort(p[8], p[16]);
221 sort(p[9], p[16]);
222 sort(p[10], p[16]);
223}
224
225inline void sort25(uint8x8_t p[25])
226{
227 sort(p[1], p[2]);
228 sort(p[0], p[1]);
229 sort(p[1], p[2]);
230 sort(p[4], p[5]);
231 sort(p[3], p[4]);
232 sort(p[4], p[5]);
233 sort(p[0], p[3]);
234 sort(p[2], p[5]);
235 sort(p[2], p[3]);
236 sort(p[1], p[4]);
237 sort(p[1], p[2]);
238 sort(p[3], p[4]);
239 sort(p[7], p[8]);
240 sort(p[6], p[7]);
241 sort(p[7], p[8]);
242 sort(p[10], p[11]);
243 sort(p[9], p[10]);
244 sort(p[10], p[11]);
245 sort(p[6], p[9]);
246 sort(p[8], p[11]);
247 sort(p[8], p[9]);
248 sort(p[7], p[10]);
249 sort(p[7], p[8]);
250 sort(p[9], p[10]);
251 sort(p[0], p[6]);
252 sort(p[4], p[10]);
253 sort(p[4], p[6]);
254 sort(p[2], p[8]);
255 sort(p[2], p[4]);
256 sort(p[6], p[8]);
257 sort(p[1], p[7]);
258 sort(p[5], p[11]);
259 sort(p[5], p[7]);
260 sort(p[3], p[9]);
261 sort(p[3], p[5]);
262 sort(p[7], p[9]);
263 sort(p[1], p[2]);
264 sort(p[3], p[4]);
265 sort(p[5], p[6]);
266 sort(p[7], p[8]);
267 sort(p[9], p[10]);
268 sort(p[13], p[14]);
269 sort(p[12], p[13]);
270 sort(p[13], p[14]);
271 sort(p[16], p[17]);
272 sort(p[15], p[16]);
273 sort(p[16], p[17]);
274 sort(p[12], p[15]);
275 sort(p[14], p[17]);
276 sort(p[14], p[15]);
277 sort(p[13], p[16]);
278 sort(p[13], p[14]);
279 sort(p[15], p[16]);
280 sort(p[19], p[20]);
281 sort(p[18], p[19]);
282 sort(p[19], p[20]);
283 sort(p[21], p[22]);
284 sort(p[23], p[24]);
285 sort(p[21], p[23]);
286 sort(p[22], p[24]);
287 sort(p[22], p[23]);
288 sort(p[18], p[21]);
289 sort(p[20], p[23]);
290 sort(p[20], p[21]);
291 sort(p[19], p[22]);
292 sort(p[22], p[24]);
293 sort(p[19], p[20]);
294 sort(p[21], p[22]);
295 sort(p[23], p[24]);
296 sort(p[12], p[18]);
297 sort(p[16], p[22]);
298 sort(p[16], p[18]);
299 sort(p[14], p[20]);
300 sort(p[20], p[24]);
301 sort(p[14], p[16]);
302 sort(p[18], p[20]);
303 sort(p[22], p[24]);
304 sort(p[13], p[19]);
305 sort(p[17], p[23]);
306 sort(p[17], p[19]);
307 sort(p[15], p[21]);
308 sort(p[15], p[17]);
309 sort(p[19], p[21]);
310 sort(p[13], p[14]);
311 sort(p[15], p[16]);
312 sort(p[17], p[18]);
313 sort(p[19], p[20]);
314 sort(p[21], p[22]);
315 sort(p[23], p[24]);
316 sort(p[0], p[12]);
317 sort(p[8], p[20]);
318 sort(p[8], p[12]);
319 sort(p[4], p[16]);
320 sort(p[16], p[24]);
321 sort(p[12], p[16]);
322 sort(p[2], p[14]);
323 sort(p[10], p[22]);
324 sort(p[10], p[14]);
325 sort(p[6], p[18]);
326 sort(p[6], p[10]);
327 sort(p[10], p[12]);
328 sort(p[1], p[13]);
329 sort(p[9], p[21]);
330 sort(p[9], p[13]);
331 sort(p[5], p[17]);
332 sort(p[13], p[17]);
333 sort(p[3], p[15]);
334 sort(p[11], p[23]);
335 sort(p[11], p[15]);
336 sort(p[7], p[19]);
337 sort(p[7], p[11]);
338 sort(p[11], p[13]);
339 sort(p[11], p[12]);
340}
341} // namespace
342
343NENonLinearFilterKernel::NENonLinearFilterKernel()
344 : _border_width(0), _input(nullptr), _output(nullptr), _mask(nullptr), _pattern(MatrixPattern::BOX), _function(NonLinearFilterFunction::MIN), _func_idx(0), _border_size()
345{
346}
347
348BorderSize NENonLinearFilterKernel::border_size() const
349{
350 return _border_size;
351}
352
353void NENonLinearFilterKernel::configure(const ITensor *input, ITensor *output, NonLinearFilterFunction function, unsigned int mask_size, MatrixPattern pattern, const uint8_t *mask,
354 bool border_undefined)
355{
356 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
357 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
358 ARM_COMPUTE_ERROR_ON(3 != mask_size && 5 != mask_size);
359 ARM_COMPUTE_ERROR_ON(MatrixPattern::OTHER == pattern && nullptr == mask);
360
361 // Set class variables
362 _border_size = BorderSize(mask_size / 2);
363 _input = input;
364 _output = output;
365 _mask = mask;
366 _pattern = pattern;
367 _function = function;
368
369 // Configure kernel window
370 const unsigned int num_elems_processed_per_iteration = (MatrixPattern::OTHER == pattern) ? 1 : 8;
371 constexpr unsigned int num_elems_read_per_iteration = 16;
372
373 Window win = calculate_max_window(*input->info(), num_elems_processed_per_iteration, border_undefined, border_size());
374 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
375 update_window_and_padding(win,
376 AccessWindowRectangle(input->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, mask_size),
377 output_access);
378 output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
379
380 INEKernel::configure(win);
381
382 // Define function index
383 _func_idx = (3 == mask_size) ? 0 : 1;
384
385 if(MatrixPattern::OTHER != pattern)
386 {
387 _func_idx = (_func_idx) * 3 + static_cast<unsigned int>(function);
388 }
389}
390
391void NENonLinearFilterKernel::fill_mask(uint8_t *mask, int cols, int rows, MatrixPattern pattern)
392{
393 unsigned int v = 0;
394
395 for(int r = 0; r < rows; ++r)
396 {
397 for(int c = 0; c < cols; ++c, ++v)
398 {
399 uint8_t val = 0;
400
401 switch(pattern)
402 {
403 case MatrixPattern::BOX:
404 val = 255;
405 break;
406 case MatrixPattern::CROSS:
407 val = ((r == (rows / 2)) || (c == (cols / 2))) ? 255 : 0;
408 break;
409 case MatrixPattern::DISK:
410 val = (((r - rows / 2.0f + 0.5f) * (r - rows / 2.0f + 0.5f)) / ((rows / 2.0f) * (rows / 2.0f)) + ((c - cols / 2.0f + 0.5f) * (c - cols / 2.0f + 0.5f)) / ((cols / 2.0f) *
411 (cols / 2.0f))) <= 1.0f ? 255 : 0;
412 break;
413 default:
414 return;
415 }
416
417 mask[v] = val;
418 }
419 }
420}
421
422template <>
423void NENonLinearFilterKernel::median_filter_box<3, 3>(const Window &win)
424{
425 Iterator input(_input, win);
426 Iterator output(_output, win);
427
428 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, -1)));
429 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 0)));
430 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 1)));
431
432 execute_window_loop(win, [&](const Coordinates & id)
433 {
434 const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
435 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
436 const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
437
438 uint8x8_t p0 = vget_low_u8(top_data);
439 uint8x8_t p1 = vext_u8(vget_low_u8(top_data), vget_high_u8(top_data), 1);
440 uint8x8_t p2 = vext_u8(vget_low_u8(top_data), vget_high_u8(top_data), 2);
441 uint8x8_t p3 = vget_low_u8(mid_data);
442 uint8x8_t p4 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
443 uint8x8_t p5 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
444 uint8x8_t p6 = vget_low_u8(bot_data);
445 uint8x8_t p7 = vext_u8(vget_low_u8(bot_data), vget_high_u8(bot_data), 1);
446 uint8x8_t p8 = vext_u8(vget_low_u8(bot_data), vget_high_u8(bot_data), 2);
447
448 sort9(p0, p1, p2, p3, p4, p5, p6, p7, p8);
449
450 vst1_u8(output.ptr(), p4);
451 },
452 input, output);
453}
454template <>
455void NENonLinearFilterKernel::median_filter_box<5, 5>(const Window &win)
456{
457 Iterator input(_input, win);
458 Iterator output(_output, win);
459
460 const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -2)));
461 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
462 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
463 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
464 const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 2)));
465
466 execute_window_loop(win, [&](const Coordinates & id)
467 {
468 const uint8x16_t top2_data = vld1q_u8(input_top2_ptr + input.offset());
469 const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
470 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
471 const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
472 const uint8x16_t bot2_data = vld1q_u8(input_bot2_ptr + input.offset());
473
474 const uint8x8_t d[] =
475 {
476 vget_low_u8(top2_data),
477 vget_high_u8(top2_data),
478 vget_low_u8(top_data),
479 vget_high_u8(top_data),
480 vget_low_u8(mid_data),
481 vget_high_u8(mid_data),
482 vget_low_u8(bot_data),
483 vget_high_u8(bot_data),
484 vget_low_u8(bot2_data),
485 vget_high_u8(bot2_data)
486 };
487
488 uint8x8_t p[25];
489 for(unsigned int i = 0; i < 5; ++i)
490 {
491 const unsigned int idx_d = i * 2;
492 const unsigned int idx_p = i * 5;
493
494 p[idx_p] = d[idx_d];
495 p[idx_p + 1] = vext_u8(d[idx_d], d[idx_d + 1], 1);
496 p[idx_p + 2] = vext_u8(d[idx_d], d[idx_d + 1], 2);
497 p[idx_p + 3] = vext_u8(d[idx_d], d[idx_d + 1], 3);
498 p[idx_p + 4] = vext_u8(d[idx_d], d[idx_d + 1], 4);
499 }
500
501 sort25(p);
502
503 vst1_u8(output.ptr(), p[12]);
504 },
505 input, output);
506}
507
508template <int mask_w, int mask_h>
509void NENonLinearFilterKernel::min_filter_box(const Window &win)
510{
511 static_assert(mask_w > 0, "Mask size must not be 0");
512 static_assert(mask_h > 0, "Mask size must not be 0");
513
514 Iterator input(_input, win);
515 Iterator output(_output, win);
516
517 const int k_row_half = mask_h / 2;
518 const int k_col_half = mask_w / 2;
519
520 // Set row pointers
521 std::array<const unsigned char *, mask_h> input_ptrs{ {} };
522 for(int i = -k_row_half; i <= k_row_half; ++i)
523 {
524 input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
525 }
526
527 execute_window_loop(win, [&](const Coordinates & id)
528 {
529 // Get min of rows
530 uint8x16_t rows_min = vld1q_u8(input_ptrs[0] + input.offset());
531
532 for(unsigned int r = 1; r < mask_h; ++r)
533 {
534 const uint8x16_t data = vld1q_u8(input_ptrs[r] + input.offset());
535 rows_min = vminq_u8(rows_min, data);
536 }
537
538 const uint8x8_t out = min_row<mask_w>(rows_min);
539
540 // Store result as U8
541 vst1_u8(output.ptr(), out);
542 },
543 input, output);
544}
545
546template <int mask_w, int mask_h>
547void NENonLinearFilterKernel::max_filter_box(const Window &win)
548{
549 static_assert(mask_w > 0, "Mask size must not be 0");
550 static_assert(mask_h > 0, "Mask size must not be 0");
551 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
552
553 Iterator input(_input, win);
554 Iterator output(_output, win);
555
556 const int k_row_half = mask_h / 2;
557 const int k_col_half = mask_w / 2;
558
559 // Set row pointers
560 std::array<const unsigned char *, mask_h> input_ptrs{ {} };
561 for(int i = -k_row_half; i <= k_row_half; ++i)
562 {
563 input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
564 }
565
566 execute_window_loop(win, [&](const Coordinates & id)
567 {
568 uint8x16_t rows_max = vld1q_u8(input_ptrs[0] + input.offset());
569
570 // Get max of rows
571 for(unsigned int r = 1; r < mask_h; ++r)
572 {
573 const uint8x16_t data = vld1q_u8(input_ptrs[r] + input.offset());
574 rows_max = vmaxq_u8(rows_max, data);
575 }
576
577 // Get max of columns
578 const uint8x8_t out = max_row<mask_w>(rows_max);
579
580 // Store result as U8
581 vst1_u8(output.ptr(), out);
582 },
583 input, output);
584}
585
586template <>
587void NENonLinearFilterKernel::median_filter_cross<3, 3>(const Window &win)
588{
589 Iterator input(_input, win);
590 Iterator output(_output, win);
591
592 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -1)));
593 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 0)));
594 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 1)));
595
596 execute_window_loop(win, [&](const Coordinates & id)
597 {
598 const uint8x8_t top_data = vld1_u8(input_top_ptr + input.offset());
599 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
600 const uint8x8_t bot_data = vld1_u8(input_bot_ptr + input.offset());
601
602 uint8x8_t p0 = top_data;
603 uint8x8_t p1 = vget_low_u8(mid_data);
604 uint8x8_t p2 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
605 uint8x8_t p3 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
606 uint8x8_t p4 = bot_data;
607
608 sort5(p0, p1, p2, p3, p4);
609
610 vst1_u8(output.ptr(), p2);
611 },
612 input, output);
613}
614
615template <>
616void NENonLinearFilterKernel::median_filter_cross<5, 5>(const Window &win)
617{
618 Iterator input(_input, win);
619 Iterator output(_output, win);
620
621 const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -2)));
622 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -1)));
623 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
624 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 1)));
625 const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 2)));
626
627 execute_window_loop(win, [&](const Coordinates & id)
628 {
629 const uint8x8_t top2_data = vld1_u8(input_top2_ptr + input.offset());
630 const uint8x8_t top_data = vld1_u8(input_top_ptr + input.offset());
631 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
632 const uint8x8_t bot_data = vld1_u8(input_bot_ptr + input.offset());
633 const uint8x8_t bot2_data = vld1_u8(input_bot2_ptr + input.offset());
634
635 uint8x8_t p0 = top2_data;
636 uint8x8_t p1 = top_data;
637 uint8x8_t p2 = vget_low_u8(mid_data);
638 uint8x8_t p3 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
639 uint8x8_t p4 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
640 uint8x8_t p5 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 3);
641 uint8x8_t p6 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 4);
642 uint8x8_t p7 = bot_data;
643 uint8x8_t p8 = bot2_data;
644
645 sort9(p0, p1, p2, p3, p4, p5, p6, p7, p8);
646
647 vst1_u8(output.ptr(), p4);
648 },
649 input, output);
650}
651
652template <int mask_w, int mask_h>
653void NENonLinearFilterKernel::min_filter_cross(const Window &win)
654{
655 static_assert(mask_w > 0, "Mask size must not be 0");
656 static_assert(mask_h > 0, "Mask size must not be 0");
657 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
658
659 Iterator input(_input, win);
660 Iterator output(_output, win);
661
662 const int k_row_half = mask_h / 2;
663 const int k_col_half = mask_w / 2;
664
665 const unsigned char *mid_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, 0));
666
667 // Set row pointers
668 std::array<const unsigned char *, mask_h> input_ptrs{ {} };
669 for(int i = -k_row_half; i <= k_row_half; ++i)
670 {
671 input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(0, i));
672 }
673
674 execute_window_loop(win, [&](const Coordinates & id)
675 {
676 uint8x8_t rows_min = vld1_u8(input_ptrs[0] + input.offset());
677
678 // Get min of rows
679 for(unsigned int r = 1; r < mask_h; ++r)
680 {
681 const uint8x8_t data = vld1_u8(input_ptrs[r] + input.offset());
682 rows_min = vmin_u8(rows_min, data);
683 }
684
685 // Get min of middle row
686 const uint8x16_t data = vld1q_u8(mid_ptr + input.offset());
687 uint8x8_t out = min_row<mask_w>(data);
688
689 // Get final min
690 out = vmin_u8(out, rows_min);
691
692 // Store result as U8
693 vst1_u8(output.ptr(), out);
694 },
695 input, output);
696}
697
698template <int mask_w, int mask_h>
699void NENonLinearFilterKernel::max_filter_cross(const Window &win)
700{
701 static_assert(mask_w > 0, "Mask size must not be 0");
702 static_assert(mask_h > 0, "Mask size must not be 0");
703 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
704
705 Iterator input(_input, win);
706 Iterator output(_output, win);
707
708 const int k_row_half = mask_h / 2;
709 const int k_col_half = mask_w / 2;
710
711 const unsigned char *mid_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, 0));
712
713 // Set row pointers
714 std::array<unsigned char *, mask_h> input_ptrs{ {} };
715 for(int i = -k_row_half; i <= k_row_half; ++i)
716 {
717 input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(0, i));
718 }
719
720 execute_window_loop(win, [&](const Coordinates & id)
721 {
722 uint8x8_t rows_max = vld1_u8(input_ptrs[0] + input.offset());
723
724 // Get max of rows
725 for(unsigned int r = 1; r < mask_h; ++r)
726 {
727 const uint8x8_t data = vld1_u8(input_ptrs[r] + input.offset());
728 rows_max = vmax_u8(rows_max, data);
729 }
730
731 // Get max of middle row
732 const uint8x16_t data = vld1q_u8(mid_ptr + input.offset());
733 uint8x8_t out = max_row<mask_w>(data);
734
735 // Get final max
736 out = vmax_u8(out, rows_max);
737
738 // Store result as U8
739 vst1_u8(output.ptr(), out);
740 },
741 input, output);
742}
743
744template <>
745void NENonLinearFilterKernel::median_filter_disk<5, 5>(const Window &win)
746{
747 Iterator input(_input, win);
748 Iterator output(_output, win);
749
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100750 static const uint8x16_t zero = vdupq_n_u8(0);
751 const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -2)));
752 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
753 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
754 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
755 const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 2)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100756
757 execute_window_loop(win, [&](const Coordinates & id)
758 {
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100759 const uint8x16_t top2_data = vextq_u8(vld1q_u8(input_top2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100760 const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
761 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
762 const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100763 const uint8x16_t bot2_data = vextq_u8(vld1q_u8(input_bot2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100764
765 uint8x8_t d[] =
766 {
767 vget_low_u8(top2_data),
768 vget_high_u8(top2_data),
769 vget_low_u8(top_data),
770 vget_high_u8(top_data),
771 vget_low_u8(mid_data),
772 vget_high_u8(mid_data),
773 vget_low_u8(bot_data),
774 vget_high_u8(bot_data),
775 vget_low_u8(bot2_data),
776 vget_high_u8(bot2_data)
777 };
778
779 uint8x8_t p[21];
780 p[0] = d[0];
781 p[1] = vext_u8(d[0], d[1], 1);
782 p[2] = vext_u8(d[0], d[1], 2);
783 p[18] = d[8];
784 p[19] = vext_u8(d[8], d[9], 1);
785 p[20] = vext_u8(d[8], d[9], 2);
786
787 for(unsigned int i = 0; i < 3; ++i)
788 {
789 const unsigned int idx_d = 2 + i * 2;
790 const unsigned int idx_p = 3 + i * 5;
791
792 p[idx_p] = d[idx_d];
793 p[idx_p + 1] = vext_u8(d[idx_d], d[idx_d + 1], 1);
794 p[idx_p + 2] = vext_u8(d[idx_d], d[idx_d + 1], 2);
795 p[idx_p + 3] = vext_u8(d[idx_d], d[idx_d + 1], 3);
796 p[idx_p + 4] = vext_u8(d[idx_d], d[idx_d + 1], 4);
797 }
798
799 sort21(p);
800
801 vst1_u8(output.ptr(), p[10]);
802 },
803 input, output);
804}
805
806template <>
807void NENonLinearFilterKernel::min_filter_disk<5, 5>(const Window &win)
808{
809 Iterator input(_input, win);
810 Iterator output(_output, win);
811
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100812 static const uint8x16_t zero = vdupq_n_u8(0);
813 const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -2)));
814 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
815 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
816 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
817 const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 2)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100818
819 execute_window_loop(win, [&](const Coordinates & id)
820 {
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100821 const uint8x16_t top2_data = vextq_u8(vld1q_u8(input_top2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100822 const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
823 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
824 const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100825 const uint8x16_t bot2_data = vextq_u8(vld1q_u8(input_bot2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100826
827 const uint8x16_t rows_min_3 = vminq_u8(top2_data, bot2_data);
828 uint8x16_t rows_min_5 = vminq_u8(top_data, bot_data);
829 rows_min_5 = vminq_u8(rows_min_5, mid_data);
830
831 const uint8x8_t out_3 = min_row<3>(rows_min_3);
832 const uint8x8_t out_5 = min_row<5>(rows_min_5);
833
834 vst1_u8(output.ptr(), vmin_u8(out_3, out_5));
835 },
836 input, output);
837}
838
839template <>
840void NENonLinearFilterKernel::max_filter_disk<5, 5>(const Window &win)
841{
842 Iterator input(_input, win);
843 Iterator output(_output, win);
844
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100845 static const uint8x16_t zero = vdupq_n_u8(0);
846 const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -2)));
847 const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
848 const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
849 const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
850 const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 2)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100851
852 execute_window_loop(win, [&](const Coordinates & id)
853 {
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100854 const uint8x16_t top2_data = vextq_u8(vld1q_u8(input_top2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100855 const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
856 const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
857 const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
Georgios Pinitas0a7a8d12017-10-23 12:23:10 +0100858 const uint8x16_t bot2_data = vextq_u8(vld1q_u8(input_bot2_ptr + input.offset()), zero, 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100859
860 const uint8x16_t rows_max_3 = vmaxq_u8(top2_data, bot2_data);
861 uint8x16_t rows_max_5 = vmaxq_u8(top_data, bot_data);
862 rows_max_5 = vmaxq_u8(rows_max_5, mid_data);
863
864 const uint8x8_t out_3 = max_row<3>(rows_max_3);
865 const uint8x8_t out_5 = max_row<5>(rows_max_5);
866
867 vst1_u8(output.ptr(), vmax_u8(out_3, out_5));
868 },
869 input, output);
870}
871
872template <int mask_w, int mask_h>
873void NENonLinearFilterKernel::non_linear_filter_generic(const Window &win)
874{
875 Iterator input(_input, win);
876 Iterator output(_output, win);
877 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
878
879 const int k_row_half = mask_h / 2;
880 const int k_col_half = mask_w / 2;
881 constexpr int mask_size = mask_w * mask_h;
882
883 // Set row pointers
884 std::array<unsigned char *, mask_h> input_ptrs{ {} };
885 for(int i = -k_row_half; i <= k_row_half; ++i)
886 {
887 input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
888 }
889
890 execute_window_loop(win, [&](const Coordinates & id)
891 {
892 std::array<uint8_t, mask_size> vals{ {} };
893
894 size_t v = 0;
895 size_t m = 0;
896
897 for(unsigned int r = 0; r < mask_h; ++r)
898 {
899 const auto in_ptr = static_cast<const uint8_t *>(input_ptrs[r] + input.offset());
900
901 for(unsigned int c = 0; c < mask_w; ++c, ++m)
902 {
903 if(_mask[m] == 255)
904 {
905 vals[v] = in_ptr[c];
906 ++v;
907 }
908 }
909 }
910
911 // Only do something if there is at least one non-zero element in the
912 // mask
913 if(v > 0)
914 {
915 std::sort(vals.begin(), vals.begin() + v);
916
917 switch(_function)
918 {
919 case NonLinearFilterFunction::MIN:
920 *output.ptr() = vals[0];
921 break;
922 case NonLinearFilterFunction::MAX:
923 *output.ptr() = vals[v - 1];
924 break;
925 case NonLinearFilterFunction::MEDIAN:
926 *output.ptr() = vals[v / 2];
927 break;
928 default:
929 break;
930 }
931 }
932 },
933 input, output);
934}
935
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100936void NENonLinearFilterKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100937{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100938 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100939 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
940 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
941
942 using NonLinearFilterFunction = void (NENonLinearFilterKernel::*)(const Window & window);
943
944 // Function table for BOX pattern
945 static const std::array<NonLinearFilterFunction, 6> func_table_box =
946 {
947 {
948 &NENonLinearFilterKernel::median_filter_box<3, 3>,
949 &NENonLinearFilterKernel::min_filter_box<3, 3>,
950 &NENonLinearFilterKernel::max_filter_box<3, 3>,
951 &NENonLinearFilterKernel::median_filter_box<5, 5>,
952 &NENonLinearFilterKernel::min_filter_box<5, 5>,
953 &NENonLinearFilterKernel::max_filter_box<5, 5>,
954 }
955 };
956
957 // Function table for CROSS pattern
958 static const std::array<NonLinearFilterFunction, 6> func_table_cross =
959 {
960 {
961 &NENonLinearFilterKernel::median_filter_cross<3, 3>,
962 &NENonLinearFilterKernel::min_filter_cross<3, 3>,
963 &NENonLinearFilterKernel::max_filter_cross<3, 3>,
964 &NENonLinearFilterKernel::median_filter_cross<5, 5>,
965 &NENonLinearFilterKernel::min_filter_cross<5, 5>,
966 &NENonLinearFilterKernel::max_filter_cross<5, 5>,
967 }
968 };
969
970 // Function table for DISK pattern
971 static const std::array<NonLinearFilterFunction, 6> func_table_disk =
972 {
973 {
974 &NENonLinearFilterKernel::median_filter_box<3, 3>,
975 &NENonLinearFilterKernel::min_filter_box<3, 3>,
976 &NENonLinearFilterKernel::max_filter_box<3, 3>,
977 &NENonLinearFilterKernel::median_filter_disk<5, 5>,
978 &NENonLinearFilterKernel::min_filter_disk<5, 5>,
979 &NENonLinearFilterKernel::max_filter_disk<5, 5>,
980 }
981 };
982
983 // Function table for OTHER pattern
984 static const std::array<NonLinearFilterFunction, 2> func_table_generic =
985 {
986 {
987 &NENonLinearFilterKernel::non_linear_filter_generic<3, 3>,
988 &NENonLinearFilterKernel::non_linear_filter_generic<5, 5>,
989 }
990 };
991
992 switch(_pattern)
993 {
994 case MatrixPattern::BOX:
995 ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_box.size());
996 (this->*func_table_box[_func_idx])(window);
997 break;
998 case MatrixPattern::CROSS:
999 ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_cross.size());
1000 (this->*func_table_cross[_func_idx])(window);
1001 break;
1002 case MatrixPattern::DISK:
1003 ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_disk.size());
1004 (this->*func_table_disk[_func_idx])(window);
1005 break;
1006 case MatrixPattern::OTHER:
1007 default:
1008 ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_generic.size());
1009 (this->*func_table_generic[_func_idx])(window);
1010 break;
1011 }
1012}
1013} // namespace arm_compute