blob: b3ebf5e25b8830f587ae679084b499a92adda756 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou780db4e2017-11-23 09:49:51 +00002 * Copyright (c) 2016, 2017, 2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_UTILS_H__
25#define __ARM_COMPUTE_UTILS_H__
26
27#include "arm_compute/core/Error.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000028#include "arm_compute/core/Rounding.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Types.h"
30
31#include <algorithm>
32#include <cstdint>
33#include <cstdlib>
34#include <numeric>
35#include <sstream>
36#include <string>
37#include <type_traits>
38#include <utility>
steniu017ce53c62017-09-29 14:55:00 +010039#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040
41namespace arm_compute
42{
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043/** Calculate the rounded up quotient of val / m. */
44template <typename S, typename T>
45constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
46{
47 return (val + m - 1) / m;
48}
49
Anthony Barbier6ff3b192017-09-04 18:44:23 +010050/** Computes the smallest number larger or equal to value that is a multiple of divisor. */
51template <typename S, typename T>
52inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
53{
54 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000055 return DIV_CEIL(value, divisor) * divisor;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010056}
57
58/** Computes the largest number smaller or equal to value that is a multiple of divisor. */
59template <typename S, typename T>
60inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
61{
62 ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
63 return (value / divisor) * divisor;
64}
65
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066/** Returns the arm_compute library build information
67 *
68 * Contains the version number and the build options used to build the library
69 *
70 * @return The arm_compute library build information
71 */
72std::string build_information();
73
74/** Load an entire file in memory
75 *
76 * @param[in] filename Name of the file to read.
77 * @param[in] binary Is it a binary file ?
78 *
79 * @return The content of the file.
80 */
81std::string read_file(const std::string &filename, bool binary);
82
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083/** The size in bytes of the data type
84 *
85 * @param[in] data_type Input data type
86 *
87 * @return The size in bytes of the data type
88 */
89inline size_t data_size_from_type(DataType data_type)
90{
91 switch(data_type)
92 {
93 case DataType::U8:
94 case DataType::S8:
95 case DataType::QS8:
Michel Iwaniec00633802017-10-12 14:14:15 +010096 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097 return 1;
98 case DataType::U16:
99 case DataType::S16:
100 case DataType::F16:
101 case DataType::QS16:
102 return 2;
103 case DataType::F32:
104 case DataType::U32:
105 case DataType::S32:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100106 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100107 return 4;
108 case DataType::F64:
109 case DataType::U64:
110 case DataType::S64:
111 return 8;
112 case DataType::SIZET:
113 return sizeof(size_t);
114 default:
115 ARM_COMPUTE_ERROR("Invalid data type");
116 return 0;
117 }
118}
119
120/** The size in bytes of the pixel format
121 *
122 * @param[in] format Input format
123 *
124 * @return The size in bytes of the pixel format
125 */
126inline size_t pixel_size_from_format(Format format)
127{
128 switch(format)
129 {
130 case Format::U8:
131 return 1;
132 case Format::U16:
133 case Format::S16:
134 case Format::F16:
135 case Format::UV88:
136 case Format::YUYV422:
137 case Format::UYVY422:
138 return 2;
139 case Format::RGB888:
140 return 3;
141 case Format::RGBA8888:
142 return 4;
143 case Format::U32:
144 case Format::S32:
145 case Format::F32:
146 return 4;
147 //Doesn't make sense for planar formats:
148 case Format::NV12:
149 case Format::NV21:
150 case Format::IYUV:
151 case Format::YUV444:
152 default:
153 ARM_COMPUTE_ERROR("Undefined pixel size for given format");
154 return 0;
155 }
156}
157
158/** The size in bytes of the data type
159 *
160 * @param[in] dt Input data type
161 *
162 * @return The size in bytes of the data type
163 */
164inline size_t element_size_from_data_type(DataType dt)
165{
166 switch(dt)
167 {
168 case DataType::S8:
169 case DataType::U8:
170 case DataType::QS8:
Michel Iwaniec00633802017-10-12 14:14:15 +0100171 case DataType::QASYMM8:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100172 return 1;
173 case DataType::U16:
174 case DataType::S16:
175 case DataType::QS16:
176 case DataType::F16:
177 return 2;
178 case DataType::U32:
179 case DataType::S32:
180 case DataType::F32:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100181 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100182 return 4;
183 default:
184 ARM_COMPUTE_ERROR("Undefined element size for given data type");
185 return 0;
186 }
187}
188
189/** Return the data type used by a given single-planar pixel format
190 *
191 * @param[in] format Input format
192 *
193 * @return The size in bytes of the pixel format
194 */
195inline DataType data_type_from_format(Format format)
196{
197 switch(format)
198 {
199 case Format::U8:
200 case Format::UV88:
201 case Format::RGB888:
202 case Format::RGBA8888:
203 case Format::YUYV422:
204 case Format::UYVY422:
205 return DataType::U8;
206 case Format::U16:
207 return DataType::U16;
208 case Format::S16:
209 return DataType::S16;
210 case Format::U32:
211 return DataType::U32;
212 case Format::S32:
213 return DataType::S32;
214 case Format::F16:
215 return DataType::F16;
216 case Format::F32:
217 return DataType::F32;
218 //Doesn't make sense for planar formats:
219 case Format::NV12:
220 case Format::NV21:
221 case Format::IYUV:
222 case Format::YUV444:
223 default:
224 ARM_COMPUTE_ERROR("Not supported data_type for given format");
225 return DataType::UNKNOWN;
226 }
227}
228
229/** Return the plane index of a given channel given an input format.
230 *
231 * @param[in] format Input format
232 * @param[in] channel Input channel
233 *
234 * @return The plane index of the specific channel of the specific format
235 */
236inline int plane_idx_from_channel(Format format, Channel channel)
237{
238 switch(format)
239 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100240 // Single planar formats have a single plane
241 case Format::U8:
242 case Format::U16:
243 case Format::S16:
244 case Format::U32:
245 case Format::S32:
246 case Format::F16:
247 case Format::F32:
248 case Format::UV88:
249 case Format::RGB888:
250 case Format::RGBA8888:
251 case Format::YUYV422:
252 case Format::UYVY422:
253 return 0;
254 // Multi planar formats
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100255 case Format::NV12:
256 case Format::NV21:
257 {
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100258 // Channel U and V share the same plane of format UV88
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100259 switch(channel)
260 {
261 case Channel::Y:
262 return 0;
263 case Channel::U:
264 case Channel::V:
265 return 1;
266 default:
267 ARM_COMPUTE_ERROR("Not supported channel");
268 return 0;
269 }
270 }
271 case Format::IYUV:
272 case Format::YUV444:
273 {
274 switch(channel)
275 {
276 case Channel::Y:
277 return 0;
278 case Channel::U:
279 return 1;
280 case Channel::V:
281 return 2;
282 default:
283 ARM_COMPUTE_ERROR("Not supported channel");
284 return 0;
285 }
286 }
287 default:
288 ARM_COMPUTE_ERROR("Not supported format");
289 return 0;
290 }
291}
292
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100293/** Return the channel index of a given channel given an input format.
294 *
295 * @param[in] format Input format
296 * @param[in] channel Input channel
297 *
298 * @return The channel index of the specific channel of the specific format
299 */
300inline int channel_idx_from_format(Format format, Channel channel)
301{
302 switch(format)
303 {
304 case Format::RGB888:
305 {
306 switch(channel)
307 {
308 case Channel::R:
309 return 0;
310 case Channel::G:
311 return 1;
312 case Channel::B:
313 return 2;
314 default:
315 ARM_COMPUTE_ERROR("Not supported channel");
316 return 0;
317 }
318 }
319 case Format::RGBA8888:
320 {
321 switch(channel)
322 {
323 case Channel::R:
324 return 0;
325 case Channel::G:
326 return 1;
327 case Channel::B:
328 return 2;
329 case Channel::A:
330 return 3;
331 default:
332 ARM_COMPUTE_ERROR("Not supported channel");
333 return 0;
334 }
335 }
336 case Format::YUYV422:
337 {
338 switch(channel)
339 {
340 case Channel::Y:
341 return 0;
342 case Channel::U:
343 return 1;
344 case Channel::V:
345 return 3;
346 default:
347 ARM_COMPUTE_ERROR("Not supported channel");
348 return 0;
349 }
350 }
351 case Format::UYVY422:
352 {
353 switch(channel)
354 {
355 case Channel::Y:
356 return 1;
357 case Channel::U:
358 return 0;
359 case Channel::V:
360 return 2;
361 default:
362 ARM_COMPUTE_ERROR("Not supported channel");
363 return 0;
364 }
365 }
366 case Format::NV12:
367 {
368 switch(channel)
369 {
370 case Channel::Y:
371 return 0;
372 case Channel::U:
373 return 0;
374 case Channel::V:
375 return 1;
376 default:
377 ARM_COMPUTE_ERROR("Not supported channel");
378 return 0;
379 }
380 }
381 case Format::NV21:
382 {
383 switch(channel)
384 {
385 case Channel::Y:
386 return 0;
387 case Channel::U:
388 return 1;
389 case Channel::V:
390 return 0;
391 default:
392 ARM_COMPUTE_ERROR("Not supported channel");
393 return 0;
394 }
395 }
396 case Format::YUV444:
397 case Format::IYUV:
398 {
399 switch(channel)
400 {
401 case Channel::Y:
402 return 0;
403 case Channel::U:
404 return 0;
405 case Channel::V:
406 return 0;
407 default:
408 ARM_COMPUTE_ERROR("Not supported channel");
409 return 0;
410 }
411 }
412 default:
413 ARM_COMPUTE_ERROR("Not supported format");
414 return 0;
415 }
416}
417
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100418/** Return the number of planes for a given format
419 *
420 * @param[in] format Input format
421 *
422 * @return The number of planes for a given image format.
423 */
424inline size_t num_planes_from_format(Format format)
425{
426 switch(format)
427 {
428 case Format::U8:
429 case Format::S16:
430 case Format::U16:
431 case Format::S32:
432 case Format::U32:
433 case Format::F16:
434 case Format::F32:
435 case Format::RGB888:
436 case Format::RGBA8888:
437 case Format::YUYV422:
438 case Format::UYVY422:
439 return 1;
440 case Format::NV12:
441 case Format::NV21:
442 return 2;
443 case Format::IYUV:
444 case Format::YUV444:
445 return 3;
446 default:
447 ARM_COMPUTE_ERROR("Not supported format");
448 return 0;
449 }
450}
451
452/** Return the number of channels for a given single-planar pixel format
453 *
454 * @param[in] format Input format
455 *
456 * @return The number of channels for a given image format.
457 */
458inline size_t num_channels_from_format(Format format)
459{
460 switch(format)
461 {
462 case Format::U8:
463 case Format::U16:
464 case Format::S16:
465 case Format::U32:
466 case Format::S32:
467 case Format::F16:
468 case Format::F32:
469 return 1;
470 // Because the U and V channels are subsampled
471 // these formats appear like having only 2 channels:
472 case Format::YUYV422:
473 case Format::UYVY422:
474 return 2;
475 case Format::UV88:
476 return 2;
477 case Format::RGB888:
478 return 3;
479 case Format::RGBA8888:
480 return 4;
481 //Doesn't make sense for planar formats:
482 case Format::NV12:
483 case Format::NV21:
484 case Format::IYUV:
485 case Format::YUV444:
486 default:
487 return 0;
488 }
489}
490
Chunosovd621bca2017-11-03 17:33:15 +0700491/** Return the promoted data type of a given data type.
492 *
493 * @note If promoted data type is not supported an error will be thrown
494 *
495 * @param[in] dt Data type to get the promoted type of.
496 *
497 * @return Promoted data type
498 */
499inline DataType get_promoted_data_type(DataType dt)
500{
501 switch(dt)
502 {
503 case DataType::U8:
504 return DataType::U16;
505 case DataType::S8:
506 return DataType::S16;
507 case DataType::QS8:
508 return DataType::QS16;
509 case DataType::U16:
510 return DataType::U32;
511 case DataType::S16:
512 return DataType::S32;
513 case DataType::QS16:
514 return DataType::QS32;
515 case DataType::QASYMM8:
516 case DataType::F16:
517 case DataType::U32:
518 case DataType::S32:
519 case DataType::F32:
520 case DataType::QS32:
521 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
522 default:
523 ARM_COMPUTE_ERROR("Undefined data type!");
524 }
525 return DataType::UNKNOWN;
526}
527
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100528/** Return true if the given format has horizontal subsampling.
529 *
530 * @param[in] format Format to determine subsampling.
531 *
532 * @return True if the format can be subsampled horizontaly.
533 */
534inline bool has_format_horizontal_subsampling(Format format)
535{
536 return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
537}
538
539/** Return true if the given format has vertical subsampling.
540 *
541 * @param[in] format Format to determine subsampling.
542 *
543 * @return True if the format can be subsampled verticaly.
544 */
545inline bool has_format_vertical_subsampling(Format format)
546{
547 return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
548}
549
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100550/** Separate a 2D convolution into two 1D convolutions
Anthony Barbierf202e502017-11-23 18:02:04 +0000551 *
552 * @param[in] conv 2D convolution
553 * @param[out] conv_col 1D vertical convolution
554 * @param[out] conv_row 1D horizontal convolution
555 * @param[in] size Size of the 2D convolution
556 *
557 * @return true if the separation was successful
558 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100559inline bool separate_matrix(const int16_t *conv, int16_t *conv_col, int16_t *conv_row, uint8_t size)
560{
561 int32_t min_col = -1;
562 int16_t min_col_val = -1;
563
564 for(int32_t i = 0; i < size; ++i)
565 {
566 if(conv[i] != 0 && (min_col < 0 || abs(min_col_val) > abs(conv[i])))
567 {
568 min_col = i;
569 min_col_val = conv[i];
570 }
571 }
572
573 if(min_col < 0)
574 {
575 return false;
576 }
577
578 for(uint32_t j = 0; j < size; ++j)
579 {
580 conv_col[j] = conv[min_col + j * size];
581 }
582
583 for(uint32_t i = 0; i < size; i++)
584 {
585 if(static_cast<int>(i) == min_col)
586 {
587 conv_row[i] = 1;
588 }
589 else
590 {
591 int16_t coeff = conv[i] / conv[min_col];
592
593 for(uint32_t j = 1; j < size; ++j)
594 {
595 if(conv[i + j * size] != (conv_col[j] * coeff))
596 {
597 return false;
598 }
599 }
600
601 conv_row[i] = coeff;
602 }
603 }
604
605 return true;
606}
607
608/** Calculate the scale of the given square matrix
609 *
610 * The scale is the absolute value of the sum of all the coefficients in the matrix.
611 *
612 * @note If the coefficients add up to 0 then the scale is set to 1.
613 *
614 * @param[in] matrix Matrix coefficients
615 * @param[in] matrix_size Number of elements per side of the square matrix. (Number of coefficients = matrix_size * matrix_size).
616 *
617 * @return The absolute value of the sum of the coefficients if they don't add up to 0, otherwise 1.
618 */
619inline uint32_t calculate_matrix_scale(const int16_t *matrix, unsigned int matrix_size)
620{
621 const size_t size = matrix_size * matrix_size;
622
623 return std::max(1, std::abs(std::accumulate(matrix, matrix + size, 0)));
624}
625
steniu017ce53c62017-09-29 14:55:00 +0100626/** Calculate the output shapes of the depth concatenate function.
627 *
628 * @param[in] inputs_vector The vector that stores all the pointers to input.
629 *
630 * @return the output shape
631 */
632template <typename T>
633TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector)
634{
635 TensorShape out_shape = inputs_vector[0]->info()->tensor_shape();
636
637 size_t max_x = 0;
638 size_t max_y = 0;
639 size_t depth = 0;
640
641 for(const auto &tensor : inputs_vector)
642 {
643 ARM_COMPUTE_ERROR_ON(tensor == nullptr);
644 const TensorShape shape = tensor->info()->tensor_shape();
645 max_x = std::max(shape.x(), max_x);
646 max_y = std::max(shape.y(), max_y);
647 depth += shape.z();
648 }
649
650 out_shape.set(0, max_x);
651 out_shape.set(1, max_y);
652 out_shape.set(2, depth);
653
654 return out_shape;
655}
656
Ioan-Cristian Szabo9414f642017-10-27 17:35:40 +0100657/** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
658 *
659 * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
660 * Android Source: https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java
661 * WebM: https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM
662 * libYUV: https://bugs.chromium.org/p/libyuv/issues/detail?id=198&can=1&q=odd%20width
663 * YUVPlayer: https://sourceforge.net/p/raw-yuvplayer/bugs/1/
664 *
665 * @param[in, out] shape Tensor shape of 2D size
666 * @param[in] format Format of the tensor
667 *
668 */
669inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
670{
671 TensorShape output{ shape };
672
673 // Force width to be even for formats which require subsampling of the U and V channels
674 if(has_format_horizontal_subsampling(format))
675 {
676 output.set(0, output.x() & ~1U);
677 }
678
679 // Force height to be even for formats which require subsampling of the U and V channels
680 if(has_format_vertical_subsampling(format))
681 {
682 output.set(1, output.y() & ~1U);
683 }
684
685 return output;
686}
687
688/** Calculate subsampled shape for a given format and channel
689 *
690 * @param[in] shape Shape of the tensor to calculate the extracted channel.
691 * @param[in] format Format of the tensor.
692 * @param[in] channel Channel to create tensor shape to be extracted.
693 *
694 * @return The subsampled tensor shape.
695 */
696inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
697{
698 TensorShape output{ shape };
699
700 // Subsample shape only for U or V channel
701 if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
702 {
703 // Subsample width for the tensor shape when channel is U or V
704 if(has_format_horizontal_subsampling(format))
705 {
706 output.set(0, output.x() / 2U);
707 }
708
709 // Subsample height for the tensor shape when channel is U or V
710 if(has_format_vertical_subsampling(format))
711 {
712 output.set(1, output.y() / 2U);
713 }
714 }
715
716 return output;
717}
718
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100719/** Calculate accurary required by the horizontal and vertical convolution computations
720 *
721 * @param[in] conv_col Pointer to the vertical vector of the separated convolution filter
722 * @param[in] conv_row Pointer to the horizontal vector of the convolution filter
723 * @param[in] size Number of elements per vector of the separated matrix
724 *
725 * @return The return type is a pair. The first element of the pair is the biggest data type needed for the first stage. The second
726 * element of the pair is the biggest data type needed for the second stage.
727 */
728inline std::pair<DataType, DataType> data_type_for_convolution(const int16_t *conv_col, const int16_t *conv_row, size_t size)
729{
730 DataType first_stage = DataType::UNKNOWN;
731 DataType second_stage = DataType::UNKNOWN;
732
733 auto gez = [](const int16_t &v)
734 {
735 return v >= 0;
736 };
737
738 auto accu_neg = [](const int &first, const int &second)
739 {
740 return first + (second < 0 ? second : 0);
741 };
742
743 auto accu_pos = [](const int &first, const int &second)
744 {
745 return first + (second > 0 ? second : 0);
746 };
747
748 const bool only_positive_coefficients = std::all_of(conv_row, conv_row + size, gez) && std::all_of(conv_col, conv_col + size, gez);
749
750 if(only_positive_coefficients)
751 {
752 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0) * UINT8_MAX;
753 const int max_value = std::accumulate(conv_col, conv_col + size, 0) * max_row_value;
754
755 first_stage = (max_row_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
756
757 second_stage = (max_value <= UINT16_MAX) ? DataType::U16 : DataType::S32;
758 }
759 else
760 {
761 const int min_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_neg) * UINT8_MAX;
762 const int max_row_value = std::accumulate(conv_row, conv_row + size, 0, accu_pos) * UINT8_MAX;
763 const int neg_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_neg);
764 const int pos_coeffs_sum = std::accumulate(conv_col, conv_col + size, 0, accu_pos);
765 const int min_value = neg_coeffs_sum * max_row_value + pos_coeffs_sum * min_row_value;
766 const int max_value = neg_coeffs_sum * min_row_value + pos_coeffs_sum * max_row_value;
767
768 first_stage = ((INT16_MIN <= min_row_value) && (max_row_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
769
770 second_stage = ((INT16_MIN <= min_value) && (max_value <= INT16_MAX)) ? DataType::S16 : DataType::S32;
771 }
772
773 return std::make_pair(first_stage, second_stage);
774}
775
776/** Calculate the accuracy required by the squared convolution calculation.
777 *
778 *
779 * @param[in] conv Pointer to the squared convolution matrix
780 * @param[in] size The total size of the convolution matrix
781 *
782 * @return The return is the biggest data type needed to do the convolution
783 */
784inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t size)
785{
786 auto gez = [](const int16_t v)
787 {
788 return v >= 0;
789 };
790
791 const bool only_positive_coefficients = std::all_of(conv, conv + size, gez);
792
793 if(only_positive_coefficients)
794 {
795 const int max_conv_value = std::accumulate(conv, conv + size, 0) * UINT8_MAX;
796 if(max_conv_value <= UINT16_MAX)
797 {
798 return DataType::U16;
799 }
800 else
801 {
802 return DataType::S32;
803 }
804 }
805 else
806 {
807 const int min_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
808 {
809 return b < 0 ? a + b : a;
810 })
811 * UINT8_MAX;
812
813 const int max_value = std::accumulate(conv, conv + size, 0, [](int a, int b)
814 {
815 return b > 0 ? a + b : a;
816 })
817 * UINT8_MAX;
818
819 if((INT16_MIN <= min_value) && (INT16_MAX >= max_value))
820 {
821 return DataType::S16;
822 }
823 else
824 {
825 return DataType::S32;
826 }
827 }
828}
829
Georgios Pinitas4074c992018-01-30 18:13:46 +0000830/** Calculate padding requirements in case of SAME padding
831 *
832 * @param[in] input_shape Input shape
833 * @param[in] weights_shape Weights shape
834 * @param[in] conv_info Convolution information (containing strides)
835 *
836 * @return PadStrideInfo for SAME padding
837 */
838PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info);
839
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100840/** Returns expected shape for the deconvolution output tensor.
841 *
842 * @param[in] out_dims widht and height of the output tensor, these values can be obtained with the function deconvolution_output_dimensions.
843 * @param[in] input Shape of the input tensor.
844 * @param[in] weights Shape of the weights tensor.
845 *
846 * @return Deconvolution output tensor shape.
847 */
848TensorShape deconvolution_output_shape(const std::pair<unsigned int, unsigned int> &out_dims, TensorShape input, TensorShape weights);
849
850/** Returns expected width and height of the deconvolution's output tensor.
851 *
Michalis Spyrou780db4e2017-11-23 09:49:51 +0000852 * @param[in] in_width Width of input tensor (Number of columns)
853 * @param[in] in_height Height of input tensor (Number of rows)
854 * @param[in] kernel_width Kernel width.
855 * @param[in] kernel_height Kernel height.
856 * @param[in] padx X axis padding.
857 * @param[in] pady Y axis padding.
858 * @param[in] inner_border_right The number of zeros added to right edge of the input.
859 * @param[in] inner_border_top The number of zeros added to top edge of the input.
860 * @param[in] stride_x X axis input stride.
861 * @param[in] stride_y Y axis input stride.
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100862 *
863 * @return A pair with the new width in the first position and the new height in the second.
864 */
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100865const std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
866 unsigned int kernel_width, unsigned int kernel_height,
Michalis Spyrou780db4e2017-11-23 09:49:51 +0000867 unsigned int padx, unsigned int pady, unsigned int inner_border_right, unsigned int inner_border_top,
868 unsigned int stride_x, unsigned int stride_y);
Pablo Tellof5f34bb2017-08-22 13:34:13 +0100869
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100870/** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
871 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100872 * @param[in] width Width of input tensor (Number of columns)
873 * @param[in] height Height of input tensor (Number of rows)
874 * @param[in] kernel_width Kernel width.
875 * @param[in] kernel_height Kernel height.
876 * @param[in] pad_stride_info Pad and stride information.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100877 *
878 * @return A pair with the new width in the first position and the new height in the second.
879 */
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100880const std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsigned int height,
881 unsigned int kernel_width, unsigned int kernel_height,
882 const PadStrideInfo &pad_stride_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100883
884/** Convert a tensor format into a string.
885 *
886 * @param[in] format @ref Format to be translated to string.
887 *
888 * @return The string describing the format.
889 */
890const std::string &string_from_format(Format format);
891
892/** Convert a channel identity into a string.
893 *
894 * @param[in] channel @ref Channel to be translated to string.
895 *
896 * @return The string describing the channel.
897 */
898const std::string &string_from_channel(Channel channel);
899
900/** Convert a data type identity into a string.
901 *
902 * @param[in] dt @ref DataType to be translated to string.
903 *
904 * @return The string describing the data type.
905 */
906const std::string &string_from_data_type(DataType dt);
907/** Convert a matrix pattern into a string.
908 *
909 * @param[in] pattern @ref MatrixPattern to be translated to string.
910 *
911 * @return The string describing the matrix pattern.
912 */
913const std::string &string_from_matrix_pattern(MatrixPattern pattern);
914/** Translates a given activation function to a string.
915 *
916 * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
917 *
918 * @return The string describing the activation function.
919 */
920const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
921/** Translates a given non linear function to a string.
922 *
923 * @param[in] function @ref NonLinearFilterFunction to be translated to string.
924 *
925 * @return The string describing the non linear function.
926 */
927const std::string &string_from_non_linear_filter_function(NonLinearFilterFunction function);
928/** Translates a given interpolation policy to a string.
929 *
930 * @param[in] policy @ref InterpolationPolicy to be translated to string.
931 *
932 * @return The string describing the interpolation policy.
933 */
934const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
935/** Translates a given border mode policy to a string.
936 *
937 * @param[in] border_mode @ref BorderMode to be translated to string.
938 *
939 * @return The string describing the border mode.
940 */
941const std::string &string_from_border_mode(BorderMode border_mode);
942/** Translates a given normalization type to a string.
943 *
944 * @param[in] type @ref NormType to be translated to string.
945 *
946 * @return The string describing the normalization type.
947 */
948const std::string &string_from_norm_type(NormType type);
Georgios Pinitascdf51452017-08-31 14:21:36 +0100949/** Translates a given pooling type to a string.
950 *
951 * @param[in] type @ref PoolingType to be translated to string.
952 *
953 * @return The string describing the pooling type.
954 */
955const std::string &string_from_pooling_type(PoolingType type);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100956/** Lower a given string.
957 *
958 * @param[in] val Given string to lower.
959 *
960 * @return The lowered string
961 */
962std::string lower_string(const std::string &val);
963
964/** Check if a given data type is of floating point type
965 *
966 * @param[in] dt Input data type.
967 *
968 * @return True if data type is of floating point type, else false.
969 */
970inline bool is_data_type_float(DataType dt)
971{
972 switch(dt)
973 {
974 case DataType::F16:
975 case DataType::F32:
976 return true;
977 default:
978 return false;
979 }
980}
981
Georgios Pinitas05078ec2017-11-02 13:06:59 +0000982/** Check if a given data type is of quantized type
983 *
984 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
985 *
986 * @param[in] dt Input data type.
987 *
988 * @return True if data type is of quantized type, else false.
989 */
990inline bool is_data_type_quantized(DataType dt)
991{
992 switch(dt)
993 {
994 case DataType::QS8:
995 case DataType::QASYMM8:
996 case DataType::QS16:
997 case DataType::QS32:
998 return true;
999 default:
1000 return false;
1001 }
1002}
1003
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001004/** Check if a given data type is of fixed point type
1005 *
1006 * @param[in] dt Input data type.
1007 *
1008 * @return True if data type is of fixed point type, else false.
1009 */
1010inline bool is_data_type_fixed_point(DataType dt)
1011{
1012 switch(dt)
1013 {
1014 case DataType::QS8:
1015 case DataType::QS16:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001016 case DataType::QS32:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001017 return true;
1018 default:
1019 return false;
1020 }
1021}
1022
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001023/** Check if a given data type is of asymmetric quantized type
1024 *
1025 * @param[in] dt Input data type.
1026 *
1027 * @return True if data type is of symmetric quantized type, else false.
1028 */
Anton Lokhmotovaf6204c2017-11-08 09:34:19 +00001029inline bool is_data_type_quantized_asymmetric(DataType dt)
Georgios Pinitas05078ec2017-11-02 13:06:59 +00001030{
1031 switch(dt)
1032 {
1033 case DataType::QASYMM8:
1034 return true;
1035 default:
1036 return false;
1037 }
1038}
1039
Georgios Pinitas89010962017-08-04 14:58:27 +01001040/** Create a string with the float in full precision.
1041 *
1042 * @param val Floating point value
1043 *
1044 * @return String with the floating point value.
1045 */
1046inline std::string float_to_string_with_full_precision(float val)
1047{
1048 std::stringstream ss;
1049 ss.precision(std::numeric_limits<float>::digits10 + 1);
1050 ss << val;
1051 return ss.str();
1052}
1053
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001054/** Print consecutive elements to an output stream.
1055 *
1056 * @param[out] s Output stream to print the elements to.
1057 * @param[in] ptr Pointer to print the elements from.
1058 * @param[in] n Number of elements to print.
1059 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1060 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1061 */
1062template <typename T>
1063void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1064{
1065 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1066
1067 for(unsigned int i = 0; i < n; ++i)
1068 {
1069 // Set stream width as it is not a "sticky" stream manipulator
1070 if(stream_width != 0)
1071 {
1072 s.width(stream_width);
1073 }
Anthony Barbier7068f992017-10-26 15:23:08 +01001074
1075 if(std::is_same<typename std::decay<T>::type, half>::value)
1076 {
1077 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1078 s << std::right << static_cast<T>(ptr[i]) << element_delim;
1079 }
1080 else
1081 {
1082 s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1083 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001084 }
1085}
1086
1087/** Identify the maximum width of n consecutive elements.
1088 *
1089 * @param[in] s The output stream which will be used to print the elements. Used to extract the stream format.
1090 * @param[in] ptr Pointer to the elements.
1091 * @param[in] n Number of elements.
1092 *
1093 * @return The maximum width of the elements.
1094 */
1095template <typename T>
1096int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1097{
1098 using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1099
1100 int max_width = -1;
1101 for(unsigned int i = 0; i < n; ++i)
1102 {
1103 std::stringstream ss;
1104 ss.copyfmt(s);
Anthony Barbier7068f992017-10-26 15:23:08 +01001105
1106 if(std::is_same<typename std::decay<T>::type, half>::value)
1107 {
1108 // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1109 ss << static_cast<T>(ptr[i]);
1110 }
1111 else
1112 {
1113 ss << static_cast<print_type>(ptr[i]);
1114 }
1115
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001116 max_width = std::max<int>(max_width, ss.str().size());
1117 }
1118 return max_width;
1119}
1120
1121/** Print consecutive elements to an output stream.
1122 *
1123 * @param[out] s Output stream to print the elements to.
1124 * @param[in] dt Data type of the elements
1125 * @param[in] ptr Pointer to print the elements from.
1126 * @param[in] n Number of elements to print.
1127 * @param[in] stream_width (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1128 * @param[in] element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1129 */
1130void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1131
1132/** Identify the maximum width of n consecutive elements.
1133 *
1134 * @param[in] s Output stream to print the elements to.
1135 * @param[in] dt Data type of the elements
1136 * @param[in] ptr Pointer to print the elements from.
1137 * @param[in] n Number of elements to print.
1138 *
1139 * @return The maximum width of the elements.
1140 */
1141int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
1142}
1143#endif /*__ARM_COMPUTE_UTILS_H__ */