blob: b711451453574a10de007efd7c8240fe472465c7 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrouf63885b2019-01-16 14:18:09 +00002 * Copyright (c) 2016-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_UTILS_H__
25#define __ARM_COMPUTE_UTILS_H__
26
27#include "arm_compute/core/Error.h"
Giuseppe Rossinid7647d42018-07-17 18:13:13 +010028#include "arm_compute/core/PixelValue.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000029#include "arm_compute/core/Rounding.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Types.h"
31
32#include <algorithm>
33#include <cstdint>
34#include <cstdlib>
35#include <numeric>
36#include <sstream>
37#include <string>
38#include <type_traits>
39#include <utility>
steniu017ce53c62017-09-29 14:55:00 +010040#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010041
42namespace arm_compute
43{
Alex Gildayc357c472018-03-21 13:54:09 +000044/** Calculate the rounded up quotient of val / m.
45 *
46 * @param[in] val Value to divide and round up.
47 * @param[in] m Value to divide by.
48 *
49 * @return the result.
50 */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000051template <typename S, typename T>
52constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
53{
54 return (val + m - 1) / m;
55}
56
Alex Gildayc357c472018-03-21 13:54:09 +000057/** Computes the smallest number larger or equal to value that is a multiple of divisor.
58 *
59 * @param[in] value Lower bound value
60 * @param[in] divisor Value to compute multiple of.
61 *
62 * @return the result.
63 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010064template <typename S, typename T>
65inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
66{
67 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000068 return DIV_CEIL(value, divisor) * divisor;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069}
70
Alex Gildayc357c472018-03-21 13:54:09 +000071/** Computes the largest number smaller or equal to value that is a multiple of divisor.
72 *
73 * @param[in] value Upper bound value
74 * @param[in] divisor Value to compute multiple of.
75 *
76 * @return the result.
77 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010078template <typename S, typename T>
79inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
80{
81 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
82 return (value / divisor) * divisor;
83}
84
Anthony Barbier6ff3b192017-09-04 18:44:23 +010085/** Returns the arm_compute library build information
86 *
87 * Contains the version number and the build options used to build the library
88 *
89 * @return The arm_compute library build information
90 */
91std::string build_information();
92
93/** Load an entire file in memory
94 *
95 * @param[in] filename Name of the file to read.
96 * @param[in] binary Is it a binary file ?
97 *
98 * @return The content of the file.
99 */
100std::string read_file(const std::string &filename, bool binary);
101
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100102/** The size in bytes of the data type
103 *
104 * @param[in] data_type Input data type
105 *
106 * @return The size in bytes of the data type
107 */
108inline size_t data_size_from_type(DataType data_type)
109{
110 switch(data_type)
111 {
112 case DataType::U8:
113 case DataType::S8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100114 case DataType::QSYMM8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100115 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100116 case DataType::QSYMM8_PER_CHANNEL:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117 return 1;
118 case DataType::U16:
119 case DataType::S16:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100120 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100121 case DataType::F16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100122 return 2;
123 case DataType::F32:
124 case DataType::U32:
125 case DataType::S32:
126 return 4;
127 case DataType::F64:
128 case DataType::U64:
129 case DataType::S64:
130 return 8;
131 case DataType::SIZET:
132 return sizeof(size_t);
133 default:
134 ARM_COMPUTE_ERROR("Invalid data type");
135 return 0;
136 }
137}
138
139/** The size in bytes of the pixel format
140 *
141 * @param[in] format Input format
142 *
143 * @return The size in bytes of the pixel format
144 */
145inline size_t pixel_size_from_format(Format format)
146{
147 switch(format)
148 {
149 case Format::U8:
150 return 1;
151 case Format::U16:
152 case Format::S16:
153 case Format::F16:
154 case Format::UV88:
155 case Format::YUYV422:
156 case Format::UYVY422:
157 return 2;
158 case Format::RGB888:
159 return 3;
160 case Format::RGBA8888:
161 return 4;
162 case Format::U32:
163 case Format::S32:
164 case Format::F32:
165 return 4;
166 //Doesn't make sense for planar formats:
167 case Format::NV12:
168 case Format::NV21:
169 case Format::IYUV:
170 case Format::YUV444:
171 default:
172 ARM_COMPUTE_ERROR("Undefined pixel size for given format");
173 return 0;
174 }
175}
176
177/** The size in bytes of the data type
178 *
179 * @param[in] dt Input data type
180 *
181 * @return The size in bytes of the data type
182 */
183inline size_t element_size_from_data_type(DataType dt)
184{
185 switch(dt)
186 {
187 case DataType::S8:
188 case DataType::U8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100189 case DataType::QSYMM8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100190 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100191 case DataType::QSYMM8_PER_CHANNEL:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192 return 1;
193 case DataType::U16:
194 case DataType::S16:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100195 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196 case DataType::F16:
197 return 2;
198 case DataType::U32:
199 case DataType::S32:
200 case DataType::F32:
201 return 4;
202 default:
203 ARM_COMPUTE_ERROR("Undefined element size for given data type");
204 return 0;
205 }
206}
207
208/** Return the data type used by a given single-planar pixel format
209 *
210 * @param[in] format Input format
211 *
212 * @return The size in bytes of the pixel format
213 */
214inline DataType data_type_from_format(Format format)
215{
216 switch(format)
217 {
218 case Format::U8:
219 case Format::UV88:
220 case Format::RGB888:
221 case Format::RGBA8888:
222 case Format::YUYV422:
223 case Format::UYVY422:
224 return DataType::U8;
225 case Format::U16:
226 return DataType::U16;
227 case Format::S16:
228 return DataType::S16;
229 case Format::U32:
230 return DataType::U32;
231 case Format::S32:
232 return DataType::S32;
233 case Format::F16:
234 return DataType::F16;
235 case Format::F32:
236 return DataType::F32;
237 //Doesn't make sense for planar formats:
238 case Format::NV12:
239 case Format::NV21:
240 case Format::IYUV:
241 case Format::YUV444:
242 default:
243 ARM_COMPUTE_ERROR("Not supported data_type for given format");
244 return DataType::UNKNOWN;
245 }
246}
247
248/** Return the plane index of a given channel given an input format.
249 *
250 * @param[in] format Input format
251 * @param[in] channel Input channel
252 *
253 * @return The plane index of the specific channel of the specific format
254 */
255inline int plane_idx_from_channel(Format format, Channel channel)
256{
257 switch(format)
258 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100259 // Single planar formats have a single plane
260 case Format::U8:
261 case Format::U16:
262 case Format::S16:
263 case Format::U32:
264 case Format::S32:
265 case Format::F16:
266 case Format::F32:
267 case Format::UV88:
268 case Format::RGB888:
269 case Format::RGBA8888:
270 case Format::YUYV422:
271 case Format::UYVY422:
272 return 0;
273 // Multi planar formats
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100274 case Format::NV12:
275 case Format::NV21:
276 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100277 // Channel U and V share the same plane of format UV88
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278 switch(channel)
279 {
280 case Channel::Y:
281 return 0;
282 case Channel::U:
283 case Channel::V:
284 return 1;
285 default:
286 ARM_COMPUTE_ERROR("Not supported channel");
287 return 0;
288 }
289 }
290 case Format::IYUV:
291 case Format::YUV444:
292 {
293 switch(channel)
294 {
295 case Channel::Y:
296 return 0;
297 case Channel::U:
298 return 1;
299 case Channel::V:
300 return 2;
301 default:
302 ARM_COMPUTE_ERROR("Not supported channel");
303 return 0;
304 }
305 }
306 default:
307 ARM_COMPUTE_ERROR("Not supported format");
308 return 0;
309 }
310}
311
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100312/** Return the channel index of a given channel given an input format.
313 *
314 * @param[in] format Input format
315 * @param[in] channel Input channel
316 *
317 * @return The channel index of the specific channel of the specific format
318 */
319inline int channel_idx_from_format(Format format, Channel channel)
320{
321 switch(format)
322 {
323 case Format::RGB888:
324 {
325 switch(channel)
326 {
327 case Channel::R:
328 return 0;
329 case Channel::G:
330 return 1;
331 case Channel::B:
332 return 2;
333 default:
334 ARM_COMPUTE_ERROR("Not supported channel");
335 return 0;
336 }
337 }
338 case Format::RGBA8888:
339 {
340 switch(channel)
341 {
342 case Channel::R:
343 return 0;
344 case Channel::G:
345 return 1;
346 case Channel::B:
347 return 2;
348 case Channel::A:
349 return 3;
350 default:
351 ARM_COMPUTE_ERROR("Not supported channel");
352 return 0;
353 }
354 }
355 case Format::YUYV422:
356 {
357 switch(channel)
358 {
359 case Channel::Y:
360 return 0;
361 case Channel::U:
362 return 1;
363 case Channel::V:
364 return 3;
365 default:
366 ARM_COMPUTE_ERROR("Not supported channel");
367 return 0;
368 }
369 }
370 case Format::UYVY422:
371 {
372 switch(channel)
373 {
374 case Channel::Y:
375 return 1;
376 case Channel::U:
377 return 0;
378 case Channel::V:
379 return 2;
380 default:
381 ARM_COMPUTE_ERROR("Not supported channel");
382 return 0;
383 }
384 }
385 case Format::NV12:
386 {
387 switch(channel)
388 {
389 case Channel::Y:
390 return 0;
391 case Channel::U:
392 return 0;
393 case Channel::V:
394 return 1;
395 default:
396 ARM_COMPUTE_ERROR("Not supported channel");
397 return 0;
398 }
399 }
400 case Format::NV21:
401 {
402 switch(channel)
403 {
404 case Channel::Y:
405 return 0;
406 case Channel::U:
407 return 1;
408 case Channel::V:
409 return 0;
410 default:
411 ARM_COMPUTE_ERROR("Not supported channel");
412 return 0;
413 }
414 }
415 case Format::YUV444:
416 case Format::IYUV:
417 {
418 switch(channel)
419 {
420 case Channel::Y:
421 return 0;
422 case Channel::U:
423 return 0;
424 case Channel::V:
425 return 0;
426 default:
427 ARM_COMPUTE_ERROR("Not supported channel");
428 return 0;
429 }
430 }
431 default:
432 ARM_COMPUTE_ERROR("Not supported format");
433 return 0;
434 }
435}
436
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100437/** Return the number of planes for a given format
438 *
439 * @param[in] format Input format
440 *
441 * @return The number of planes for a given image format.
442 */
443inline size_t num_planes_from_format(Format format)
444{
445 switch(format)
446 {
447 case Format::U8:
448 case Format::S16:
449 case Format::U16:
450 case Format::S32:
451 case Format::U32:
452 case Format::F16:
453 case Format::F32:
454 case Format::RGB888:
455 case Format::RGBA8888:
456 case Format::YUYV422:
457 case Format::UYVY422:
458 return 1;
459 case Format::NV12:
460 case Format::NV21:
461 return 2;
462 case Format::IYUV:
463 case Format::YUV444:
464 return 3;
465 default:
466 ARM_COMPUTE_ERROR("Not supported format");
467 return 0;
468 }
469}
470
471/** Return the number of channels for a given single-planar pixel format
472 *
473 * @param[in] format Input format
474 *
475 * @return The number of channels for a given image format.
476 */
477inline size_t num_channels_from_format(Format format)
478{
479 switch(format)
480 {
481 case Format::U8:
482 case Format::U16:
483 case Format::S16:
484 case Format::U32:
485 case Format::S32:
486 case Format::F16:
487 case Format::F32:
488 return 1;
489 // Because the U and V channels are subsampled
490 // these formats appear like having only 2 channels:
491 case Format::YUYV422:
492 case Format::UYVY422:
493 return 2;
494 case Format::UV88:
495 return 2;
496 case Format::RGB888:
497 return 3;
498 case Format::RGBA8888:
499 return 4;
500 //Doesn't make sense for planar formats:
501 case Format::NV12:
502 case Format::NV21:
503 case Format::IYUV:
504 case Format::YUV444:
505 default:
506 return 0;
507 }
508}
509
Chunosovd621bca2017-11-03 17:33:15 +0700510/** Return the promoted data type of a given data type.
511 *
512 * @note If promoted data type is not supported an error will be thrown
513 *
514 * @param[in] dt Data type to get the promoted type of.
515 *
516 * @return Promoted data type
517 */
518inline DataType get_promoted_data_type(DataType dt)
519{
520 switch(dt)
521 {
522 case DataType::U8:
523 return DataType::U16;
524 case DataType::S8:
525 return DataType::S16;
Chunosovd621bca2017-11-03 17:33:15 +0700526 case DataType::U16:
527 return DataType::U32;
528 case DataType::S16:
529 return DataType::S32;
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100530 case DataType::QSYMM8:
Chunosovd621bca2017-11-03 17:33:15 +0700531 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100532 case DataType::QSYMM8_PER_CHANNEL:
Manuel Bottini3689fcd2019-06-14 17:18:12 +0100533 case DataType::QSYMM16:
Chunosovd621bca2017-11-03 17:33:15 +0700534 case DataType::F16:
535 case DataType::U32:
536 case DataType::S32:
537 case DataType::F32:
Chunosovd621bca2017-11-03 17:33:15 +0700538 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
539 default:
540 ARM_COMPUTE_ERROR("Undefined data type!");
541 }
542 return DataType::UNKNOWN;
543}
544
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100545/** Return true if the given format has horizontal subsampling.
546 *
547 * @param[in] format Format to determine subsampling.
548 *
549 * @return True if the format can be subsampled horizontaly.
550 */
551inline bool has_format_horizontal_subsampling(Format format)
552{
553 return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
554}
555
556/** Return true if the given format has vertical subsampling.
557 *
558 * @param[in] format Format to determine subsampling.
559 *
560 * @return True if the format can be subsampled verticaly.
561 */
562inline bool has_format_vertical_subsampling(Format format)
563{
564 return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
565}
566
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567/** Separate a 2D convolution into two 1D convolutions
Anthony Barbierf202e502017-11-23 18:02:04 +0000568 *
569 * @param[in] conv 2D convolution
570 * @param[out] conv_col 1D vertical convolution
571 * @param[out] conv_row 1D horizontal convolution
572 * @param[in] size Size of the 2D convolution
573 *
574 * @return true if the separation was successful
575 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100576inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
577{
578 int32_t min_col = -1;
579 int16_t min_col_val = -1;
580
581 for(int32_t i = 0; i < size; ++i)
582 {
583 if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
584 {
585 min_col = i;
586 min_col_val = conv[i];
587 }
588 }
589
590 if(min_col < 0)
591 {
592 return false;
593 }
594
595 for(uint32_t j = 0; j < size; ++j)
596 {
597 conv_col[j] = conv[min_col + j * size];
598 }
599
600 for(uint32_t i = 0; i < size; i++)
601 {
602 if(static_cast<int>(i) == min_col)
603 {
604 conv_row[i] = 1;
605 }
606 else
607 {
608 int16_t coeff = conv[i] / conv[min_col];
609
610 for(uint32_t j = 1; j < size; ++j)
611 {
612 if(conv[i + j * size] != (conv_col[j] * coeff))
613 {
614 return false;
615 }
616 }
617
618 conv_row[i] = coeff;
619 }
620 }
621
622 return true;
623}
624
625/** Calculate the scale of the given square matrix
626 *
627 * The scale is the absolute value of the sum of all the coefficients in the matrix.
628 *
629 * @note If the coefficients add up to 0 then the scale is set to 1.
630 *
631 * @param[in] matrix Matrix coefficients
632 * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
633 *
634 * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
635 */
636inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
637{
638 const size_t size = matrix_size * matrix_size;
639
640 return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
641}
642
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100643/** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
644 *
645 * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
Manuel Bottini581c8982019-02-07 10:31:57 +0000646 * <a href="https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java">Android Source</a>
647 * <a href="https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM">WebM</a>
648 * <a href="https://bugs.chromium.org/p/libyuv/issues/detail?id=198&amp;can=1&amp;q=odd%20width">libYUV</a>
649 * <a href="https://sourceforge.net/p/raw-yuvplayer/bugs/1/">YUVPlayer</a> *
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100650 *
651 * @param[in, out] shape Tensor shape of 2D size
652 * @param[in] format Format of the tensor
653 *
Alex Gildayc357c472018-03-21 13:54:09 +0000654 * @return The adjusted tensor shape.
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100655 */
656inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
657{
658 TensorShape output{ shape };
659
660 // Force width to be even for formats which require subsampling of the U and V channels
661 if(has_format_horizontal_subsampling(format))
662 {
663 output.set(0, output.x() & ~1U);
664 }
665
666 // Force height to be even for formats which require subsampling of the U and V channels
667 if(has_format_vertical_subsampling(format))
668 {
669 output.set(1, output.y() & ~1U);
670 }
671
672 return output;
673}
674
675/** Calculate subsampled shape for a given format and channel
676 *
677 * @param[in] shape Shape of the tensor to calculate the extracted channel.
678 * @param[in] format Format of the tensor.
679 * @param[in] channel Channel to create tensor shape to be extracted.
680 *
681 * @return The subsampled tensor shape.
682 */
683inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
684{
685 TensorShape output{ shape };
686
687 // Subsample shape only for U or V channel
688 if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
689 {
690 // Subsample width for the tensor shape when channel is U or V
691 if(has_format_horizontal_subsampling(format))
692 {
693 output.set(0, output.x() / 2U);
694 }
695
696 // Subsample height for the tensor shape when channel is U or V
697 if(has_format_vertical_subsampling(format))
698 {
699 output.set(1, output.y() / 2U);
700 }
701 }
702
703 return output;
704}
705
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100706/** Calculate accurary required by the horizontal and vertical convolution computations
707 *
708 * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
709 * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
710 * @param[in] size Number of elements per vector of the separated matrix
711 *
712 * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
713 * element of the pair is the biggest data type needed for the second stage.
714 */
715inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
716{
717 DataType first_stage = DataType::UNKNOWN;
718 DataType second_stage = DataType::UNKNOWN;
719
720 auto gez = [](const int16_t &v)
721 {
722 return v >= 0;
723 };
724
725 auto accu_neg = [](const int &first, const int &second)
726 {
727 return first + (second < 0 ? second : 0);
728 };
729
730 auto accu_pos = [](const int &first, const int &second)
731 {
732 return first + (second > 0 ? second : 0);
733 };
734
735 const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
736
737 if(only_positive_coefficients)
738 {
739 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
740 const int max_value = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
741
742 first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
743
744 second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
745 }
746 else
747 {
748 const int min_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
749 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
750 const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
751 const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
752 const int min_value = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
753 const int max_value = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
754
755 first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
756
757 second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
758 }
759
760 return std::make_pair(first_stage, second_stage);
761}
762
763/** Calculate the accuracy required by the squared convolution calculation.
764 *
765 *
766 * @param[in] conv Pointer to the squared convolution matrix
767 * @param[in] size The total size of the convolution matrix
768 *
769 * @return The return is the biggest data type needed to do the convolution
770 */
771inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
772{
773 auto gez = [](const int16_t v)
774 {
775 return v >= 0;
776 };
777
778 const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
779
780 if(only_positive_coefficients)
781 {
782 const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
783 if(max_conv_value <= UINT16_MAX)
784 {
785 return DataType::U16;
786 }
787 else
788 {
789 return DataType::S32;
790 }
791 }
792 else
793 {
794 const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
795 {
796 return b < 0 ? a + b : a;
797 })
798 * UINT8_MAX;
799
800 const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
801 {
802 return b > 0 ? a + b : a;
803 })
804 * UINT8_MAX;
805
806 if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
807 {
808 return DataType::S16;
809 }
810 else
811 {
812 return DataType::S32;
813 }
814 }
815}
816
Pablo Tello35767bc2018-12-05 17:36:30 +0000817/** Permutes the given dimensions according the permutation vector
818 *
819 * @param[in,out] dimensions Dimensions to be permuted.
820 * @param[in] perm Vector describing the permutation.
821 *
822 */
823template <typename T>
824inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
825{
826 const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
827 for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
828 {
829 T dimension_val = old_dim[i];
830 dimensions.set(perm[i], dimension_val);
831 }
832}
833
Georgios Pinitas4074c992018-01-30 18:13:46 +0000834/** Calculate padding requirements in case of SAME padding
835 *
836 * @param[in] input_shape Input shape
837 * @param[in] weights_shape Weights shape
838 * @param[in] conv_info Convolution information (containing strides)
Isabella Gottardi6a914402019-01-30 15:45:42 +0000839 * @param[in] data_layout (Optional) Data layout of the input and weights tensor
Pablo Tello01bbacb2019-04-30 10:32:42 +0100840 * @param[in] dilation (Optional) Dilation factor used in the convolution.
Georgios Pinitas4074c992018-01-30 18:13:46 +0000841 *
842 * @return PadStrideInfo for SAME padding
843 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100844PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW, const Size2D &dilation = Size2D(1u, 1u));
Georgios Pinitas4074c992018-01-30 18:13:46 +0000845
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100846/** Returns expected width and height of the deconvolution's output tensor.
847 *
Michalis Spyrouafbc5ff2018-10-03 14:18:19 +0100848 * @param[in] in_width Width of input tensor (Number of columns)
849 * @param[in] in_height Height of input tensor (Number of rows)
850 * @param[in] kernel_width Kernel width.
851 * @param[in] kernel_height Kernel height.
852 * @param[in] padx X axis padding.
853 * @param[in] pady Y axis padding.
854 * @param[in] stride_x X axis input stride.
855 * @param[in] stride_y Y axis input stride.
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100856 *
857 * @return A pair with the new width in the first position and the new height in the second.
858 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100859std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
860 unsigned int kernel_width, unsigned int kernel_height,
861 unsigned int padx, unsigned int pady,
862 unsigned int stride_x, unsigned int stride_y);
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100863
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100864/** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
865 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100866 * @param[in] width Width of input tensor (Number of columns)
867 * @param[in] height Height of input tensor (Number of rows)
868 * @param[in] kernel_width Kernel width.
869 * @param[in] kernel_height Kernel height.
870 * @param[in] pad_stride_info Pad and stride information.
Alex Gilday7da29b62018-03-23 14:16:00 +0000871 * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100872 *
873 * @return A pair with the new width in the first position and the new height in the second.
874 */
Pablo Tello01bbacb2019-04-30 10:32:42 +0100875std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsigned int height,
876 unsigned int kernel_width, unsigned int kernel_height,
877 const PadStrideInfo &pad_stride_info,
878 const Size2D &dilation = Size2D(1U, 1U));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100879
880/** Convert a tensor format into a string.
881 *
882 * @param[in] format @ref Format to be translated to string.
883 *
884 * @return The string describing the format.
885 */
886const std::string &string_from_format(Format format);
887
888/** Convert a channel identity into a string.
889 *
890 * @param[in] channel @ref Channel to be translated to string.
891 *
892 * @return The string describing the channel.
893 */
894const std::string &string_from_channel(Channel channel);
Michele Di Giorgiobf3c6622018-03-08 11:52:27 +0000895/** Convert a data layout identity into a string.
896 *
897 * @param[in] dl @ref DataLayout to be translated to string.
898 *
899 * @return The string describing the data layout.
900 */
901const std::string &string_from_data_layout(DataLayout dl);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100902/** Convert a data type identity into a string.
903 *
904 * @param[in] dt @ref DataType to be translated to string.
905 *
906 * @return The string describing the data type.
907 */
908const std::string &string_from_data_type(DataType dt);
909/** Convert a matrix pattern into a string.
910 *
911 * @param[in] pattern @ref MatrixPattern to be translated to string.
912 *
913 * @return The string describing the matrix pattern.
914 */
915const std::string &string_from_matrix_pattern(MatrixPattern pattern);
916/** Translates a given activation function to a string.
917 *
918 * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
919 *
920 * @return The string describing the activation function.
921 */
922const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
923/** Translates a given non linear function to a string.
924 *
925 * @param[in] function @ref NonLinearFilterFunction to be translated to string.
926 *
927 * @return The string describing the non linear function.
928 */
929const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
930/** Translates a given interpolation policy to a string.
931 *
932 * @param[in] policy @ref InterpolationPolicy to be translated to string.
933 *
934 * @return The string describing the interpolation policy.
935 */
936const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
937/** Translates a given border mode policy to a string.
938 *
939 * @param[in] border_mode @ref BorderMode to be translated to string.
940 *
941 * @return The string describing the border mode.
942 */
943const std::string &string_from_border_mode(BorderMode border_mode);
944/** Translates a given normalization type to a string.
945 *
946 * @param[in] type @ref NormType to be translated to string.
947 *
948 * @return The string describing the normalization type.
949 */
950const std::string &string_from_norm_type(NormType type);
Georgios Pinitascdf51452017-08-31 14:21:36 +0100951/** Translates a given pooling type to a string.
952 *
953 * @param[in] type @ref PoolingType to be translated to string.
954 *
955 * @return The string describing the pooling type.
956 */
957const std::string &string_from_pooling_type(PoolingType type);
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100958/** Translates a given GEMMLowp output stage to a string.
959 *
960 * @param[in] output_stage @ref GEMMLowpOutputStageInfo to be translated to string.
961 *
962 * @return The string describing the GEMMLowp output stage
963 */
964const std::string &string_from_gemmlowp_output_stage(GEMMLowpOutputStageType output_stage);
Giuseppe Rossinid7647d42018-07-17 18:13:13 +0100965/** Convert a PixelValue to a string, represented through the specific data type
966 *
967 * @param[in] value The PixelValue to convert
968 * @param[in] data_type The type to be used to convert the @p value
969 *
970 * @return String representation of the PixelValue through the given data type.
971 */
972std::string string_from_pixel_value(const PixelValue &value, const DataType data_type);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100973/** Lower a given string.
974 *
975 * @param[in] val Given string to lower.
976 *
977 * @return The lowered string
978 */
979std::string lower_string(const std::string &val);
980
981/** Check if a given data type is of floating point type
982 *
983 * @param[in] dt Input data type.
984 *
985 * @return True if data type is of floating point type, else false.
986 */
987inline bool is_data_type_float(DataType dt)
988{
989 switch(dt)
990 {
991 case DataType::F16:
992 case DataType::F32:
993 return true;
994 default:
995 return false;
996 }
997}
998
Georgios Pinitas05078ec2017-11-02 13:06:59 +0000999/** Check if a given data type is of quantized type
1000 *
1001 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
1002 *
1003 * @param[in] dt Input data type.
1004 *
1005 * @return True if data type is of quantized type, else false.
1006 */
1007inline bool is_data_type_quantized(DataType dt)
1008{
1009 switch(dt)
1010 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001011 case DataType::QSYMM8:
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001012 case DataType::QASYMM8:
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001013 case DataType::QSYMM8_PER_CHANNEL:
Manuel Bottini3689fcd2019-06-14 17:18:12 +01001014 case DataType::QSYMM16:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001015 return true;
1016 default:
1017 return false;
1018 }
1019}
1020
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001021/** Check if a given data type is of asymmetric quantized type
1022 *
1023 * @param[in] dt Input data type.
1024 *
1025 * @return True if data type is of symmetric quantized type, else false.
1026 */
Anton Lokhmotovaf6204c2017-11-08 09:34:19 +00001027inline bool is_data_type_quantized_asymmetric(DataType dt)
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001028{
1029 switch(dt)
1030 {
1031 case DataType::QASYMM8:
1032 return true;
1033 default:
1034 return false;
1035 }
1036}
1037
Georgios Pinitas89010962017-08-04 14:58:27 +01001038/** Create a string with the float in full precision.
1039 *
1040 * @param val Floating point value
1041 *
1042 * @return String with the floating point value.
1043 */
1044inline std::string float_to_string_with_full_precision(float val)
1045{
1046 std::stringstream ss;
Georgios Pinitas7900a9e2018-11-23 11:44:58 +00001047 ss.precision(std::numeric_limits<float>::max_digits10);
Georgios Pinitas89010962017-08-04 14:58:27 +01001048 ss << val;
Giorgio Arena73023022018-09-04 14:55:55 +01001049
1050 if(val != static_cast<int>(val))
1051 {
1052 ss << "f";
1053 }
1054
Georgios Pinitas89010962017-08-04 14:58:27 +01001055 return ss.str();
1056}
1057
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001058/** Returns the number of elements required to go from start to end with the wanted step
1059 *
1060 * @param[in] start start value
1061 * @param[in] end end value
1062 * @param[in] step step value between each number in the wanted sequence
1063 *
1064 * @return number of elements to go from start value to end value using the wanted step
1065 */
1066inline size_t num_of_elements_in_range(const float start, const float end, const float step)
1067{
1068 ARM_COMPUTE_ERROR_ON_MSG(step == 0, "Range Step cannot be 0");
1069 return size_t(std::ceil((end - start) / step));
1070}
1071
1072/** Returns true if the value can be represented by the given data type
1073 *
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001074 * @param[in] val value to be checked
1075 * @param[in] dt data type that is checked
1076 * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001077 *
1078 * @return true if the data type can hold the value.
1079 */
1080template <typename T>
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001081bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo())
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001082{
1083 switch(dt)
1084 {
1085 case DataType::U8:
1086 return ((static_cast<uint8_t>(val) == val) && val >= std::numeric_limits<uint8_t>::lowest() && val <= std::numeric_limits<uint8_t>::max());
1087 case DataType::QASYMM8:
1088 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +01001089 double min = static_cast<double>(dequantize_qasymm8(0, qinfo));
1090 double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo));
Michalis Spyrouf63885b2019-01-16 14:18:09 +00001091 return ((double)val >= min && (double)val <= max);
1092 }
1093 case DataType::S8:
1094 return ((static_cast<int8_t>(val) == val) && val >= std::numeric_limits<int8_t>::lowest() && val <= std::numeric_limits<int8_t>::max());
1095 case DataType::U16:
1096 return ((static_cast<uint16_t>(val) == val) && val >= std::numeric_limits<uint16_t>::lowest() && val <= std::numeric_limits<uint16_t>::max());
1097 case DataType::S16:
1098 return ((static_cast<int16_t>(val) == val) && val >= std::numeric_limits<int16_t>::lowest() && val <= std::numeric_limits<int16_t>::max());
1099 case DataType::U32:
1100 return ((static_cast<uint32_t>(val) == val) && val >= std::numeric_limits<uint32_t>::lowest() && val <= std::numeric_limits<uint32_t>::max());
1101 case DataType::S32:
1102 return ((static_cast<int32_t>(val) == val) && val >= std::numeric_limits<int32_t>::lowest() && val <= std::numeric_limits<int32_t>::max());
1103 case DataType::U64:
1104 return (val >= std::numeric_limits<uint64_t>::lowest() && val <= std::numeric_limits<uint64_t>::max());
1105 case DataType::S64:
1106 return (val >= std::numeric_limits<int64_t>::lowest() && val <= std::numeric_limits<int64_t>::max());
1107 case DataType::F16:
1108 return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
1109 case DataType::F32:
1110 return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
1111 case DataType::F64:
1112 return (val >= std::numeric_limits<double>::lowest() && val <= std::numeric_limits<double>::max());
1113 case DataType::SIZET:
1114 return ((static_cast<size_t>(val) == val) && val >= std::numeric_limits<size_t>::lowest() && val <= std::numeric_limits<size_t>::max());
1115 default:
1116 ARM_COMPUTE_ERROR("Data type not supported");
1117 return false;
1118 }
1119}
1120
giuros01edc21e42018-11-16 14:45:31 +00001121#ifdef ARM_COMPUTE_ASSERTS_ENABLED
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001122/** Print consecutive elements to an output stream.
1123 *
1124 * @param[out] s Output stream to print the elements to.
1125 * @param[in] ptr Pointer to print the elements from.
1126 * @param[in] n Number of elements to print.
1127 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1128 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1129 */
1130template <typename T>
1131void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1132{
1133 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1134
1135 for(unsigned int i = 0; i < n; ++i)
1136 {
1137 // Set stream width as it is not a "sticky" stream manipulator
1138 if(stream_width != 0)
1139 {
1140 s.width(stream_width);
1141 }
Anthony Barbier7068f992017-10-26 15:23:08 +01001142
1143 if(std::is_same<typename std::decay<T>::type, half>::value)
1144 {
1145 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1146 s << std::right << static_cast<T>(ptr[i]) << element_delim;
1147 }
1148 else
1149 {
1150 s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1151 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001152 }
1153}
1154
1155/** Identify the maximum width of n consecutive elements.
1156 *
1157 * @param[in] s The output stream which will be used to print the elements. Used to extract the stream format.
1158 * @param[in] ptr Pointer to the elements.
1159 * @param[in] n Number of elements.
1160 *
1161 * @return The maximum width of the elements.
1162 */
1163template <typename T>
1164int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1165{
1166 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1167
1168 int max_width = -1;
1169 for(unsigned int i = 0; i < n; ++i)
1170 {
1171 std::stringstream ss;
1172 ss.copyfmt(s);
Anthony Barbier7068f992017-10-26 15:23:08 +01001173
1174 if(std::is_same<typename std::decay<T>::type, half>::value)
1175 {
1176 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1177 ss << static_cast<T>(ptr[i]);
1178 }
1179 else
1180 {
1181 ss << static_cast<print_type>(ptr[i]);
1182 }
1183
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001184 max_width = std::max<int>(max_width, ss.str().size());
1185 }
1186 return max_width;
1187}
1188
1189/** Print consecutive elements to an output stream.
1190 *
1191 * @param[out] s Output stream to print the elements to.
1192 * @param[in] dt Data type of the elements
1193 * @param[in] ptr Pointer to print the elements from.
1194 * @param[in] n Number of elements to print.
1195 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1196 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1197 */
1198void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1199
1200/** Identify the maximum width of n consecutive elements.
1201 *
1202 * @param[in] s Output stream to print the elements to.
1203 * @param[in] dt Data type of the elements
1204 * @param[in] ptr Pointer to print the elements from.
1205 * @param[in] n Number of elements to print.
1206 *
1207 * @return The maximum width of the elements.
1208 */
1209int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
giuros01edc21e42018-11-16 14:45:31 +00001210#endif /* ARM_COMPUTE_ASSERTS_ENABLED */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001211}
1212#endif /*__ARM_COMPUTE_UTILS_H__ */