blob: 060d5904d45b14b96fb1cee729b02779c083fdf3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou780db4e2017-11-23 09:49:51 +00002 * Copyright (c) 2016, 2017, 2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_UTILS_H__
25#define __ARM_COMPUTE_UTILS_H__
26
27#include "arm_compute/core/Error.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000028#include "arm_compute/core/Rounding.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Types.h"
30
31#include <algorithm>
32#include <cstdint>
33#include <cstdlib>
34#include <numeric>
35#include <sstream>
36#include <string>
37#include <type_traits>
38#include <utility>
steniu017ce53c62017-09-29 14:55:00 +010039#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040
41namespace arm_compute
42{
Alex Gildayc357c472018-03-21 13:54:09 +000043/** Calculate the rounded up quotient of val / m.
44 *
45 * @param[in] val Value to divide and round up.
46 * @param[in] m Value to divide by.
47 *
48 * @return the result.
49 */
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000050template <typename S, typename T>
51constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
52{
53 return (val + m - 1) / m;
54}
55
Alex Gildayc357c472018-03-21 13:54:09 +000056/** Computes the smallest number larger or equal to value that is a multiple of divisor.
57 *
58 * @param[in] value Lower bound value
59 * @param[in] divisor Value to compute multiple of.
60 *
61 * @return the result.
62 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010063template <typename S, typename T>
64inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
65{
66 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000067 return DIV_CEIL(value, divisor) * divisor;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010068}
69
Alex Gildayc357c472018-03-21 13:54:09 +000070/** Computes the largest number smaller or equal to value that is a multiple of divisor.
71 *
72 * @param[in] value Upper bound value
73 * @param[in] divisor Value to compute multiple of.
74 *
75 * @return the result.
76 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077template <typename S, typename T>
78inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
79{
80 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
81 return (value / divisor) * divisor;
82}
83
Anthony Barbier6ff3b192017-09-04 18:44:23 +010084/** Returns the arm_compute library build information
85 *
86 * Contains the version number and the build options used to build the library
87 *
88 * @return The arm_compute library build information
89 */
90std::string build_information();
91
92/** Load an entire file in memory
93 *
94 * @param[in] filename Name of the file to read.
95 * @param[in] binary Is it a binary file ?
96 *
97 * @return The content of the file.
98 */
99std::string read_file(const std::string &filename, bool binary);
100
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100101/** The size in bytes of the data type
102 *
103 * @param[in] data_type Input data type
104 *
105 * @return The size in bytes of the data type
106 */
107inline size_t data_size_from_type(DataType data_type)
108{
109 switch(data_type)
110 {
111 case DataType::U8:
112 case DataType::S8:
113 case DataType::QS8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100114 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100115 return 1;
116 case DataType::U16:
117 case DataType::S16:
118 case DataType::F16:
119 case DataType::QS16:
120 return 2;
121 case DataType::F32:
122 case DataType::U32:
123 case DataType::S32:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100124 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100125 return 4;
126 case DataType::F64:
127 case DataType::U64:
128 case DataType::S64:
129 return 8;
130 case DataType::SIZET:
131 return sizeof(size_t);
132 default:
133 ARM_COMPUTE_ERROR("Invalid data type");
134 return 0;
135 }
136}
137
138/** The size in bytes of the pixel format
139 *
140 * @param[in] format Input format
141 *
142 * @return The size in bytes of the pixel format
143 */
144inline size_t pixel_size_from_format(Format format)
145{
146 switch(format)
147 {
148 case Format::U8:
149 return 1;
150 case Format::U16:
151 case Format::S16:
152 case Format::F16:
153 case Format::UV88:
154 case Format::YUYV422:
155 case Format::UYVY422:
156 return 2;
157 case Format::RGB888:
158 return 3;
159 case Format::RGBA8888:
160 return 4;
161 case Format::U32:
162 case Format::S32:
163 case Format::F32:
164 return 4;
165 //Doesn't make sense for planar formats:
166 case Format::NV12:
167 case Format::NV21:
168 case Format::IYUV:
169 case Format::YUV444:
170 default:
171 ARM_COMPUTE_ERROR("Undefined pixel size for given format");
172 return 0;
173 }
174}
175
176/** The size in bytes of the data type
177 *
178 * @param[in] dt Input data type
179 *
180 * @return The size in bytes of the data type
181 */
182inline size_t element_size_from_data_type(DataType dt)
183{
184 switch(dt)
185 {
186 case DataType::S8:
187 case DataType::U8:
188 case DataType::QS8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100189 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190 return 1;
191 case DataType::U16:
192 case DataType::S16:
193 case DataType::QS16:
194 case DataType::F16:
195 return 2;
196 case DataType::U32:
197 case DataType::S32:
198 case DataType::F32:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100199 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100200 return 4;
201 default:
202 ARM_COMPUTE_ERROR("Undefined element size for given data type");
203 return 0;
204 }
205}
206
207/** Return the data type used by a given single-planar pixel format
208 *
209 * @param[in] format Input format
210 *
211 * @return The size in bytes of the pixel format
212 */
213inline DataType data_type_from_format(Format format)
214{
215 switch(format)
216 {
217 case Format::U8:
218 case Format::UV88:
219 case Format::RGB888:
220 case Format::RGBA8888:
221 case Format::YUYV422:
222 case Format::UYVY422:
223 return DataType::U8;
224 case Format::U16:
225 return DataType::U16;
226 case Format::S16:
227 return DataType::S16;
228 case Format::U32:
229 return DataType::U32;
230 case Format::S32:
231 return DataType::S32;
232 case Format::F16:
233 return DataType::F16;
234 case Format::F32:
235 return DataType::F32;
236 //Doesn't make sense for planar formats:
237 case Format::NV12:
238 case Format::NV21:
239 case Format::IYUV:
240 case Format::YUV444:
241 default:
242 ARM_COMPUTE_ERROR("Not supported data_type for given format");
243 return DataType::UNKNOWN;
244 }
245}
246
247/** Return the plane index of a given channel given an input format.
248 *
249 * @param[in] format Input format
250 * @param[in] channel Input channel
251 *
252 * @return The plane index of the specific channel of the specific format
253 */
254inline int plane_idx_from_channel(Format format, Channel channel)
255{
256 switch(format)
257 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100258 // Single planar formats have a single plane
259 case Format::U8:
260 case Format::U16:
261 case Format::S16:
262 case Format::U32:
263 case Format::S32:
264 case Format::F16:
265 case Format::F32:
266 case Format::UV88:
267 case Format::RGB888:
268 case Format::RGBA8888:
269 case Format::YUYV422:
270 case Format::UYVY422:
271 return 0;
272 // Multi planar formats
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100273 case Format::NV12:
274 case Format::NV21:
275 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100276 // Channel U and V share the same plane of format UV88
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277 switch(channel)
278 {
279 case Channel::Y:
280 return 0;
281 case Channel::U:
282 case Channel::V:
283 return 1;
284 default:
285 ARM_COMPUTE_ERROR("Not supported channel");
286 return 0;
287 }
288 }
289 case Format::IYUV:
290 case Format::YUV444:
291 {
292 switch(channel)
293 {
294 case Channel::Y:
295 return 0;
296 case Channel::U:
297 return 1;
298 case Channel::V:
299 return 2;
300 default:
301 ARM_COMPUTE_ERROR("Not supported channel");
302 return 0;
303 }
304 }
305 default:
306 ARM_COMPUTE_ERROR("Not supported format");
307 return 0;
308 }
309}
310
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100311/** Return the channel index of a given channel given an input format.
312 *
313 * @param[in] format Input format
314 * @param[in] channel Input channel
315 *
316 * @return The channel index of the specific channel of the specific format
317 */
318inline int channel_idx_from_format(Format format, Channel channel)
319{
320 switch(format)
321 {
322 case Format::RGB888:
323 {
324 switch(channel)
325 {
326 case Channel::R:
327 return 0;
328 case Channel::G:
329 return 1;
330 case Channel::B:
331 return 2;
332 default:
333 ARM_COMPUTE_ERROR("Not supported channel");
334 return 0;
335 }
336 }
337 case Format::RGBA8888:
338 {
339 switch(channel)
340 {
341 case Channel::R:
342 return 0;
343 case Channel::G:
344 return 1;
345 case Channel::B:
346 return 2;
347 case Channel::A:
348 return 3;
349 default:
350 ARM_COMPUTE_ERROR("Not supported channel");
351 return 0;
352 }
353 }
354 case Format::YUYV422:
355 {
356 switch(channel)
357 {
358 case Channel::Y:
359 return 0;
360 case Channel::U:
361 return 1;
362 case Channel::V:
363 return 3;
364 default:
365 ARM_COMPUTE_ERROR("Not supported channel");
366 return 0;
367 }
368 }
369 case Format::UYVY422:
370 {
371 switch(channel)
372 {
373 case Channel::Y:
374 return 1;
375 case Channel::U:
376 return 0;
377 case Channel::V:
378 return 2;
379 default:
380 ARM_COMPUTE_ERROR("Not supported channel");
381 return 0;
382 }
383 }
384 case Format::NV12:
385 {
386 switch(channel)
387 {
388 case Channel::Y:
389 return 0;
390 case Channel::U:
391 return 0;
392 case Channel::V:
393 return 1;
394 default:
395 ARM_COMPUTE_ERROR("Not supported channel");
396 return 0;
397 }
398 }
399 case Format::NV21:
400 {
401 switch(channel)
402 {
403 case Channel::Y:
404 return 0;
405 case Channel::U:
406 return 1;
407 case Channel::V:
408 return 0;
409 default:
410 ARM_COMPUTE_ERROR("Not supported channel");
411 return 0;
412 }
413 }
414 case Format::YUV444:
415 case Format::IYUV:
416 {
417 switch(channel)
418 {
419 case Channel::Y:
420 return 0;
421 case Channel::U:
422 return 0;
423 case Channel::V:
424 return 0;
425 default:
426 ARM_COMPUTE_ERROR("Not supported channel");
427 return 0;
428 }
429 }
430 default:
431 ARM_COMPUTE_ERROR("Not supported format");
432 return 0;
433 }
434}
435
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100436/** Return the number of planes for a given format
437 *
438 * @param[in] format Input format
439 *
440 * @return The number of planes for a given image format.
441 */
442inline size_t num_planes_from_format(Format format)
443{
444 switch(format)
445 {
446 case Format::U8:
447 case Format::S16:
448 case Format::U16:
449 case Format::S32:
450 case Format::U32:
451 case Format::F16:
452 case Format::F32:
453 case Format::RGB888:
454 case Format::RGBA8888:
455 case Format::YUYV422:
456 case Format::UYVY422:
457 return 1;
458 case Format::NV12:
459 case Format::NV21:
460 return 2;
461 case Format::IYUV:
462 case Format::YUV444:
463 return 3;
464 default:
465 ARM_COMPUTE_ERROR("Not supported format");
466 return 0;
467 }
468}
469
470/** Return the number of channels for a given single-planar pixel format
471 *
472 * @param[in] format Input format
473 *
474 * @return The number of channels for a given image format.
475 */
476inline size_t num_channels_from_format(Format format)
477{
478 switch(format)
479 {
480 case Format::U8:
481 case Format::U16:
482 case Format::S16:
483 case Format::U32:
484 case Format::S32:
485 case Format::F16:
486 case Format::F32:
487 return 1;
488 // Because the U and V channels are subsampled
489 // these formats appear like having only 2 channels:
490 case Format::YUYV422:
491 case Format::UYVY422:
492 return 2;
493 case Format::UV88:
494 return 2;
495 case Format::RGB888:
496 return 3;
497 case Format::RGBA8888:
498 return 4;
499 //Doesn't make sense for planar formats:
500 case Format::NV12:
501 case Format::NV21:
502 case Format::IYUV:
503 case Format::YUV444:
504 default:
505 return 0;
506 }
507}
508
Chunosovd621bca2017-11-03 17:33:15 +0700509/** Return the promoted data type of a given data type.
510 *
511 * @note If promoted data type is not supported an error will be thrown
512 *
513 * @param[in] dt Data type to get the promoted type of.
514 *
515 * @return Promoted data type
516 */
517inline DataType get_promoted_data_type(DataType dt)
518{
519 switch(dt)
520 {
521 case DataType::U8:
522 return DataType::U16;
523 case DataType::S8:
524 return DataType::S16;
525 case DataType::QS8:
526 return DataType::QS16;
527 case DataType::U16:
528 return DataType::U32;
529 case DataType::S16:
530 return DataType::S32;
531 case DataType::QS16:
532 return DataType::QS32;
533 case DataType::QASYMM8:
534 case DataType::F16:
535 case DataType::U32:
536 case DataType::S32:
537 case DataType::F32:
538 case DataType::QS32:
539 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
540 default:
541 ARM_COMPUTE_ERROR("Undefined data type!");
542 }
543 return DataType::UNKNOWN;
544}
545
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100546/** Return true if the given format has horizontal subsampling.
547 *
548 * @param[in] format Format to determine subsampling.
549 *
550 * @return True if the format can be subsampled horizontaly.
551 */
552inline bool has_format_horizontal_subsampling(Format format)
553{
554 return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
555}
556
557/** Return true if the given format has vertical subsampling.
558 *
559 * @param[in] format Format to determine subsampling.
560 *
561 * @return True if the format can be subsampled verticaly.
562 */
563inline bool has_format_vertical_subsampling(Format format)
564{
565 return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
566}
567
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568/** Separate a 2D convolution into two 1D convolutions
Anthony Barbierf202e502017-11-23 18:02:04 +0000569 *
570 * @param[in] conv 2D convolution
571 * @param[out] conv_col 1D vertical convolution
572 * @param[out] conv_row 1D horizontal convolution
573 * @param[in] size Size of the 2D convolution
574 *
575 * @return true if the separation was successful
576 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100577inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
578{
579 int32_t min_col = -1;
580 int16_t min_col_val = -1;
581
582 for(int32_t i = 0; i < size; ++i)
583 {
584 if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
585 {
586 min_col = i;
587 min_col_val = conv[i];
588 }
589 }
590
591 if(min_col < 0)
592 {
593 return false;
594 }
595
596 for(uint32_t j = 0; j < size; ++j)
597 {
598 conv_col[j] = conv[min_col + j * size];
599 }
600
601 for(uint32_t i = 0; i < size; i++)
602 {
603 if(static_cast<int>(i) == min_col)
604 {
605 conv_row[i] = 1;
606 }
607 else
608 {
609 int16_t coeff = conv[i] / conv[min_col];
610
611 for(uint32_t j = 1; j < size; ++j)
612 {
613 if(conv[i + j * size] != (conv_col[j] * coeff))
614 {
615 return false;
616 }
617 }
618
619 conv_row[i] = coeff;
620 }
621 }
622
623 return true;
624}
625
626/** Calculate the scale of the given square matrix
627 *
628 * The scale is the absolute value of the sum of all the coefficients in the matrix.
629 *
630 * @note If the coefficients add up to 0 then the scale is set to 1.
631 *
632 * @param[in] matrix Matrix coefficients
633 * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
634 *
635 * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
636 */
637inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
638{
639 const size_t size = matrix_size * matrix_size;
640
641 return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
642}
643
steniu017ce53c62017-09-29 14:55:00 +0100644/** Calculate the output shapes of the depth concatenate function.
645 *
646 * @param[in] inputs_vector The vector that stores all the pointers to input.
647 *
648 * @return the output shape
649 */
650template <typename T>
651TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector)
652{
653 TensorShape out_shape = inputs_vector[0]->info()->tensor_shape();
654
655 size_t max_x = 0;
656 size_t max_y = 0;
657 size_t depth = 0;
658
659 for(const auto &tensor : inputs_vector)
660 {
661 ARM_COMPUTE_ERROR_ON(tensor == nullptr);
662 const TensorShape shape = tensor->info()->tensor_shape();
663 max_x = std::max(shape.x(), max_x);
664 max_y = std::max(shape.y(), max_y);
665 depth += shape.z();
666 }
667
668 out_shape.set(0, max_x);
669 out_shape.set(1, max_y);
670 out_shape.set(2, depth);
671
672 return out_shape;
673}
674
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100675/** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
676 *
677 * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
678 * Android Source: https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java
679 * WebM: https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM
680 * libYUV: https://bugs.chromium.org/p/libyuv/issues/detail?id=198&can=1&q=odd%20width
681 * YUVPlayer: https://sourceforge.net/p/raw-yuvplayer/bugs/1/
682 *
683 * @param[in, out] shape Tensor shape of 2D size
684 * @param[in] format Format of the tensor
685 *
Alex Gildayc357c472018-03-21 13:54:09 +0000686 * @return The adjusted tensor shape.
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100687 */
688inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
689{
690 TensorShape output{ shape };
691
692 // Force width to be even for formats which require subsampling of the U and V channels
693 if(has_format_horizontal_subsampling(format))
694 {
695 output.set(0, output.x() & ~1U);
696 }
697
698 // Force height to be even for formats which require subsampling of the U and V channels
699 if(has_format_vertical_subsampling(format))
700 {
701 output.set(1, output.y() & ~1U);
702 }
703
704 return output;
705}
706
707/** Calculate subsampled shape for a given format and channel
708 *
709 * @param[in] shape Shape of the tensor to calculate the extracted channel.
710 * @param[in] format Format of the tensor.
711 * @param[in] channel Channel to create tensor shape to be extracted.
712 *
713 * @return The subsampled tensor shape.
714 */
715inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
716{
717 TensorShape output{ shape };
718
719 // Subsample shape only for U or V channel
720 if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
721 {
722 // Subsample width for the tensor shape when channel is U or V
723 if(has_format_horizontal_subsampling(format))
724 {
725 output.set(0, output.x() / 2U);
726 }
727
728 // Subsample height for the tensor shape when channel is U or V
729 if(has_format_vertical_subsampling(format))
730 {
731 output.set(1, output.y() / 2U);
732 }
733 }
734
735 return output;
736}
737
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100738/** Calculate accurary required by the horizontal and vertical convolution computations
739 *
740 * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
741 * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
742 * @param[in] size Number of elements per vector of the separated matrix
743 *
744 * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
745 * element of the pair is the biggest data type needed for the second stage.
746 */
747inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
748{
749 DataType first_stage = DataType::UNKNOWN;
750 DataType second_stage = DataType::UNKNOWN;
751
752 auto gez = [](const int16_t &v)
753 {
754 return v >= 0;
755 };
756
757 auto accu_neg = [](const int &first, const int &second)
758 {
759 return first + (second < 0 ? second : 0);
760 };
761
762 auto accu_pos = [](const int &first, const int &second)
763 {
764 return first + (second > 0 ? second : 0);
765 };
766
767 const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
768
769 if(only_positive_coefficients)
770 {
771 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
772 const int max_value = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
773
774 first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
775
776 second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
777 }
778 else
779 {
780 const int min_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
781 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
782 const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
783 const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
784 const int min_value = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
785 const int max_value = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
786
787 first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
788
789 second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
790 }
791
792 return std::make_pair(first_stage, second_stage);
793}
794
795/** Calculate the accuracy required by the squared convolution calculation.
796 *
797 *
798 * @param[in] conv Pointer to the squared convolution matrix
799 * @param[in] size The total size of the convolution matrix
800 *
801 * @return The return is the biggest data type needed to do the convolution
802 */
803inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
804{
805 auto gez = [](const int16_t v)
806 {
807 return v >= 0;
808 };
809
810 const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
811
812 if(only_positive_coefficients)
813 {
814 const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
815 if(max_conv_value <= UINT16_MAX)
816 {
817 return DataType::U16;
818 }
819 else
820 {
821 return DataType::S32;
822 }
823 }
824 else
825 {
826 const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
827 {
828 return b < 0 ? a + b : a;
829 })
830 * UINT8_MAX;
831
832 const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
833 {
834 return b > 0 ? a + b : a;
835 })
836 * UINT8_MAX;
837
838 if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
839 {
840 return DataType::S16;
841 }
842 else
843 {
844 return DataType::S32;
845 }
846 }
847}
848
Georgios Pinitas4074c992018-01-30 18:13:46 +0000849/** Calculate padding requirements in case of SAME padding
850 *
851 * @param[in] input_shape Input shape
852 * @param[in] weights_shape Weights shape
853 * @param[in] conv_info Convolution information (containing strides)
854 *
855 * @return PadStrideInfo for SAME padding
856 */
857PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info);
858
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100859/** Returns expected shape for the deconvolution output tensor.
860 *
861 * @param[in] out_dims widht and height of the output tensor, these values can be obtained with the function deconvolution_output_dimensions.
862 * @param[in] input Shape of the input tensor.
863 * @param[in] weights Shape of the weights tensor.
864 *
865 * @return Deconvolution output tensor shape.
866 */
867TensorShape deconvolution_output_shape(const std::pair<unsigned int, unsigned int> &out_dims, TensorShape input, TensorShape weights);
868
869/** Returns expected width and height of the deconvolution's output tensor.
870 *
Michalis Spyrou780db4e2017-11-23 09:49:51 +0000871 * @param[in] in_width Width of input tensor (Number of columns)
872 * @param[in] in_height Height of input tensor (Number of rows)
873 * @param[in] kernel_width Kernel width.
874 * @param[in] kernel_height Kernel height.
875 * @param[in] padx X axis padding.
876 * @param[in] pady Y axis padding.
877 * @param[in] inner_border_right The number of zeros added to right edge of the input.
878 * @param[in] inner_border_top The number of zeros added to top edge of the input.
879 * @param[in] stride_x X axis input stride.
880 * @param[in] stride_y Y axis input stride.
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100881 *
882 * @return A pair with the new width in the first position and the new height in the second.
883 */
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100884const std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
885 unsigned int kernel_width, unsigned int kernel_height,
Michalis Spyrou780db4e2017-11-23 09:49:51 +0000886 unsigned int padx, unsigned int pady, unsigned int inner_border_right, unsigned int inner_border_top,
887 unsigned int stride_x, unsigned int stride_y);
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100888
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100889/** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
890 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100891 * @param[in] width Width of input tensor (Number of columns)
892 * @param[in] height Height of input tensor (Number of rows)
893 * @param[in] kernel_width Kernel width.
894 * @param[in] kernel_height Kernel height.
895 * @param[in] pad_stride_info Pad and stride information.
Alex Gilday7da29b62018-03-23 14:16:00 +0000896 * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100897 *
898 * @return A pair with the new width in the first position and the new height in the second.
899 */
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100900const std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsigned int height,
901 unsigned int kernel_width, unsigned int kernel_height,
Alex Gilday7da29b62018-03-23 14:16:00 +0000902 const PadStrideInfo &pad_stride_info,
903 const Size2D &dilation = Size2D(1U, 1U));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100904
905/** Convert a tensor format into a string.
906 *
907 * @param[in] format @ref Format to be translated to string.
908 *
909 * @return The string describing the format.
910 */
911const std::string &string_from_format(Format format);
912
913/** Convert a channel identity into a string.
914 *
915 * @param[in] channel @ref Channel to be translated to string.
916 *
917 * @return The string describing the channel.
918 */
919const std::string &string_from_channel(Channel channel);
Michele Di Giorgiobf3c6622018-03-08 11:52:27 +0000920/** Convert a data layout identity into a string.
921 *
922 * @param[in] dl @ref DataLayout to be translated to string.
923 *
924 * @return The string describing the data layout.
925 */
926const std::string &string_from_data_layout(DataLayout dl);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100927/** Convert a data type identity into a string.
928 *
929 * @param[in] dt @ref DataType to be translated to string.
930 *
931 * @return The string describing the data type.
932 */
933const std::string &string_from_data_type(DataType dt);
934/** Convert a matrix pattern into a string.
935 *
936 * @param[in] pattern @ref MatrixPattern to be translated to string.
937 *
938 * @return The string describing the matrix pattern.
939 */
940const std::string &string_from_matrix_pattern(MatrixPattern pattern);
941/** Translates a given activation function to a string.
942 *
943 * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
944 *
945 * @return The string describing the activation function.
946 */
947const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
948/** Translates a given non linear function to a string.
949 *
950 * @param[in] function @ref NonLinearFilterFunction to be translated to string.
951 *
952 * @return The string describing the non linear function.
953 */
954const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
955/** Translates a given interpolation policy to a string.
956 *
957 * @param[in] policy @ref InterpolationPolicy to be translated to string.
958 *
959 * @return The string describing the interpolation policy.
960 */
961const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
962/** Translates a given border mode policy to a string.
963 *
964 * @param[in] border_mode @ref BorderMode to be translated to string.
965 *
966 * @return The string describing the border mode.
967 */
968const std::string &string_from_border_mode(BorderMode border_mode);
969/** Translates a given normalization type to a string.
970 *
971 * @param[in] type @ref NormType to be translated to string.
972 *
973 * @return The string describing the normalization type.
974 */
975const std::string &string_from_norm_type(NormType type);
Georgios Pinitascdf51452017-08-31 14:21:36 +0100976/** Translates a given pooling type to a string.
977 *
978 * @param[in] type @ref PoolingType to be translated to string.
979 *
980 * @return The string describing the pooling type.
981 */
982const std::string &string_from_pooling_type(PoolingType type);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100983/** Lower a given string.
984 *
985 * @param[in] val Given string to lower.
986 *
987 * @return The lowered string
988 */
989std::string lower_string(const std::string &val);
990
991/** Check if a given data type is of floating point type
992 *
993 * @param[in] dt Input data type.
994 *
995 * @return True if data type is of floating point type, else false.
996 */
997inline bool is_data_type_float(DataType dt)
998{
999 switch(dt)
1000 {
1001 case DataType::F16:
1002 case DataType::F32:
1003 return true;
1004 default:
1005 return false;
1006 }
1007}
1008
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001009/** Check if a given data type is of quantized type
1010 *
1011 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
1012 *
1013 * @param[in] dt Input data type.
1014 *
1015 * @return True if data type is of quantized type, else false.
1016 */
1017inline bool is_data_type_quantized(DataType dt)
1018{
1019 switch(dt)
1020 {
1021 case DataType::QS8:
1022 case DataType::QASYMM8:
1023 case DataType::QS16:
1024 case DataType::QS32:
1025 return true;
1026 default:
1027 return false;
1028 }
1029}
1030
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001031/** Check if a given data type is of fixed point type
1032 *
1033 * @param[in] dt Input data type.
1034 *
1035 * @return True if data type is of fixed point type, else false.
1036 */
1037inline bool is_data_type_fixed_point(DataType dt)
1038{
1039 switch(dt)
1040 {
1041 case DataType::QS8:
1042 case DataType::QS16:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001043 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001044 return true;
1045 default:
1046 return false;
1047 }
1048}
1049
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001050/** Check if a given data type is of asymmetric quantized type
1051 *
1052 * @param[in] dt Input data type.
1053 *
1054 * @return True if data type is of symmetric quantized type, else false.
1055 */
Anton Lokhmotovaf6204c2017-11-08 09:34:19 +00001056inline bool is_data_type_quantized_asymmetric(DataType dt)
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001057{
1058 switch(dt)
1059 {
1060 case DataType::QASYMM8:
1061 return true;
1062 default:
1063 return false;
1064 }
1065}
1066
Georgios Pinitas89010962017-08-04 14:58:27 +01001067/** Create a string with the float in full precision.
1068 *
1069 * @param val Floating point value
1070 *
1071 * @return String with the floating point value.
1072 */
1073inline std::string float_to_string_with_full_precision(float val)
1074{
1075 std::stringstream ss;
1076 ss.precision(std::numeric_limits<float>::digits10 + 1);
1077 ss << val;
1078 return ss.str();
1079}
1080
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001081/** Print consecutive elements to an output stream.
1082 *
1083 * @param[out] s Output stream to print the elements to.
1084 * @param[in] ptr Pointer to print the elements from.
1085 * @param[in] n Number of elements to print.
1086 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1087 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1088 */
1089template <typename T>
1090void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1091{
1092 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1093
1094 for(unsigned int i = 0; i < n; ++i)
1095 {
1096 // Set stream width as it is not a "sticky" stream manipulator
1097 if(stream_width != 0)
1098 {
1099 s.width(stream_width);
1100 }
Anthony Barbier7068f992017-10-26 15:23:08 +01001101
1102 if(std::is_same<typename std::decay<T>::type, half>::value)
1103 {
1104 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1105 s << std::right << static_cast<T>(ptr[i]) << element_delim;
1106 }
1107 else
1108 {
1109 s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1110 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001111 }
1112}
1113
1114/** Identify the maximum width of n consecutive elements.
1115 *
1116 * @param[in] s The output stream which will be used to print the elements. Used to extract the stream format.
1117 * @param[in] ptr Pointer to the elements.
1118 * @param[in] n Number of elements.
1119 *
1120 * @return The maximum width of the elements.
1121 */
1122template <typename T>
1123int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1124{
1125 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1126
1127 int max_width = -1;
1128 for(unsigned int i = 0; i < n; ++i)
1129 {
1130 std::stringstream ss;
1131 ss.copyfmt(s);
Anthony Barbier7068f992017-10-26 15:23:08 +01001132
1133 if(std::is_same<typename std::decay<T>::type, half>::value)
1134 {
1135 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1136 ss << static_cast<T>(ptr[i]);
1137 }
1138 else
1139 {
1140 ss << static_cast<print_type>(ptr[i]);
1141 }
1142
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001143 max_width = std::max<int>(max_width, ss.str().size());
1144 }
1145 return max_width;
1146}
1147
1148/** Print consecutive elements to an output stream.
1149 *
1150 * @param[out] s Output stream to print the elements to.
1151 * @param[in] dt Data type of the elements
1152 * @param[in] ptr Pointer to print the elements from.
1153 * @param[in] n Number of elements to print.
1154 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1155 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1156 */
1157void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1158
1159/** Identify the maximum width of n consecutive elements.
1160 *
1161 * @param[in] s Output stream to print the elements to.
1162 * @param[in] dt Data type of the elements
1163 * @param[in] ptr Pointer to print the elements from.
1164 * @param[in] n Number of elements to print.
1165 *
1166 * @return The maximum width of the elements.
1167 */
1168int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
1169}
1170#endif /*__ARM_COMPUTE_UTILS_H__ */