blob: b0e26328ed01ab6044e0697382724f673e8d87d9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrouf63885b2019-01-16 14:18:09 +00002 * Copyright (c) 2016-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_UTILS_H__
25#define __ARM_COMPUTE_UTILS_H__
26
27#include "arm_compute/core/Error.h"
Giuseppe Rossinid7647d42018-07-17 18:13:13 +010028#include "arm_compute/core/PixelValue.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000029#include "arm_compute/core/Rounding.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Types.h"
31
32#include <algorithm>
33#include <cstdint>
34#include <cstdlib>
Michalis Spyrou53860dd2019-07-01 14:20:56 +010035#include <iomanip>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include <numeric>
37#include <sstream>
38#include <string>
39#include <type_traits>
40#include <utility>
steniu017ce53c62017-09-29 14:55:00 +010041#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43namespace arm_compute
44{
Alex Gildayc357c472018-03-21 13:54:09 +000045/** Calculate the rounded up quotient of val / m.
46 *
47 * @param[in] val Value to divide and round up.
48 * @param[in] m Value to divide by.
49 *
50 * @return the result.
51 */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000052template <typename S, typename T>
53constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
54{
55 return (val + m - 1) / m;
56}
57
Alex Gildayc357c472018-03-21 13:54:09 +000058/** Computes the smallest number larger or equal to value that is a multiple of divisor.
59 *
60 * @param[in] value Lower bound value
61 * @param[in] divisor Value to compute multiple of.
62 *
63 * @return the result.
64 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010065template <typename S, typename T>
66inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
67{
68 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000069 return DIV_CEIL(value, divisor) * divisor;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070}
71
Alex Gildayc357c472018-03-21 13:54:09 +000072/** Computes the largest number smaller or equal to value that is a multiple of divisor.
73 *
74 * @param[in] value Upper bound value
75 * @param[in] divisor Value to compute multiple of.
76 *
77 * @return the result.
78 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010079template <typename S, typename T>
80inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
81{
82 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
83 return (value / divisor) * divisor;
84}
85
Anthony Barbier6ff3b192017-09-04 18:44:23 +010086/** Returns the arm_compute library build information
87 *
88 * Contains the version number and the build options used to build the library
89 *
90 * @return The arm_compute library build information
91 */
92std::string build_information();
93
94/** Load an entire file in memory
95 *
96 * @param[in] filename Name of the file to read.
97 * @param[in] binary Is it a binary file ?
98 *
99 * @return The content of the file.
100 */
101std::string read_file(const std::string &filename, bool binary);
102
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103/** The size in bytes of the data type
104 *
105 * @param[in] data_type Input data type
106 *
107 * @return The size in bytes of the data type
108 */
109inline size_t data_size_from_type(DataType data_type)
110{
111 switch(data_type)
112 {
113 case DataType::U8:
114 case DataType::S8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100115 case DataType::QSYMM8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100116 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100117 case DataType::QSYMM8_PER_CHANNEL:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118 return 1;
119 case DataType::U16:
120 case DataType::S16:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100121 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100122 case DataType::F16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100123 return 2;
124 case DataType::F32:
125 case DataType::U32:
126 case DataType::S32:
127 return 4;
128 case DataType::F64:
129 case DataType::U64:
130 case DataType::S64:
131 return 8;
132 case DataType::SIZET:
133 return sizeof(size_t);
134 default:
135 ARM_COMPUTE_ERROR("Invalid data type");
136 return 0;
137 }
138}
139
140/** The size in bytes of the pixel format
141 *
142 * @param[in] format Input format
143 *
144 * @return The size in bytes of the pixel format
145 */
146inline size_t pixel_size_from_format(Format format)
147{
148 switch(format)
149 {
150 case Format::U8:
151 return 1;
152 case Format::U16:
153 case Format::S16:
154 case Format::F16:
155 case Format::UV88:
156 case Format::YUYV422:
157 case Format::UYVY422:
158 return 2;
159 case Format::RGB888:
160 return 3;
161 case Format::RGBA8888:
162 return 4;
163 case Format::U32:
164 case Format::S32:
165 case Format::F32:
166 return 4;
167 //Doesn't make sense for planar formats:
168 case Format::NV12:
169 case Format::NV21:
170 case Format::IYUV:
171 case Format::YUV444:
172 default:
173 ARM_COMPUTE_ERROR("Undefined pixel size for given format");
174 return 0;
175 }
176}
177
178/** The size in bytes of the data type
179 *
180 * @param[in] dt Input data type
181 *
182 * @return The size in bytes of the data type
183 */
184inline size_t element_size_from_data_type(DataType dt)
185{
186 switch(dt)
187 {
188 case DataType::S8:
189 case DataType::U8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100190 case DataType::QSYMM8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100191 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100192 case DataType::QSYMM8_PER_CHANNEL:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193 return 1;
194 case DataType::U16:
195 case DataType::S16:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100196 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197 case DataType::F16:
198 return 2;
199 case DataType::U32:
200 case DataType::S32:
201 case DataType::F32:
202 return 4;
203 default:
204 ARM_COMPUTE_ERROR("Undefined element size for given data type");
205 return 0;
206 }
207}
208
209/** Return the data type used by a given single-planar pixel format
210 *
211 * @param[in] format Input format
212 *
213 * @return The size in bytes of the pixel format
214 */
215inline DataType data_type_from_format(Format format)
216{
217 switch(format)
218 {
219 case Format::U8:
220 case Format::UV88:
221 case Format::RGB888:
222 case Format::RGBA8888:
223 case Format::YUYV422:
224 case Format::UYVY422:
225 return DataType::U8;
226 case Format::U16:
227 return DataType::U16;
228 case Format::S16:
229 return DataType::S16;
230 case Format::U32:
231 return DataType::U32;
232 case Format::S32:
233 return DataType::S32;
234 case Format::F16:
235 return DataType::F16;
236 case Format::F32:
237 return DataType::F32;
238 //Doesn't make sense for planar formats:
239 case Format::NV12:
240 case Format::NV21:
241 case Format::IYUV:
242 case Format::YUV444:
243 default:
244 ARM_COMPUTE_ERROR("Not supported data_type for given format");
245 return DataType::UNKNOWN;
246 }
247}
248
249/** Return the plane index of a given channel given an input format.
250 *
251 * @param[in] format Input format
252 * @param[in] channel Input channel
253 *
254 * @return The plane index of the specific channel of the specific format
255 */
256inline int plane_idx_from_channel(Format format, Channel channel)
257{
258 switch(format)
259 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100260 // Single planar formats have a single plane
261 case Format::U8:
262 case Format::U16:
263 case Format::S16:
264 case Format::U32:
265 case Format::S32:
266 case Format::F16:
267 case Format::F32:
268 case Format::UV88:
269 case Format::RGB888:
270 case Format::RGBA8888:
271 case Format::YUYV422:
272 case Format::UYVY422:
273 return 0;
274 // Multi planar formats
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100275 case Format::NV12:
276 case Format::NV21:
277 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100278 // Channel U and V share the same plane of format UV88
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100279 switch(channel)
280 {
281 case Channel::Y:
282 return 0;
283 case Channel::U:
284 case Channel::V:
285 return 1;
286 default:
287 ARM_COMPUTE_ERROR("Not supported channel");
288 return 0;
289 }
290 }
291 case Format::IYUV:
292 case Format::YUV444:
293 {
294 switch(channel)
295 {
296 case Channel::Y:
297 return 0;
298 case Channel::U:
299 return 1;
300 case Channel::V:
301 return 2;
302 default:
303 ARM_COMPUTE_ERROR("Not supported channel");
304 return 0;
305 }
306 }
307 default:
308 ARM_COMPUTE_ERROR("Not supported format");
309 return 0;
310 }
311}
312
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100313/** Return the channel index of a given channel given an input format.
314 *
315 * @param[in] format Input format
316 * @param[in] channel Input channel
317 *
318 * @return The channel index of the specific channel of the specific format
319 */
320inline int channel_idx_from_format(Format format, Channel channel)
321{
322 switch(format)
323 {
324 case Format::RGB888:
325 {
326 switch(channel)
327 {
328 case Channel::R:
329 return 0;
330 case Channel::G:
331 return 1;
332 case Channel::B:
333 return 2;
334 default:
335 ARM_COMPUTE_ERROR("Not supported channel");
336 return 0;
337 }
338 }
339 case Format::RGBA8888:
340 {
341 switch(channel)
342 {
343 case Channel::R:
344 return 0;
345 case Channel::G:
346 return 1;
347 case Channel::B:
348 return 2;
349 case Channel::A:
350 return 3;
351 default:
352 ARM_COMPUTE_ERROR("Not supported channel");
353 return 0;
354 }
355 }
356 case Format::YUYV422:
357 {
358 switch(channel)
359 {
360 case Channel::Y:
361 return 0;
362 case Channel::U:
363 return 1;
364 case Channel::V:
365 return 3;
366 default:
367 ARM_COMPUTE_ERROR("Not supported channel");
368 return 0;
369 }
370 }
371 case Format::UYVY422:
372 {
373 switch(channel)
374 {
375 case Channel::Y:
376 return 1;
377 case Channel::U:
378 return 0;
379 case Channel::V:
380 return 2;
381 default:
382 ARM_COMPUTE_ERROR("Not supported channel");
383 return 0;
384 }
385 }
386 case Format::NV12:
387 {
388 switch(channel)
389 {
390 case Channel::Y:
391 return 0;
392 case Channel::U:
393 return 0;
394 case Channel::V:
395 return 1;
396 default:
397 ARM_COMPUTE_ERROR("Not supported channel");
398 return 0;
399 }
400 }
401 case Format::NV21:
402 {
403 switch(channel)
404 {
405 case Channel::Y:
406 return 0;
407 case Channel::U:
408 return 1;
409 case Channel::V:
410 return 0;
411 default:
412 ARM_COMPUTE_ERROR("Not supported channel");
413 return 0;
414 }
415 }
416 case Format::YUV444:
417 case Format::IYUV:
418 {
419 switch(channel)
420 {
421 case Channel::Y:
422 return 0;
423 case Channel::U:
424 return 0;
425 case Channel::V:
426 return 0;
427 default:
428 ARM_COMPUTE_ERROR("Not supported channel");
429 return 0;
430 }
431 }
432 default:
433 ARM_COMPUTE_ERROR("Not supported format");
434 return 0;
435 }
436}
437
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100438/** Return the number of planes for a given format
439 *
440 * @param[in] format Input format
441 *
442 * @return The number of planes for a given image format.
443 */
444inline size_t num_planes_from_format(Format format)
445{
446 switch(format)
447 {
448 case Format::U8:
449 case Format::S16:
450 case Format::U16:
451 case Format::S32:
452 case Format::U32:
453 case Format::F16:
454 case Format::F32:
455 case Format::RGB888:
456 case Format::RGBA8888:
457 case Format::YUYV422:
458 case Format::UYVY422:
459 return 1;
460 case Format::NV12:
461 case Format::NV21:
462 return 2;
463 case Format::IYUV:
464 case Format::YUV444:
465 return 3;
466 default:
467 ARM_COMPUTE_ERROR("Not supported format");
468 return 0;
469 }
470}
471
472/** Return the number of channels for a given single-planar pixel format
473 *
474 * @param[in] format Input format
475 *
476 * @return The number of channels for a given image format.
477 */
478inline size_t num_channels_from_format(Format format)
479{
480 switch(format)
481 {
482 case Format::U8:
483 case Format::U16:
484 case Format::S16:
485 case Format::U32:
486 case Format::S32:
487 case Format::F16:
488 case Format::F32:
489 return 1;
490 // Because the U and V channels are subsampled
491 // these formats appear like having only 2 channels:
492 case Format::YUYV422:
493 case Format::UYVY422:
494 return 2;
495 case Format::UV88:
496 return 2;
497 case Format::RGB888:
498 return 3;
499 case Format::RGBA8888:
500 return 4;
501 //Doesn't make sense for planar formats:
502 case Format::NV12:
503 case Format::NV21:
504 case Format::IYUV:
505 case Format::YUV444:
506 default:
507 return 0;
508 }
509}
510
Chunosovd621bca2017-11-03 17:33:15 +0700511/** Return the promoted data type of a given data type.
512 *
513 * @note If promoted data type is not supported an error will be thrown
514 *
515 * @param[in] dt Data type to get the promoted type of.
516 *
517 * @return Promoted data type
518 */
519inline DataType get_promoted_data_type(DataType dt)
520{
521 switch(dt)
522 {
523 case DataType::U8:
524 return DataType::U16;
525 case DataType::S8:
526 return DataType::S16;
Chunosovd621bca2017-11-03 17:33:15 +0700527 case DataType::U16:
528 return DataType::U32;
529 case DataType::S16:
530 return DataType::S32;
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100531 case DataType::QSYMM8:
Chunosovd621bca2017-11-03 17:33:15 +0700532 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100533 case DataType::QSYMM8_PER_CHANNEL:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100534 case DataType::QSYMM16:
Chunosovd621bca2017-11-03 17:33:15 +0700535 case DataType::F16:
536 case DataType::U32:
537 case DataType::S32:
538 case DataType::F32:
Chunosovd621bca2017-11-03 17:33:15 +0700539 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
540 default:
541 ARM_COMPUTE_ERROR("Undefined data type!");
542 }
543 return DataType::UNKNOWN;
544}
545
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100546/** Return true if the given format has horizontal subsampling.
547 *
548 * @param[in] format Format to determine subsampling.
549 *
550 * @return True if the format can be subsampled horizontaly.
551 */
552inline bool has_format_horizontal_subsampling(Format format)
553{
554 return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
555}
556
557/** Return true if the given format has vertical subsampling.
558 *
559 * @param[in] format Format to determine subsampling.
560 *
561 * @return True if the format can be subsampled verticaly.
562 */
563inline bool has_format_vertical_subsampling(Format format)
564{
565 return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
566}
567
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568/** Separate a 2D convolution into two 1D convolutions
Anthony Barbierf202e502017-11-23 18:02:04 +0000569 *
570 * @param[in] conv 2D convolution
571 * @param[out] conv_col 1D vertical convolution
572 * @param[out] conv_row 1D horizontal convolution
573 * @param[in] size Size of the 2D convolution
574 *
575 * @return true if the separation was successful
576 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100577inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
578{
579 int32_t min_col = -1;
580 int16_t min_col_val = -1;
581
582 for(int32_t i = 0; i < size; ++i)
583 {
584 if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
585 {
586 min_col = i;
587 min_col_val = conv[i];
588 }
589 }
590
591 if(min_col < 0)
592 {
593 return false;
594 }
595
596 for(uint32_t j = 0; j < size; ++j)
597 {
598 conv_col[j] = conv[min_col + j * size];
599 }
600
601 for(uint32_t i = 0; i < size; i++)
602 {
603 if(static_cast<int>(i) == min_col)
604 {
605 conv_row[i] = 1;
606 }
607 else
608 {
609 int16_t coeff = conv[i] / conv[min_col];
610
611 for(uint32_t j = 1; j < size; ++j)
612 {
613 if(conv[i + j * size] != (conv_col[j] * coeff))
614 {
615 return false;
616 }
617 }
618
619 conv_row[i] = coeff;
620 }
621 }
622
623 return true;
624}
625
626/** Calculate the scale of the given square matrix
627 *
628 * The scale is the absolute value of the sum of all the coefficients in the matrix.
629 *
630 * @note If the coefficients add up to 0 then the scale is set to 1.
631 *
632 * @param[in] matrix Matrix coefficients
633 * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
634 *
635 * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
636 */
637inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
638{
639 const size_t size = matrix_size * matrix_size;
640
641 return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
642}
643
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100644/** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
645 *
646 * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
Manuel Bottini581c8982019-02-07 10:31:57 +0000647 * <a href="https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java">Android Source</a>
648 * <a href="https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM">WebM</a>
649 * <a href="https://bugs.chromium.org/p/libyuv/issues/detail?id=198&amp;can=1&amp;q=odd%20width">libYUV</a>
650 * <a href="https://sourceforge.net/p/raw-yuvplayer/bugs/1/">YUVPlayer</a> *
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100651 *
652 * @param[in, out] shape Tensor shape of 2D size
653 * @param[in] format Format of the tensor
654 *
Alex Gildayc357c472018-03-21 13:54:09 +0000655 * @return The adjusted tensor shape.
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100656 */
657inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
658{
659 TensorShape output{ shape };
660
661 // Force width to be even for formats which require subsampling of the U and V channels
662 if(has_format_horizontal_subsampling(format))
663 {
664 output.set(0, output.x() & ~1U);
665 }
666
667 // Force height to be even for formats which require subsampling of the U and V channels
668 if(has_format_vertical_subsampling(format))
669 {
670 output.set(1, output.y() & ~1U);
671 }
672
673 return output;
674}
675
676/** Calculate subsampled shape for a given format and channel
677 *
678 * @param[in] shape Shape of the tensor to calculate the extracted channel.
679 * @param[in] format Format of the tensor.
680 * @param[in] channel Channel to create tensor shape to be extracted.
681 *
682 * @return The subsampled tensor shape.
683 */
684inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
685{
686 TensorShape output{ shape };
687
688 // Subsample shape only for U or V channel
689 if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
690 {
691 // Subsample width for the tensor shape when channel is U or V
692 if(has_format_horizontal_subsampling(format))
693 {
694 output.set(0, output.x() / 2U);
695 }
696
697 // Subsample height for the tensor shape when channel is U or V
698 if(has_format_vertical_subsampling(format))
699 {
700 output.set(1, output.y() / 2U);
701 }
702 }
703
704 return output;
705}
706
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100707/** Calculate accurary required by the horizontal and vertical convolution computations
708 *
709 * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
710 * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
711 * @param[in] size Number of elements per vector of the separated matrix
712 *
713 * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
714 * element of the pair is the biggest data type needed for the second stage.
715 */
716inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
717{
718 DataType first_stage = DataType::UNKNOWN;
719 DataType second_stage = DataType::UNKNOWN;
720
721 auto gez = [](const int16_t &v)
722 {
723 return v >= 0;
724 };
725
726 auto accu_neg = [](const int &first, const int &second)
727 {
728 return first + (second < 0 ? second : 0);
729 };
730
731 auto accu_pos = [](const int &first, const int &second)
732 {
733 return first + (second > 0 ? second : 0);
734 };
735
736 const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
737
738 if(only_positive_coefficients)
739 {
740 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
741 const int max_value = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
742
743 first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
744
745 second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
746 }
747 else
748 {
749 const int min_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
750 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
751 const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
752 const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
753 const int min_value = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
754 const int max_value = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
755
756 first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
757
758 second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
759 }
760
761 return std::make_pair(first_stage, second_stage);
762}
763
764/** Calculate the accuracy required by the squared convolution calculation.
765 *
766 *
767 * @param[in] conv Pointer to the squared convolution matrix
768 * @param[in] size The total size of the convolution matrix
769 *
770 * @return The return is the biggest data type needed to do the convolution
771 */
772inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
773{
774 auto gez = [](const int16_t v)
775 {
776 return v >= 0;
777 };
778
779 const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
780
781 if(only_positive_coefficients)
782 {
783 const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
784 if(max_conv_value <= UINT16_MAX)
785 {
786 return DataType::U16;
787 }
788 else
789 {
790 return DataType::S32;
791 }
792 }
793 else
794 {
795 const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
796 {
797 return b < 0 ? a + b : a;
798 })
799 * UINT8_MAX;
800
801 const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
802 {
803 return b > 0 ? a + b : a;
804 })
805 * UINT8_MAX;
806
807 if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
808 {
809 return DataType::S16;
810 }
811 else
812 {
813 return DataType::S32;
814 }
815 }
816}
817
Pablo Tello35767bc2018-12-05 17:36:30 +0000818/** Permutes the given dimensions according the permutation vector
819 *
820 * @param[in,out] dimensions Dimensions to be permuted.
821 * @param[in] perm Vector describing the permutation.
822 *
823 */
824template <typename T>
825inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
826{
827 const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
828 for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
829 {
830 T dimension_val = old_dim[i];
831 dimensions.set(perm[i], dimension_val);
832 }
833}
834
Georgios Pinitas4074c992018-01-30 18:13:46 +0000835/** Calculate padding requirements in case of SAME padding
836 *
837 * @param[in] input_shape Input shape
838 * @param[in] weights_shape Weights shape
839 * @param[in] conv_info Convolution information (containing strides)
Isabella Gottardi6a914402019-01-30 15:45:42 +0000840 * @param[in] data_layout (Optional) Data layout of the input and weights tensor
Pablo Tello01bbacb2019-04-30 10:32:42 +0100841 * @param[in] dilation (Optional) Dilation factor used in the convolution.
Georgios Pinitas4074c992018-01-30 18:13:46 +0000842 *
843 * @return PadStrideInfo for SAME padding
844 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100845PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW, const Size2D &dilation = Size2D(1u, 1u));
Georgios Pinitas4074c992018-01-30 18:13:46 +0000846
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100847/** Returns expected width and height of the deconvolution's output tensor.
848 *
Michalis Spyrouafbc5ff2018-10-03 14:18:19 +0100849 * @param[in] in_width Width of input tensor (Number of columns)
850 * @param[in] in_height Height of input tensor (Number of rows)
851 * @param[in] kernel_width Kernel width.
852 * @param[in] kernel_height Kernel height.
853 * @param[in] padx X axis padding.
854 * @param[in] pady Y axis padding.
855 * @param[in] stride_x X axis input stride.
856 * @param[in] stride_y Y axis input stride.
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100857 *
858 * @return A pair with the new width in the first position and the new height in the second.
859 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100860std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
861 unsigned int kernel_width, unsigned int kernel_height,
862 unsigned int padx, unsigned int pady,
863 unsigned int stride_x, unsigned int stride_y);
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100864
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100865/** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
866 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100867 * @param[in] width Width of input tensor (Number of columns)
868 * @param[in] height Height of input tensor (Number of rows)
869 * @param[in] kernel_width Kernel width.
870 * @param[in] kernel_height Kernel height.
871 * @param[in] pad_stride_info Pad and stride information.
Alex Gilday7da29b62018-03-23 14:16:00 +0000872 * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100873 *
874 * @return A pair with the new width in the first position and the new height in the second.
875 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100876std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsigned int height,
877 unsigned int kernel_width, unsigned int kernel_height,
878 const PadStrideInfo &pad_stride_info,
879 const Size2D &dilation = Size2D(1U, 1U));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100880
881/** Convert a tensor format into a string.
882 *
883 * @param[in] format @ref Format to be translated to string.
884 *
885 * @return The string describing the format.
886 */
887const std::string &string_from_format(Format format);
888
889/** Convert a channel identity into a string.
890 *
891 * @param[in] channel @ref Channel to be translated to string.
892 *
893 * @return The string describing the channel.
894 */
895const std::string &string_from_channel(Channel channel);
Michele Di Giorgiobf3c6622018-03-08 11:52:27 +0000896/** Convert a data layout identity into a string.
897 *
898 * @param[in] dl @ref DataLayout to be translated to string.
899 *
900 * @return The string describing the data layout.
901 */
902const std::string &string_from_data_layout(DataLayout dl);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100903/** Convert a data type identity into a string.
904 *
905 * @param[in] dt @ref DataType to be translated to string.
906 *
907 * @return The string describing the data type.
908 */
909const std::string &string_from_data_type(DataType dt);
910/** Convert a matrix pattern into a string.
911 *
912 * @param[in] pattern @ref MatrixPattern to be translated to string.
913 *
914 * @return The string describing the matrix pattern.
915 */
916const std::string &string_from_matrix_pattern(MatrixPattern pattern);
917/** Translates a given activation function to a string.
918 *
919 * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
920 *
921 * @return The string describing the activation function.
922 */
923const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
924/** Translates a given non linear function to a string.
925 *
926 * @param[in] function @ref NonLinearFilterFunction to be translated to string.
927 *
928 * @return The string describing the non linear function.
929 */
930const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
931/** Translates a given interpolation policy to a string.
932 *
933 * @param[in] policy @ref InterpolationPolicy to be translated to string.
934 *
935 * @return The string describing the interpolation policy.
936 */
937const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
938/** Translates a given border mode policy to a string.
939 *
940 * @param[in] border_mode @ref BorderMode to be translated to string.
941 *
942 * @return The string describing the border mode.
943 */
944const std::string &string_from_border_mode(BorderMode border_mode);
945/** Translates a given normalization type to a string.
946 *
947 * @param[in] type @ref NormType to be translated to string.
948 *
949 * @return The string describing the normalization type.
950 */
951const std::string &string_from_norm_type(NormType type);
Georgios Pinitascdf51452017-08-31 14:21:36 +0100952/** Translates a given pooling type to a string.
953 *
954 * @param[in] type @ref PoolingType to be translated to string.
955 *
956 * @return The string describing the pooling type.
957 */
958const std::string &string_from_pooling_type(PoolingType type);
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100959/** Translates a given GEMMLowp output stage to a string.
960 *
961 * @param[in] output_stage @ref GEMMLowpOutputStageInfo to be translated to string.
962 *
963 * @return The string describing the GEMMLowp output stage
964 */
965const std::string &string_from_gemmlowp_output_stage(GEMMLowpOutputStageType output_stage);
Giuseppe Rossinid7647d42018-07-17 18:13:13 +0100966/** Convert a PixelValue to a string, represented through the specific data type
967 *
968 * @param[in] value The PixelValue to convert
969 * @param[in] data_type The type to be used to convert the @p value
970 *
971 * @return String representation of the PixelValue through the given data type.
972 */
973std::string string_from_pixel_value(const PixelValue &value, const DataType data_type);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100974/** Lower a given string.
975 *
976 * @param[in] val Given string to lower.
977 *
978 * @return The lowered string
979 */
980std::string lower_string(const std::string &val);
981
982/** Check if a given data type is of floating point type
983 *
984 * @param[in] dt Input data type.
985 *
986 * @return True if data type is of floating point type, else false.
987 */
988inline bool is_data_type_float(DataType dt)
989{
990 switch(dt)
991 {
992 case DataType::F16:
993 case DataType::F32:
994 return true;
995 default:
996 return false;
997 }
998}
999
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001000/** Check if a given data type is of quantized type
1001 *
1002 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
1003 *
1004 * @param[in] dt Input data type.
1005 *
1006 * @return True if data type is of quantized type, else false.
1007 */
1008inline bool is_data_type_quantized(DataType dt)
1009{
1010 switch(dt)
1011 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001012 case DataType::QSYMM8:
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001013 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001014 case DataType::QSYMM8_PER_CHANNEL:
Manuel Bottini3689fcd2019-06-14 17:18:12 +01001015 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001016 return true;
1017 default:
1018 return false;
1019 }
1020}
1021
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001022/** Check if a given data type is of asymmetric quantized type
1023 *
1024 * @param[in] dt Input data type.
1025 *
Michele Di Giorgio6997fc92019-06-18 10:23:22 +01001026 * @return True if data type is of asymmetric quantized type, else false.
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001027 */
Anton Lokhmotovaf6204c2017-11-08 09:34:19 +00001028inline bool is_data_type_quantized_asymmetric(DataType dt)
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001029{
1030 switch(dt)
1031 {
1032 case DataType::QASYMM8:
1033 return true;
1034 default:
1035 return false;
1036 }
1037}
1038
Michele Di Giorgio6997fc92019-06-18 10:23:22 +01001039/** Check if a given data type is of symmetric quantized type
1040 *
1041 * @param[in] dt Input data type.
1042 *
1043 * @return True if data type is of symmetric quantized type, else false.
1044 */
1045inline bool is_data_type_quantized_symmetric(DataType dt)
1046{
1047 switch(dt)
1048 {
1049 case DataType::QSYMM8:
1050 case DataType::QSYMM8_PER_CHANNEL:
1051 case DataType::QSYMM16:
1052 return true;
1053 default:
1054 return false;
1055 }
1056}
1057
Georgios Pinitas89010962017-08-04 14:58:27 +01001058/** Create a string with the float in full precision.
1059 *
1060 * @param val Floating point value
1061 *
1062 * @return String with the floating point value.
1063 */
1064inline std::string float_to_string_with_full_precision(float val)
1065{
1066 std::stringstream ss;
Georgios Pinitas7900a9e2018-11-23 11:44:58 +00001067 ss.precision(std::numeric_limits<float>::max_digits10);
Georgios Pinitas89010962017-08-04 14:58:27 +01001068 ss << val;
Giorgio Arena73023022018-09-04 14:55:55 +01001069
1070 if(val != static_cast<int>(val))
1071 {
1072 ss << "f";
1073 }
1074
Georgios Pinitas89010962017-08-04 14:58:27 +01001075 return ss.str();
1076}
1077
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001078/** Returns the number of elements required to go from start to end with the wanted step
1079 *
1080 * @param[in] start start value
1081 * @param[in] end end value
1082 * @param[in] step step value between each number in the wanted sequence
1083 *
1084 * @return number of elements to go from start value to end value using the wanted step
1085 */
1086inline size_t num_of_elements_in_range(const float start, const float end, const float step)
1087{
1088 ARM_COMPUTE_ERROR_ON_MSG(step == 0, "Range Step cannot be 0");
1089 return size_t(std::ceil((end - start) / step));
1090}
1091
1092/** Returns true if the value can be represented by the given data type
1093 *
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001094 * @param[in] val value to be checked
1095 * @param[in] dt data type that is checked
1096 * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001097 *
1098 * @return true if the data type can hold the value.
1099 */
1100template <typename T>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001101bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo())
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001102{
1103 switch(dt)
1104 {
1105 case DataType::U8:
1106 return ((static_cast<uint8_t>(val) == val) && val >= std::numeric_limits<uint8_t>::lowest() && val <= std::numeric_limits<uint8_t>::max());
1107 case DataType::QASYMM8:
1108 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001109 double min = static_cast<double>(dequantize_qasymm8(0, qinfo));
1110 double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo));
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001111 return ((double)val >= min && (double)val <= max);
1112 }
1113 case DataType::S8:
1114 return ((static_cast<int8_t>(val) == val) && val >= std::numeric_limits<int8_t>::lowest() && val <= std::numeric_limits<int8_t>::max());
1115 case DataType::U16:
1116 return ((static_cast<uint16_t>(val) == val) && val >= std::numeric_limits<uint16_t>::lowest() && val <= std::numeric_limits<uint16_t>::max());
1117 case DataType::S16:
1118 return ((static_cast<int16_t>(val) == val) && val >= std::numeric_limits<int16_t>::lowest() && val <= std::numeric_limits<int16_t>::max());
1119 case DataType::U32:
1120 return ((static_cast<uint32_t>(val) == val) && val >= std::numeric_limits<uint32_t>::lowest() && val <= std::numeric_limits<uint32_t>::max());
1121 case DataType::S32:
1122 return ((static_cast<int32_t>(val) == val) && val >= std::numeric_limits<int32_t>::lowest() && val <= std::numeric_limits<int32_t>::max());
1123 case DataType::U64:
1124 return (val >= std::numeric_limits<uint64_t>::lowest() && val <= std::numeric_limits<uint64_t>::max());
1125 case DataType::S64:
1126 return (val >= std::numeric_limits<int64_t>::lowest() && val <= std::numeric_limits<int64_t>::max());
1127 case DataType::F16:
1128 return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
1129 case DataType::F32:
1130 return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
1131 case DataType::F64:
1132 return (val >= std::numeric_limits<double>::lowest() && val <= std::numeric_limits<double>::max());
1133 case DataType::SIZET:
1134 return ((static_cast<size_t>(val) == val) && val >= std::numeric_limits<size_t>::lowest() && val <= std::numeric_limits<size_t>::max());
1135 default:
1136 ARM_COMPUTE_ERROR("Data type not supported");
1137 return false;
1138 }
1139}
1140
giuros01edc21e42018-11-16 14:45:31 +00001141#ifdef ARM_COMPUTE_ASSERTS_ENABLED
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142/** Print consecutive elements to an output stream.
1143 *
1144 * @param[out] s Output stream to print the elements to.
1145 * @param[in] ptr Pointer to print the elements from.
1146 * @param[in] n Number of elements to print.
1147 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1148 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1149 */
1150template <typename T>
1151void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1152{
1153 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
Michalis Spyrou53860dd2019-07-01 14:20:56 +01001154 std::ios stream_status(nullptr);
1155 stream_status.copyfmt(s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001156
1157 for(unsigned int i = 0; i < n; ++i)
1158 {
1159 // Set stream width as it is not a "sticky" stream manipulator
1160 if(stream_width != 0)
1161 {
1162 s.width(stream_width);
1163 }
Anthony Barbier7068f992017-10-26 15:23:08 +01001164
1165 if(std::is_same<typename std::decay<T>::type, half>::value)
1166 {
1167 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1168 s << std::right << static_cast<T>(ptr[i]) << element_delim;
1169 }
1170 else
1171 {
1172 s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1173 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001174 }
Michalis Spyrou53860dd2019-07-01 14:20:56 +01001175
1176 // Restore output stream flags
1177 s.copyfmt(stream_status);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001178}
1179
1180/** Identify the maximum width of n consecutive elements.
1181 *
1182 * @param[in] s The output stream which will be used to print the elements. Used to extract the stream format.
1183 * @param[in] ptr Pointer to the elements.
1184 * @param[in] n Number of elements.
1185 *
1186 * @return The maximum width of the elements.
1187 */
1188template <typename T>
1189int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1190{
1191 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1192
1193 int max_width = -1;
1194 for(unsigned int i = 0; i < n; ++i)
1195 {
1196 std::stringstream ss;
1197 ss.copyfmt(s);
Anthony Barbier7068f992017-10-26 15:23:08 +01001198
1199 if(std::is_same<typename std::decay<T>::type, half>::value)
1200 {
1201 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1202 ss << static_cast<T>(ptr[i]);
1203 }
1204 else
1205 {
1206 ss << static_cast<print_type>(ptr[i]);
1207 }
1208
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001209 max_width = std::max<int>(max_width, ss.str().size());
1210 }
1211 return max_width;
1212}
1213
1214/** Print consecutive elements to an output stream.
1215 *
1216 * @param[out] s Output stream to print the elements to.
1217 * @param[in] dt Data type of the elements
1218 * @param[in] ptr Pointer to print the elements from.
1219 * @param[in] n Number of elements to print.
1220 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1221 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1222 */
1223void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1224
1225/** Identify the maximum width of n consecutive elements.
1226 *
1227 * @param[in] s Output stream to print the elements to.
1228 * @param[in] dt Data type of the elements
1229 * @param[in] ptr Pointer to print the elements from.
1230 * @param[in] n Number of elements to print.
1231 *
1232 * @return The maximum width of the elements.
1233 */
1234int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
giuros01edc21e42018-11-16 14:45:31 +00001235#endif /* ARM_COMPUTE_ASSERTS_ENABLED */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001236}
1237#endif /*__ARM_COMPUTE_UTILS_H__ */