blob: 2640264f55a8a3c2ef23e2f954452daa1deb719a [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrouf63885b2019-01-16 14:18:09 +00002 * Copyright (c) 2016-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_UTILS_H__
25#define __ARM_COMPUTE_UTILS_H__
26
27#include "arm_compute/core/Error.h"
Giuseppe Rossinid7647d42018-07-17 18:13:13 +010028#include "arm_compute/core/PixelValue.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000029#include "arm_compute/core/Rounding.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Types.h"
31
32#include <algorithm>
33#include <cstdint>
34#include <cstdlib>
35#include <numeric>
36#include <sstream>
37#include <string>
38#include <type_traits>
39#include <utility>
steniu017ce53c62017-09-29 14:55:00 +010040#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010041
42namespace arm_compute
43{
Alex Gildayc357c472018-03-21 13:54:09 +000044/** Calculate the rounded up quotient of val / m.
45 *
46 * @param[in] val Value to divide and round up.
47 * @param[in] m Value to divide by.
48 *
49 * @return the result.
50 */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000051template <typename S, typename T>
52constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
53{
54 return (val + m - 1) / m;
55}
56
Alex Gildayc357c472018-03-21 13:54:09 +000057/** Computes the smallest number larger or equal to value that is a multiple of divisor.
58 *
59 * @param[in] value Lower bound value
60 * @param[in] divisor Value to compute multiple of.
61 *
62 * @return the result.
63 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010064template <typename S, typename T>
65inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
66{
67 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000068 return DIV_CEIL(value, divisor) * divisor;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069}
70
Alex Gildayc357c472018-03-21 13:54:09 +000071/** Computes the largest number smaller or equal to value that is a multiple of divisor.
72 *
73 * @param[in] value Upper bound value
74 * @param[in] divisor Value to compute multiple of.
75 *
76 * @return the result.
77 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010078template <typename S, typename T>
79inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
80{
81 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
82 return (value / divisor) * divisor;
83}
84
Anthony Barbier6ff3b192017-09-04 18:44:23 +010085/** Returns the arm_compute library build information
86 *
87 * Contains the version number and the build options used to build the library
88 *
89 * @return The arm_compute library build information
90 */
91std::string build_information();
92
93/** Load an entire file in memory
94 *
95 * @param[in] filename Name of the file to read.
96 * @param[in] binary Is it a binary file ?
97 *
98 * @return The content of the file.
99 */
100std::string read_file(const std::string &filename, bool binary);
101
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100102/** The size in bytes of the data type
103 *
104 * @param[in] data_type Input data type
105 *
106 * @return The size in bytes of the data type
107 */
108inline size_t data_size_from_type(DataType data_type)
109{
110 switch(data_type)
111 {
112 case DataType::U8:
113 case DataType::S8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100114 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115 return 1;
116 case DataType::U16:
117 case DataType::S16:
118 case DataType::F16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100119 return 2;
120 case DataType::F32:
121 case DataType::U32:
122 case DataType::S32:
123 return 4;
124 case DataType::F64:
125 case DataType::U64:
126 case DataType::S64:
127 return 8;
128 case DataType::SIZET:
129 return sizeof(size_t);
130 default:
131 ARM_COMPUTE_ERROR("Invalid data type");
132 return 0;
133 }
134}
135
136/** The size in bytes of the pixel format
137 *
138 * @param[in] format Input format
139 *
140 * @return The size in bytes of the pixel format
141 */
142inline size_t pixel_size_from_format(Format format)
143{
144 switch(format)
145 {
146 case Format::U8:
147 return 1;
148 case Format::U16:
149 case Format::S16:
150 case Format::F16:
151 case Format::UV88:
152 case Format::YUYV422:
153 case Format::UYVY422:
154 return 2;
155 case Format::RGB888:
156 return 3;
157 case Format::RGBA8888:
158 return 4;
159 case Format::U32:
160 case Format::S32:
161 case Format::F32:
162 return 4;
163 //Doesn't make sense for planar formats:
164 case Format::NV12:
165 case Format::NV21:
166 case Format::IYUV:
167 case Format::YUV444:
168 default:
169 ARM_COMPUTE_ERROR("Undefined pixel size for given format");
170 return 0;
171 }
172}
173
174/** The size in bytes of the data type
175 *
176 * @param[in] dt Input data type
177 *
178 * @return The size in bytes of the data type
179 */
180inline size_t element_size_from_data_type(DataType dt)
181{
182 switch(dt)
183 {
184 case DataType::S8:
185 case DataType::U8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100186 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100187 return 1;
188 case DataType::U16:
189 case DataType::S16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190 case DataType::F16:
191 return 2;
192 case DataType::U32:
193 case DataType::S32:
194 case DataType::F32:
195 return 4;
196 default:
197 ARM_COMPUTE_ERROR("Undefined element size for given data type");
198 return 0;
199 }
200}
201
202/** Return the data type used by a given single-planar pixel format
203 *
204 * @param[in] format Input format
205 *
206 * @return The size in bytes of the pixel format
207 */
208inline DataType data_type_from_format(Format format)
209{
210 switch(format)
211 {
212 case Format::U8:
213 case Format::UV88:
214 case Format::RGB888:
215 case Format::RGBA8888:
216 case Format::YUYV422:
217 case Format::UYVY422:
218 return DataType::U8;
219 case Format::U16:
220 return DataType::U16;
221 case Format::S16:
222 return DataType::S16;
223 case Format::U32:
224 return DataType::U32;
225 case Format::S32:
226 return DataType::S32;
227 case Format::F16:
228 return DataType::F16;
229 case Format::F32:
230 return DataType::F32;
231 //Doesn't make sense for planar formats:
232 case Format::NV12:
233 case Format::NV21:
234 case Format::IYUV:
235 case Format::YUV444:
236 default:
237 ARM_COMPUTE_ERROR("Not supported data_type for given format");
238 return DataType::UNKNOWN;
239 }
240}
241
242/** Return the plane index of a given channel given an input format.
243 *
244 * @param[in] format Input format
245 * @param[in] channel Input channel
246 *
247 * @return The plane index of the specific channel of the specific format
248 */
249inline int plane_idx_from_channel(Format format, Channel channel)
250{
251 switch(format)
252 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100253 // Single planar formats have a single plane
254 case Format::U8:
255 case Format::U16:
256 case Format::S16:
257 case Format::U32:
258 case Format::S32:
259 case Format::F16:
260 case Format::F32:
261 case Format::UV88:
262 case Format::RGB888:
263 case Format::RGBA8888:
264 case Format::YUYV422:
265 case Format::UYVY422:
266 return 0;
267 // Multi planar formats
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100268 case Format::NV12:
269 case Format::NV21:
270 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100271 // Channel U and V share the same plane of format UV88
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100272 switch(channel)
273 {
274 case Channel::Y:
275 return 0;
276 case Channel::U:
277 case Channel::V:
278 return 1;
279 default:
280 ARM_COMPUTE_ERROR("Not supported channel");
281 return 0;
282 }
283 }
284 case Format::IYUV:
285 case Format::YUV444:
286 {
287 switch(channel)
288 {
289 case Channel::Y:
290 return 0;
291 case Channel::U:
292 return 1;
293 case Channel::V:
294 return 2;
295 default:
296 ARM_COMPUTE_ERROR("Not supported channel");
297 return 0;
298 }
299 }
300 default:
301 ARM_COMPUTE_ERROR("Not supported format");
302 return 0;
303 }
304}
305
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100306/** Return the channel index of a given channel given an input format.
307 *
308 * @param[in] format Input format
309 * @param[in] channel Input channel
310 *
311 * @return The channel index of the specific channel of the specific format
312 */
313inline int channel_idx_from_format(Format format, Channel channel)
314{
315 switch(format)
316 {
317 case Format::RGB888:
318 {
319 switch(channel)
320 {
321 case Channel::R:
322 return 0;
323 case Channel::G:
324 return 1;
325 case Channel::B:
326 return 2;
327 default:
328 ARM_COMPUTE_ERROR("Not supported channel");
329 return 0;
330 }
331 }
332 case Format::RGBA8888:
333 {
334 switch(channel)
335 {
336 case Channel::R:
337 return 0;
338 case Channel::G:
339 return 1;
340 case Channel::B:
341 return 2;
342 case Channel::A:
343 return 3;
344 default:
345 ARM_COMPUTE_ERROR("Not supported channel");
346 return 0;
347 }
348 }
349 case Format::YUYV422:
350 {
351 switch(channel)
352 {
353 case Channel::Y:
354 return 0;
355 case Channel::U:
356 return 1;
357 case Channel::V:
358 return 3;
359 default:
360 ARM_COMPUTE_ERROR("Not supported channel");
361 return 0;
362 }
363 }
364 case Format::UYVY422:
365 {
366 switch(channel)
367 {
368 case Channel::Y:
369 return 1;
370 case Channel::U:
371 return 0;
372 case Channel::V:
373 return 2;
374 default:
375 ARM_COMPUTE_ERROR("Not supported channel");
376 return 0;
377 }
378 }
379 case Format::NV12:
380 {
381 switch(channel)
382 {
383 case Channel::Y:
384 return 0;
385 case Channel::U:
386 return 0;
387 case Channel::V:
388 return 1;
389 default:
390 ARM_COMPUTE_ERROR("Not supported channel");
391 return 0;
392 }
393 }
394 case Format::NV21:
395 {
396 switch(channel)
397 {
398 case Channel::Y:
399 return 0;
400 case Channel::U:
401 return 1;
402 case Channel::V:
403 return 0;
404 default:
405 ARM_COMPUTE_ERROR("Not supported channel");
406 return 0;
407 }
408 }
409 case Format::YUV444:
410 case Format::IYUV:
411 {
412 switch(channel)
413 {
414 case Channel::Y:
415 return 0;
416 case Channel::U:
417 return 0;
418 case Channel::V:
419 return 0;
420 default:
421 ARM_COMPUTE_ERROR("Not supported channel");
422 return 0;
423 }
424 }
425 default:
426 ARM_COMPUTE_ERROR("Not supported format");
427 return 0;
428 }
429}
430
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431/** Return the number of planes for a given format
432 *
433 * @param[in] format Input format
434 *
435 * @return The number of planes for a given image format.
436 */
437inline size_t num_planes_from_format(Format format)
438{
439 switch(format)
440 {
441 case Format::U8:
442 case Format::S16:
443 case Format::U16:
444 case Format::S32:
445 case Format::U32:
446 case Format::F16:
447 case Format::F32:
448 case Format::RGB888:
449 case Format::RGBA8888:
450 case Format::YUYV422:
451 case Format::UYVY422:
452 return 1;
453 case Format::NV12:
454 case Format::NV21:
455 return 2;
456 case Format::IYUV:
457 case Format::YUV444:
458 return 3;
459 default:
460 ARM_COMPUTE_ERROR("Not supported format");
461 return 0;
462 }
463}
464
465/** Return the number of channels for a given single-planar pixel format
466 *
467 * @param[in] format Input format
468 *
469 * @return The number of channels for a given image format.
470 */
471inline size_t num_channels_from_format(Format format)
472{
473 switch(format)
474 {
475 case Format::U8:
476 case Format::U16:
477 case Format::S16:
478 case Format::U32:
479 case Format::S32:
480 case Format::F16:
481 case Format::F32:
482 return 1;
483 // Because the U and V channels are subsampled
484 // these formats appear like having only 2 channels:
485 case Format::YUYV422:
486 case Format::UYVY422:
487 return 2;
488 case Format::UV88:
489 return 2;
490 case Format::RGB888:
491 return 3;
492 case Format::RGBA8888:
493 return 4;
494 //Doesn't make sense for planar formats:
495 case Format::NV12:
496 case Format::NV21:
497 case Format::IYUV:
498 case Format::YUV444:
499 default:
500 return 0;
501 }
502}
503
Chunosovd621bca2017-11-03 17:33:15 +0700504/** Return the promoted data type of a given data type.
505 *
506 * @note If promoted data type is not supported an error will be thrown
507 *
508 * @param[in] dt Data type to get the promoted type of.
509 *
510 * @return Promoted data type
511 */
512inline DataType get_promoted_data_type(DataType dt)
513{
514 switch(dt)
515 {
516 case DataType::U8:
517 return DataType::U16;
518 case DataType::S8:
519 return DataType::S16;
Chunosovd621bca2017-11-03 17:33:15 +0700520 case DataType::U16:
521 return DataType::U32;
522 case DataType::S16:
523 return DataType::S32;
Chunosovd621bca2017-11-03 17:33:15 +0700524 case DataType::QASYMM8:
525 case DataType::F16:
526 case DataType::U32:
527 case DataType::S32:
528 case DataType::F32:
Chunosovd621bca2017-11-03 17:33:15 +0700529 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
530 default:
531 ARM_COMPUTE_ERROR("Undefined data type!");
532 }
533 return DataType::UNKNOWN;
534}
535
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100536/** Return true if the given format has horizontal subsampling.
537 *
538 * @param[in] format Format to determine subsampling.
539 *
540 * @return True if the format can be subsampled horizontaly.
541 */
542inline bool has_format_horizontal_subsampling(Format format)
543{
544 return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
545}
546
547/** Return true if the given format has vertical subsampling.
548 *
549 * @param[in] format Format to determine subsampling.
550 *
551 * @return True if the format can be subsampled verticaly.
552 */
553inline bool has_format_vertical_subsampling(Format format)
554{
555 return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
556}
557
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100558/** Separate a 2D convolution into two 1D convolutions
Anthony Barbierf202e502017-11-23 18:02:04 +0000559 *
560 * @param[in] conv 2D convolution
561 * @param[out] conv_col 1D vertical convolution
562 * @param[out] conv_row 1D horizontal convolution
563 * @param[in] size Size of the 2D convolution
564 *
565 * @return true if the separation was successful
566 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
568{
569 int32_t min_col = -1;
570 int16_t min_col_val = -1;
571
572 for(int32_t i = 0; i < size; ++i)
573 {
574 if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
575 {
576 min_col = i;
577 min_col_val = conv[i];
578 }
579 }
580
581 if(min_col < 0)
582 {
583 return false;
584 }
585
586 for(uint32_t j = 0; j < size; ++j)
587 {
588 conv_col[j] = conv[min_col + j * size];
589 }
590
591 for(uint32_t i = 0; i < size; i++)
592 {
593 if(static_cast<int>(i) == min_col)
594 {
595 conv_row[i] = 1;
596 }
597 else
598 {
599 int16_t coeff = conv[i] / conv[min_col];
600
601 for(uint32_t j = 1; j < size; ++j)
602 {
603 if(conv[i + j * size] != (conv_col[j] * coeff))
604 {
605 return false;
606 }
607 }
608
609 conv_row[i] = coeff;
610 }
611 }
612
613 return true;
614}
615
616/** Calculate the scale of the given square matrix
617 *
618 * The scale is the absolute value of the sum of all the coefficients in the matrix.
619 *
620 * @note If the coefficients add up to 0 then the scale is set to 1.
621 *
622 * @param[in] matrix Matrix coefficients
623 * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
624 *
625 * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
626 */
627inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
628{
629 const size_t size = matrix_size * matrix_size;
630
631 return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
632}
633
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100634/** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
635 *
636 * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
Manuel Bottini581c8982019-02-07 10:31:57 +0000637 * <a href="https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java">Android Source</a>
638 * <a href="https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM">WebM</a>
639 * <a href="https://bugs.chromium.org/p/libyuv/issues/detail?id=198&amp;can=1&amp;q=odd%20width">libYUV</a>
640 * <a href="https://sourceforge.net/p/raw-yuvplayer/bugs/1/">YUVPlayer</a> *
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100641 *
642 * @param[in, out] shape Tensor shape of 2D size
643 * @param[in] format Format of the tensor
644 *
Alex Gildayc357c472018-03-21 13:54:09 +0000645 * @return The adjusted tensor shape.
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100646 */
647inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
648{
649 TensorShape output{ shape };
650
651 // Force width to be even for formats which require subsampling of the U and V channels
652 if(has_format_horizontal_subsampling(format))
653 {
654 output.set(0, output.x() & ~1U);
655 }
656
657 // Force height to be even for formats which require subsampling of the U and V channels
658 if(has_format_vertical_subsampling(format))
659 {
660 output.set(1, output.y() & ~1U);
661 }
662
663 return output;
664}
665
666/** Calculate subsampled shape for a given format and channel
667 *
668 * @param[in] shape Shape of the tensor to calculate the extracted channel.
669 * @param[in] format Format of the tensor.
670 * @param[in] channel Channel to create tensor shape to be extracted.
671 *
672 * @return The subsampled tensor shape.
673 */
674inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
675{
676 TensorShape output{ shape };
677
678 // Subsample shape only for U or V channel
679 if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
680 {
681 // Subsample width for the tensor shape when channel is U or V
682 if(has_format_horizontal_subsampling(format))
683 {
684 output.set(0, output.x() / 2U);
685 }
686
687 // Subsample height for the tensor shape when channel is U or V
688 if(has_format_vertical_subsampling(format))
689 {
690 output.set(1, output.y() / 2U);
691 }
692 }
693
694 return output;
695}
696
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100697/** Calculate accurary required by the horizontal and vertical convolution computations
698 *
699 * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
700 * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
701 * @param[in] size Number of elements per vector of the separated matrix
702 *
703 * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
704 * element of the pair is the biggest data type needed for the second stage.
705 */
706inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
707{
708 DataType first_stage = DataType::UNKNOWN;
709 DataType second_stage = DataType::UNKNOWN;
710
711 auto gez = [](const int16_t &v)
712 {
713 return v >= 0;
714 };
715
716 auto accu_neg = [](const int &first, const int &second)
717 {
718 return first + (second < 0 ? second : 0);
719 };
720
721 auto accu_pos = [](const int &first, const int &second)
722 {
723 return first + (second > 0 ? second : 0);
724 };
725
726 const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
727
728 if(only_positive_coefficients)
729 {
730 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
731 const int max_value = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
732
733 first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
734
735 second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
736 }
737 else
738 {
739 const int min_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
740 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
741 const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
742 const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
743 const int min_value = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
744 const int max_value = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
745
746 first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
747
748 second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
749 }
750
751 return std::make_pair(first_stage, second_stage);
752}
753
754/** Calculate the accuracy required by the squared convolution calculation.
755 *
756 *
757 * @param[in] conv Pointer to the squared convolution matrix
758 * @param[in] size The total size of the convolution matrix
759 *
760 * @return The return is the biggest data type needed to do the convolution
761 */
762inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
763{
764 auto gez = [](const int16_t v)
765 {
766 return v >= 0;
767 };
768
769 const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
770
771 if(only_positive_coefficients)
772 {
773 const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
774 if(max_conv_value <= UINT16_MAX)
775 {
776 return DataType::U16;
777 }
778 else
779 {
780 return DataType::S32;
781 }
782 }
783 else
784 {
785 const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
786 {
787 return b < 0 ? a + b : a;
788 })
789 * UINT8_MAX;
790
791 const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
792 {
793 return b > 0 ? a + b : a;
794 })
795 * UINT8_MAX;
796
797 if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
798 {
799 return DataType::S16;
800 }
801 else
802 {
803 return DataType::S32;
804 }
805 }
806}
807
Pablo Tello35767bc2018-12-05 17:36:30 +0000808/** Permutes the given dimensions according the permutation vector
809 *
810 * @param[in,out] dimensions Dimensions to be permuted.
811 * @param[in] perm Vector describing the permutation.
812 *
813 */
814template <typename T>
815inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
816{
817 const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
818 for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
819 {
820 T dimension_val = old_dim[i];
821 dimensions.set(perm[i], dimension_val);
822 }
823}
824
Georgios Pinitas4074c992018-01-30 18:13:46 +0000825/** Calculate padding requirements in case of SAME padding
826 *
827 * @param[in] input_shape Input shape
828 * @param[in] weights_shape Weights shape
829 * @param[in] conv_info Convolution information (containing strides)
Isabella Gottardi6a914402019-01-30 15:45:42 +0000830 * @param[in] data_layout (Optional) Data layout of the input and weights tensor
Georgios Pinitas4074c992018-01-30 18:13:46 +0000831 *
832 * @return PadStrideInfo for SAME padding
833 */
Isabella Gottardi6a914402019-01-30 15:45:42 +0000834PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW);
Georgios Pinitas4074c992018-01-30 18:13:46 +0000835
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100836/** Returns expected width and height of the deconvolution's output tensor.
837 *
Michalis Spyrouafbc5ff2018-10-03 14:18:19 +0100838 * @param[in] in_width Width of input tensor (Number of columns)
839 * @param[in] in_height Height of input tensor (Number of rows)
840 * @param[in] kernel_width Kernel width.
841 * @param[in] kernel_height Kernel height.
842 * @param[in] padx X axis padding.
843 * @param[in] pady Y axis padding.
844 * @param[in] stride_x X axis input stride.
845 * @param[in] stride_y Y axis input stride.
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100846 *
847 * @return A pair with the new width in the first position and the new height in the second.
848 */
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100849const std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
850 unsigned int kernel_width, unsigned int kernel_height,
Michalis Spyrouafbc5ff2018-10-03 14:18:19 +0100851 unsigned int padx, unsigned int pady,
Michalis Spyrou780db4e2017-11-23 09:49:51 +0000852 unsigned int stride_x, unsigned int stride_y);
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100853
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100854/** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
855 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100856 * @param[in] width Width of input tensor (Number of columns)
857 * @param[in] height Height of input tensor (Number of rows)
858 * @param[in] kernel_width Kernel width.
859 * @param[in] kernel_height Kernel height.
860 * @param[in] pad_stride_info Pad and stride information.
Alex Gilday7da29b62018-03-23 14:16:00 +0000861 * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100862 *
863 * @return A pair with the new width in the first position and the new height in the second.
864 */
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100865const std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsigned int height,
866 unsigned int kernel_width, unsigned int kernel_height,
Alex Gilday7da29b62018-03-23 14:16:00 +0000867 const PadStrideInfo &pad_stride_info,
868 const Size2D &dilation = Size2D(1U, 1U));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100869
870/** Convert a tensor format into a string.
871 *
872 * @param[in] format @ref Format to be translated to string.
873 *
874 * @return The string describing the format.
875 */
876const std::string &string_from_format(Format format);
877
878/** Convert a channel identity into a string.
879 *
880 * @param[in] channel @ref Channel to be translated to string.
881 *
882 * @return The string describing the channel.
883 */
884const std::string &string_from_channel(Channel channel);
Michele Di Giorgiobf3c6622018-03-08 11:52:27 +0000885/** Convert a data layout identity into a string.
886 *
887 * @param[in] dl @ref DataLayout to be translated to string.
888 *
889 * @return The string describing the data layout.
890 */
891const std::string &string_from_data_layout(DataLayout dl);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100892/** Convert a data type identity into a string.
893 *
894 * @param[in] dt @ref DataType to be translated to string.
895 *
896 * @return The string describing the data type.
897 */
898const std::string &string_from_data_type(DataType dt);
899/** Convert a matrix pattern into a string.
900 *
901 * @param[in] pattern @ref MatrixPattern to be translated to string.
902 *
903 * @return The string describing the matrix pattern.
904 */
905const std::string &string_from_matrix_pattern(MatrixPattern pattern);
906/** Translates a given activation function to a string.
907 *
908 * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
909 *
910 * @return The string describing the activation function.
911 */
912const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
913/** Translates a given non linear function to a string.
914 *
915 * @param[in] function @ref NonLinearFilterFunction to be translated to string.
916 *
917 * @return The string describing the non linear function.
918 */
919const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
920/** Translates a given interpolation policy to a string.
921 *
922 * @param[in] policy @ref InterpolationPolicy to be translated to string.
923 *
924 * @return The string describing the interpolation policy.
925 */
926const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
927/** Translates a given border mode policy to a string.
928 *
929 * @param[in] border_mode @ref BorderMode to be translated to string.
930 *
931 * @return The string describing the border mode.
932 */
933const std::string &string_from_border_mode(BorderMode border_mode);
934/** Translates a given normalization type to a string.
935 *
936 * @param[in] type @ref NormType to be translated to string.
937 *
938 * @return The string describing the normalization type.
939 */
940const std::string &string_from_norm_type(NormType type);
Georgios Pinitascdf51452017-08-31 14:21:36 +0100941/** Translates a given pooling type to a string.
942 *
943 * @param[in] type @ref PoolingType to be translated to string.
944 *
945 * @return The string describing the pooling type.
946 */
947const std::string &string_from_pooling_type(PoolingType type);
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100948/** Translates a given GEMMLowp output stage to a string.
949 *
950 * @param[in] output_stage @ref GEMMLowpOutputStageInfo to be translated to string.
951 *
952 * @return The string describing the GEMMLowp output stage
953 */
954const std::string &string_from_gemmlowp_output_stage(GEMMLowpOutputStageType output_stage);
Giuseppe Rossinid7647d42018-07-17 18:13:13 +0100955/** Convert a PixelValue to a string, represented through the specific data type
956 *
957 * @param[in] value The PixelValue to convert
958 * @param[in] data_type The type to be used to convert the @p value
959 *
960 * @return String representation of the PixelValue through the given data type.
961 */
962std::string string_from_pixel_value(const PixelValue &value, const DataType data_type);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100963/** Lower a given string.
964 *
965 * @param[in] val Given string to lower.
966 *
967 * @return The lowered string
968 */
969std::string lower_string(const std::string &val);
970
971/** Check if a given data type is of floating point type
972 *
973 * @param[in] dt Input data type.
974 *
975 * @return True if data type is of floating point type, else false.
976 */
977inline bool is_data_type_float(DataType dt)
978{
979 switch(dt)
980 {
981 case DataType::F16:
982 case DataType::F32:
983 return true;
984 default:
985 return false;
986 }
987}
988
Georgios Pinitas05078ec2017-11-02 13:06:59 +0000989/** Check if a given data type is of quantized type
990 *
991 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
992 *
993 * @param[in] dt Input data type.
994 *
995 * @return True if data type is of quantized type, else false.
996 */
997inline bool is_data_type_quantized(DataType dt)
998{
999 switch(dt)
1000 {
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001001 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001002 return true;
1003 default:
1004 return false;
1005 }
1006}
1007
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001008/** Check if a given data type is of asymmetric quantized type
1009 *
1010 * @param[in] dt Input data type.
1011 *
1012 * @return True if data type is of symmetric quantized type, else false.
1013 */
Anton Lokhmotovaf6204c2017-11-08 09:34:19 +00001014inline bool is_data_type_quantized_asymmetric(DataType dt)
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001015{
1016 switch(dt)
1017 {
1018 case DataType::QASYMM8:
1019 return true;
1020 default:
1021 return false;
1022 }
1023}
1024
Georgios Pinitas89010962017-08-04 14:58:27 +01001025/** Create a string with the float in full precision.
1026 *
1027 * @param val Floating point value
1028 *
1029 * @return String with the floating point value.
1030 */
1031inline std::string float_to_string_with_full_precision(float val)
1032{
1033 std::stringstream ss;
Georgios Pinitas7900a9e2018-11-23 11:44:58 +00001034 ss.precision(std::numeric_limits<float>::max_digits10);
Georgios Pinitas89010962017-08-04 14:58:27 +01001035 ss << val;
Giorgio Arena73023022018-09-04 14:55:55 +01001036
1037 if(val != static_cast<int>(val))
1038 {
1039 ss << "f";
1040 }
1041
Georgios Pinitas89010962017-08-04 14:58:27 +01001042 return ss.str();
1043}
1044
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001045/** Returns the number of elements required to go from start to end with the wanted step
1046 *
1047 * @param[in] start start value
1048 * @param[in] end end value
1049 * @param[in] step step value between each number in the wanted sequence
1050 *
1051 * @return number of elements to go from start value to end value using the wanted step
1052 */
1053inline size_t num_of_elements_in_range(const float start, const float end, const float step)
1054{
1055 ARM_COMPUTE_ERROR_ON_MSG(step == 0, "Range Step cannot be 0");
1056 return size_t(std::ceil((end - start) / step));
1057}
1058
1059/** Returns true if the value can be represented by the given data type
1060 *
1061 * @param[in] val value to be checked
1062 * @param[in] dt data type that is checked
1063 * @param[in] quant_info quantization info if the data type is QASYMM8
1064 *
1065 * @return true if the data type can hold the value.
1066 */
1067template <typename T>
1068bool check_value_range(T val, DataType dt, QuantizationInfo quant_info = QuantizationInfo())
1069{
1070 switch(dt)
1071 {
1072 case DataType::U8:
1073 return ((static_cast<uint8_t>(val) == val) && val >= std::numeric_limits<uint8_t>::lowest() && val <= std::numeric_limits<uint8_t>::max());
1074 case DataType::QASYMM8:
1075 {
1076 double min = static_cast<double>(quant_info.dequantize(0));
1077 double max = static_cast<double>(quant_info.dequantize(std::numeric_limits<uint8_t>::max()));
1078 return ((double)val >= min && (double)val <= max);
1079 }
1080 case DataType::S8:
1081 return ((static_cast<int8_t>(val) == val) && val >= std::numeric_limits<int8_t>::lowest() && val <= std::numeric_limits<int8_t>::max());
1082 case DataType::U16:
1083 return ((static_cast<uint16_t>(val) == val) && val >= std::numeric_limits<uint16_t>::lowest() && val <= std::numeric_limits<uint16_t>::max());
1084 case DataType::S16:
1085 return ((static_cast<int16_t>(val) == val) && val >= std::numeric_limits<int16_t>::lowest() && val <= std::numeric_limits<int16_t>::max());
1086 case DataType::U32:
1087 return ((static_cast<uint32_t>(val) == val) && val >= std::numeric_limits<uint32_t>::lowest() && val <= std::numeric_limits<uint32_t>::max());
1088 case DataType::S32:
1089 return ((static_cast<int32_t>(val) == val) && val >= std::numeric_limits<int32_t>::lowest() && val <= std::numeric_limits<int32_t>::max());
1090 case DataType::U64:
1091 return (val >= std::numeric_limits<uint64_t>::lowest() && val <= std::numeric_limits<uint64_t>::max());
1092 case DataType::S64:
1093 return (val >= std::numeric_limits<int64_t>::lowest() && val <= std::numeric_limits<int64_t>::max());
1094 case DataType::F16:
1095 return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
1096 case DataType::F32:
1097 return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
1098 case DataType::F64:
1099 return (val >= std::numeric_limits<double>::lowest() && val <= std::numeric_limits<double>::max());
1100 case DataType::SIZET:
1101 return ((static_cast<size_t>(val) == val) && val >= std::numeric_limits<size_t>::lowest() && val <= std::numeric_limits<size_t>::max());
1102 default:
1103 ARM_COMPUTE_ERROR("Data type not supported");
1104 return false;
1105 }
1106}
1107
giuros01edc21e42018-11-16 14:45:31 +00001108#ifdef ARM_COMPUTE_ASSERTS_ENABLED
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001109/** Print consecutive elements to an output stream.
1110 *
1111 * @param[out] s Output stream to print the elements to.
1112 * @param[in] ptr Pointer to print the elements from.
1113 * @param[in] n Number of elements to print.
1114 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1115 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1116 */
1117template <typename T>
1118void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1119{
1120 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1121
1122 for(unsigned int i = 0; i < n; ++i)
1123 {
1124 // Set stream width as it is not a "sticky" stream manipulator
1125 if(stream_width != 0)
1126 {
1127 s.width(stream_width);
1128 }
Anthony Barbier7068f992017-10-26 15:23:08 +01001129
1130 if(std::is_same<typename std::decay<T>::type, half>::value)
1131 {
1132 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1133 s << std::right << static_cast<T>(ptr[i]) << element_delim;
1134 }
1135 else
1136 {
1137 s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1138 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001139 }
1140}
1141
1142/** Identify the maximum width of n consecutive elements.
1143 *
1144 * @param[in] s The output stream which will be used to print the elements. Used to extract the stream format.
1145 * @param[in] ptr Pointer to the elements.
1146 * @param[in] n Number of elements.
1147 *
1148 * @return The maximum width of the elements.
1149 */
1150template <typename T>
1151int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1152{
1153 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1154
1155 int max_width = -1;
1156 for(unsigned int i = 0; i < n; ++i)
1157 {
1158 std::stringstream ss;
1159 ss.copyfmt(s);
Anthony Barbier7068f992017-10-26 15:23:08 +01001160
1161 if(std::is_same<typename std::decay<T>::type, half>::value)
1162 {
1163 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1164 ss << static_cast<T>(ptr[i]);
1165 }
1166 else
1167 {
1168 ss << static_cast<print_type>(ptr[i]);
1169 }
1170
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001171 max_width = std::max<int>(max_width, ss.str().size());
1172 }
1173 return max_width;
1174}
1175
1176/** Print consecutive elements to an output stream.
1177 *
1178 * @param[out] s Output stream to print the elements to.
1179 * @param[in] dt Data type of the elements
1180 * @param[in] ptr Pointer to print the elements from.
1181 * @param[in] n Number of elements to print.
1182 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1183 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1184 */
1185void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1186
1187/** Identify the maximum width of n consecutive elements.
1188 *
1189 * @param[in] s Output stream to print the elements to.
1190 * @param[in] dt Data type of the elements
1191 * @param[in] ptr Pointer to print the elements from.
1192 * @param[in] n Number of elements to print.
1193 *
1194 * @return The maximum width of the elements.
1195 */
1196int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
giuros01edc21e42018-11-16 14:45:31 +00001197#endif /* ARM_COMPUTE_ASSERTS_ENABLED */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001198}
1199#endif /*__ARM_COMPUTE_UTILS_H__ */