blob: d9109e45657c7162ca13db5cfdb3a4d1dab8146b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2016-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TYPES_H__
25#define __ARM_COMPUTE_TYPES_H__
26
27#include "arm_compute/core/Coordinates.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000028#include "arm_compute/core/QAsymm8.h"
29#include "arm_compute/core/Rounding.h"
Isabella Gottardi6e464c32018-01-26 12:32:45 +000030#include "arm_compute/core/Size2D.h"
Georgios Pinitas8795ffb2017-12-01 16:13:40 +000031#include "arm_compute/core/Strides.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/TensorShape.h"
Georgios Pinitas583137c2017-08-31 18:12:42 +010033#include "support/Half.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000035#include <cmath>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include <cstddef>
37#include <cstdint>
38#include <string>
39#include <utility>
40
41namespace arm_compute
42{
Georgios Pinitas583137c2017-08-31 18:12:42 +010043/** 16-bit floating point type */
44using half = half_float::half;
45
Georgios Pinitas8795ffb2017-12-01 16:13:40 +000046/** Permutation vector */
47using PermutationVector = Strides;
48
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049/** Image colour formats */
50enum class Format
51{
Daniil Efremov02bf80d2017-11-22 00:26:51 +070052 UNKNOWN, /**< Unknown image format */
53 U8, /**< 1 channel, 1 U8 per channel */
54 S16, /**< 1 channel, 1 S16 per channel */
55 U16, /**< 1 channel, 1 U16 per channel */
56 S32, /**< 1 channel, 1 S32 per channel */
57 U32, /**< 1 channel, 1 U32 per channel */
58 F16, /**< 1 channel, 1 F16 per channel */
59 F32, /**< 1 channel, 1 F32 per channel */
60 UV88, /**< 2 channel, 1 U8 per channel */
61 RGB888, /**< 3 channels, 1 U8 per channel */
62 RGBA8888, /**< 4 channels, 1 U8 per channel */
63 YUV444, /**< A 3 plane of 8 bit 4:4:4 sampled Y, U, V planes */
64 YUYV422, /**< A single plane of 32-bit macro pixel of Y0, U0, Y1, V0 bytes */
65 NV12, /**< A 2 plane YUV format of Luma (Y) and interleaved UV data at 4:2:0 sampling */
66 NV21, /**< A 2 plane YUV format of Luma (Y) and interleaved VU data at 4:2:0 sampling */
67 IYUV, /**< A 3 plane of 8-bit 4:2:0 sampled Y, U, V planes */
68 UYVY422 /**< A single plane of 32-bit macro pixel of U0, Y0, V0, Y1 byte */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069};
70
71/** Available data types */
72enum class DataType
73{
Alex Gildayc357c472018-03-21 13:54:09 +000074 UNKNOWN, /**< Unknown data type */
75 U8, /**< unsigned 8-bit number */
76 S8, /**< signed 8-bit number */
Alex Gildayc357c472018-03-21 13:54:09 +000077 QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */
78 U16, /**< unsigned 16-bit number */
79 S16, /**< signed 16-bit number */
Alex Gildayc357c472018-03-21 13:54:09 +000080 U32, /**< unsigned 32-bit number */
81 S32, /**< signed 32-bit number */
Alex Gildayc357c472018-03-21 13:54:09 +000082 U64, /**< unsigned 64-bit number */
83 S64, /**< signed 64-bit number */
84 F16, /**< 16-bit floating-point number */
85 F32, /**< 32-bit floating-point number */
86 F64, /**< 64-bit floating-point number */
87 SIZET /**< size_t */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010088};
89
Daniil Efremov02bf80d2017-11-22 00:26:51 +070090/** Available Sampling Policies */
91enum class SamplingPolicy
92{
93 CENTER, /**< Samples are taken at pixel center */
94 TOP_LEFT /**< Samples are taken at pixel top left corner */
95};
96
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097/** Constant value of the border pixels when using BorderMode::CONSTANT */
98constexpr uint8_t CONSTANT_BORDER_VALUE = 199;
99
Alex Gildayc357c472018-03-21 13:54:09 +0000100/** Constant value used to indicate a half-scale pyramid */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100101constexpr float SCALE_PYRAMID_HALF = 0.5f;
102
Alex Gildayc357c472018-03-21 13:54:09 +0000103/** Constant value used to indicate a ORB scaled pyramid */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100104constexpr float SCALE_PYRAMID_ORB = 8.408964152537146130583778358414e-01;
105
Georgios Pinitas4074c992018-01-30 18:13:46 +0000106/** Supported tensor data layouts */
107enum class DataLayout
108{
Alex Gildayc357c472018-03-21 13:54:09 +0000109 UNKNOWN, /**< Unknown data layout */
110 NCHW, /**< Num samples, channels, height, width */
111 NHWC /**< Num samples, height, width, channels */
Georgios Pinitas4074c992018-01-30 18:13:46 +0000112};
113
Isabella Gottardid17a6772018-02-27 17:41:55 +0000114/** Supported tensor data layout dimensions */
115enum class DataLayoutDimension
116{
Alex Gildayc357c472018-03-21 13:54:09 +0000117 CHANNEL, /**< channel */
118 HEIGHT, /**< height */
119 WIDTH, /**< width */
120 BATCHES /**< batches */
Isabella Gottardid17a6772018-02-27 17:41:55 +0000121};
122
Michel Iwaniec00633802017-10-12 14:14:15 +0100123/** Quantization settings (used for QASYMM8 data type) */
124struct QuantizationInfo
125{
Alex Gildayc357c472018-03-21 13:54:09 +0000126 /** Default constructor */
Georgios Pinitasf8d8f3a2018-06-06 17:57:04 +0100127 QuantizationInfo() noexcept
128 : scale(0.0f),
129 offset(0)
Michel Iwaniec00633802017-10-12 14:14:15 +0100130 {
131 }
132
Alex Gildayc357c472018-03-21 13:54:09 +0000133 /** Construct quantization info.
134 *
135 * @param[in] scale Scale.
136 * @param[in] offset Offset.
137 */
Michel Iwaniec00633802017-10-12 14:14:15 +0100138 QuantizationInfo(float scale, int offset)
139 : scale(scale), offset(offset)
140 {
141 }
142
Alex Gildayc357c472018-03-21 13:54:09 +0000143 /** Check whether equal to a given quantization info.
144 *
145 * @param[in] other Other quantization info.
146 *
147 * @return True if the given quantization info is the same.
148 */
Daniil Efremoveed841c2017-11-09 19:05:25 +0700149 bool operator==(const QuantizationInfo &other)
150 {
151 return scale == other.scale && offset == other.offset;
152 }
153
Alex Gildayc357c472018-03-21 13:54:09 +0000154 /** Check whether not equal to a given quantization info.
155 *
156 * @param[in] other Other quantization info.
157 *
158 * @return True if the given quantization info is not the same.
159 */
Daniil Efremoveed841c2017-11-09 19:05:25 +0700160 bool operator!=(const QuantizationInfo &other)
161 {
162 return !(*this == other);
163 }
164
Michel Iwaniec00633802017-10-12 14:14:15 +0100165 float scale; /**< scale */
166 int offset; /**< offset */
167
Alex Gildayc357c472018-03-21 13:54:09 +0000168 /** Quantizes a value using the scale/offset in this QuantizationInfo
169 *
170 * @param[in] value Value to quantize.
171 * @param[in] rounding_policy Policy to use when rounding.
172 *
173 * @return the quantized value.
174 */
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000175 qasymm8_t quantize(float value, RoundingPolicy rounding_policy) const
Michel Iwaniec00633802017-10-12 14:14:15 +0100176 {
177 ARM_COMPUTE_ERROR_ON_MSG(scale == 0, "QuantizationInfo::quantize: scale == 0");
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000178 return sqcvt_qasymm8_f32(value, scale, offset, rounding_policy);
Michel Iwaniec00633802017-10-12 14:14:15 +0100179 }
180
Alex Gildayc357c472018-03-21 13:54:09 +0000181 /** Dequantizes a value using the scale/offset in this QuantizationInfo
182 *
183 * @param[in] value Value to dequantize.
184 *
185 * @return the original value before quantization.
186 */
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000187 float dequantize(qasymm8_t value) const
Michel Iwaniec00633802017-10-12 14:14:15 +0100188 {
189 ARM_COMPUTE_ERROR_ON_MSG(scale == 0, "QuantizationInfo::dequantize: scale == 0");
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000190 return scvt_f32_qasymm8(value, scale, offset);
Michel Iwaniec00633802017-10-12 14:14:15 +0100191 }
192
Alex Gildayc357c472018-03-21 13:54:09 +0000193 /** Indicates whether this QuantizationInfo has valid settings or not
194 *
195 * @return True if the this has invalid settings.
196 */
Michel Iwaniec00633802017-10-12 14:14:15 +0100197 bool empty() const
198 {
199 return scale == 0;
200 }
201};
202
Alex Gildayc357c472018-03-21 13:54:09 +0000203/** Container for valid region of a window */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204struct ValidRegion
205{
Alex Gildayc357c472018-03-21 13:54:09 +0000206 /** Default constructor */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207 ValidRegion()
208 : anchor{}, shape{}
209 {
210 }
211
Alex Gildayc357c472018-03-21 13:54:09 +0000212 /** Allow instances of this class to be copy constructed */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213 ValidRegion(const ValidRegion &) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000214 /** Allow instances of this class to be move constructed */
215 ValidRegion(ValidRegion &&) = default;
216 /** Allow instances of this class to be copied */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100217 ValidRegion &operator=(const ValidRegion &) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000218 /** Allow instances of this class to be moved */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219 ValidRegion &operator=(ValidRegion &&) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000220 /** Default destructor */
221 ~ValidRegion() = default;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Alex Gildayc357c472018-03-21 13:54:09 +0000223 /** Constructor for a valid region with default number of dimensions
224 *
225 * @param[in] an_anchor Anchor for the start of the valid region.
226 * @param[in] a_shape Shape of the valid region.
227 *
228 */
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000229 ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape)
230 : anchor{ an_anchor }, shape{ a_shape }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231 {
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000232 anchor.set_num_dimensions(std::max(anchor.num_dimensions(), shape.num_dimensions()));
233 }
234
Alex Gildayc357c472018-03-21 13:54:09 +0000235 /** Constructor for a valid region with specified number of dimensions
236 *
237 * @param[in] an_anchor Anchor for the start of the valid region.
238 * @param[in] a_shape Shape of the valid region.
239 * @param[in] num_dimensions Number of dimensions (must be >= number of dimensions of anchor and shape).
240 *
241 */
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000242 ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape, size_t num_dimensions)
243 : anchor{ an_anchor }, shape{ a_shape }
244 {
245 ARM_COMPUTE_ERROR_ON(num_dimensions < std::max(anchor.num_dimensions(), shape.num_dimensions()));
246 anchor.set_num_dimensions(num_dimensions);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100247 }
248
249 /** Return the start of the valid region for the given dimension @p d */
250 int start(unsigned int d) const
251 {
252 return anchor[d];
253 }
254
255 /** Return the end of the valid region for the given dimension @p d */
256 int end(unsigned int d) const
257 {
258 return anchor[d] + shape[d];
259 }
260
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000261 /** Accessor to set the value of anchor and shape for one of the dimensions.
262 *
263 * @param[in] dimension Dimension for which the value is set.
264 * @param[in] start Value to be set in anchor for the dimension.
265 * @param[in] size Value to be set in shape for the dimension.
266 *
267 * @return *this.
268 */
269 ValidRegion &set(size_t dimension, int start, size_t size)
270 {
271 anchor.set(dimension, start);
272 shape.set(dimension, size);
273 return *this;
274 }
275
Alex Gildayc357c472018-03-21 13:54:09 +0000276 Coordinates anchor; /**< Anchor for the start of the valid region. */
277 TensorShape shape; /**< Shape of the valid region. */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100278};
279
280/** Methods available to handle borders */
281enum class BorderMode
282{
283 UNDEFINED, /**< Borders are left undefined */
284 CONSTANT, /**< Pixels outside the image are assumed to have a constant value */
285 REPLICATE /**< Pixels outside the image are assumed to have the same value as the closest image pixel */
286};
287
288/** Container for 2D border size */
289struct BorderSize
290{
291 /** Empty border, i.e. no border */
292 constexpr BorderSize()
293 : top{ 0 }, right{ 0 }, bottom{ 0 }, left{ 0 }
294 {
295 }
296
297 /** Border with equal size around the 2D plane */
Moritz Pflanzer7655a672017-09-23 11:57:33 +0100298 explicit constexpr BorderSize(unsigned int size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100299 : top{ size }, right{ size }, bottom{ size }, left{ size }
300 {
301 }
302
303 /** Border with same size for top/bottom and left/right */
304 constexpr BorderSize(unsigned int top_bottom, unsigned int left_right)
305 : top{ top_bottom }, right{ left_right }, bottom{ top_bottom }, left{ left_right }
306 {
307 }
308
309 /** Border with different sizes */
310 constexpr BorderSize(unsigned int top, unsigned int right, unsigned int bottom, unsigned int left)
311 : top{ top }, right{ right }, bottom{ bottom }, left{ left }
312 {
313 }
314
315 /** Check if the entire border is zero */
316 constexpr bool empty() const
317 {
318 return top == 0 && right == 0 && bottom == 0 && left == 0;
319 }
320
321 /** Check if the border is the same size on all sides */
322 constexpr bool uniform() const
323 {
324 return top == right && top == bottom && top == left;
325 }
326
Alex Gildayc357c472018-03-21 13:54:09 +0000327 /** Scale this border size.
328 *
329 * @param[in] scale Scale to multiply border size by.
330 *
331 * @return *this.
332 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333 BorderSize &operator*=(float scale)
334 {
335 top *= scale;
336 right *= scale;
337 bottom *= scale;
338 left *= scale;
339
340 return *this;
341 }
342
Alex Gildayc357c472018-03-21 13:54:09 +0000343 /** Scale a copy of this border size.
344 *
345 * @param[in] scale Scale to multiply border size by.
346 *
347 * @return a scaled copy of this.
348 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349 BorderSize operator*(float scale)
350 {
351 BorderSize size = *this;
352 size *= scale;
353
354 return size;
355 }
356
Alex Gildayc357c472018-03-21 13:54:09 +0000357 /** Limit this border size.
358 *
359 * @param[in] limit Border size to limit this border size to.
360 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100361 void limit(const BorderSize &limit)
362 {
363 top = std::min(top, limit.top);
364 right = std::min(right, limit.right);
365 bottom = std::min(bottom, limit.bottom);
366 left = std::min(left, limit.left);
367 }
368
Alex Gildayc357c472018-03-21 13:54:09 +0000369 unsigned int top; /**< top of the border */
370 unsigned int right; /**< right of the border */
371 unsigned int bottom; /**< bottom of the border */
372 unsigned int left; /**< left of the border */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100373};
374
Alex Gildayc357c472018-03-21 13:54:09 +0000375/** Container for 2D padding size */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100376using PaddingSize = BorderSize;
377
378/** Policy to handle overflow */
379enum class ConvertPolicy
380{
381 WRAP, /**< Wrap around */
382 SATURATE /**< Saturate */
383};
384
385/** Interpolation method */
386enum class InterpolationPolicy
387{
388 NEAREST_NEIGHBOR, /**< Output values are defined to match the source pixel whose center is nearest to the sample position */
389 BILINEAR, /**< Output values are defined by bilinear interpolation between the pixels */
390 AREA, /**< Output values are determined by averaging the source pixels whose areas fall under the area of the destination pixel, projected onto the source image */
391};
392
393/** Bilinear Interpolation method used by LKTracker */
394enum class BilinearInterpolation
395{
Alex Gildayc357c472018-03-21 13:54:09 +0000396 BILINEAR_OLD_NEW, /**< Old-new method */
397 BILINEAR_SCHARR /**< Scharr method */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100398};
399
400/** Threshold mode */
401enum class ThresholdType
402{
403 BINARY, /**< Threshold with one value */
404 RANGE /**< Threshold with two values*/
405};
406
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100407/** Termination criteria */
408enum class Termination
409{
Alex Gildayc357c472018-03-21 13:54:09 +0000410 TERM_CRITERIA_EPSILON, /**< Terminate when within epsilon of a threshold */
411 TERM_CRITERIA_ITERATIONS, /**< Terminate after a maximum number of iterations */
412 TERM_CRITERIA_BOTH /**< Terminate on whichever of the other conditions occurs first */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100413};
414
415/** Magnitude calculation type. */
416enum class MagnitudeType
417{
418 L1NORM, /**< L1 normalization type */
419 L2NORM /**< L2 normalization type */
420};
421
422/** Phase calculation type.
423 *
424 * @note When PhaseType == SIGNED, each angle is mapped to the range 0 to 255 inclusive otherwise angles between 0 and 180
425 */
426enum class PhaseType
427{
428 SIGNED, /**< Angle range: [0, 360] */
429 UNSIGNED /**< Angle range: [0, 180] */
430};
431
432/** Keypoint type */
433struct KeyPoint
434{
435 int32_t x{ 0 }; /**< X coordinates */
436 int32_t y{ 0 }; /**< Y coordinates */
437 float strength{ 0.f }; /**< Strength of the point */
438 float scale{ 0.f }; /**< Scale initialized to 0 by the corner detector */
439 float orientation{ 0.f }; /**< Orientation initialized to 0 by the corner detector */
440 int32_t tracking_status{ 0 }; /**< Status initialized to 1 by the corner detector, set to 0 when the point is lost */
441 float error{ 0.f }; /**< Tracking error initialized to 0 by the corner detector */
442};
443
Alex Gildayc357c472018-03-21 13:54:09 +0000444/** Internal key point */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100445using InternalKeypoint = std::tuple<float, float, float>; /* x,y,strength */
446
447/** Rectangle type */
448struct Rectangle
449{
450 uint16_t x; /**< Top-left x coordinate */
451 uint16_t y; /**< Top-left y coordinate */
452 uint16_t width; /**< Width of the rectangle */
453 uint16_t height; /**< Height of the rectangle */
454};
455
456/** Coordinate type */
457struct Coordinates2D
458{
459 int32_t x; /**< X coordinates */
460 int32_t y; /**< Y coordinates */
461};
462
463/** Coordinate type */
464struct Coordinates3D
465{
466 uint32_t x; /**< X coordinates */
467 uint32_t y; /**< Y coordinates */
468 uint32_t z; /**< Z coordinates */
469};
470
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100471/** Region of interest */
472struct ROI
473{
474 Rectangle rect; /**< Rectangle specifying the region of interest */
475 uint16_t batch_idx; /**< The batch index of the region of interest */
476};
477
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100478/** Available channels */
479enum class Channel
480{
481 UNKNOWN, /** Unknown channel format */
482 C0, /**< First channel (used by formats with unknown channel types). */
483 C1, /**< Second channel (used by formats with unknown channel types). */
484 C2, /**< Third channel (used by formats with unknown channel types). */
485 C3, /**< Fourth channel (used by formats with unknown channel types). */
486 R, /**< Red channel. */
487 G, /**< Green channel. */
488 B, /**< Blue channel. */
489 A, /**< Alpha channel. */
490 Y, /**< Luma channel. */
491 U, /**< Cb/U channel. */
492 V /**< Cr/V/Value channel. */
493};
494
495/** Available matrix patterns */
496enum class MatrixPattern
497{
498 BOX, /**< Box pattern matrix. */
499 CROSS, /**< Cross pattern matrix. */
500 DISK, /**< Disk pattern matrix. */
501 OTHER /**< Any other matrix pattern. */
502};
503
504/** Available non linear functions. */
505enum class NonLinearFilterFunction : unsigned
506{
507 MEDIAN = 0, /**< Non linear median filter. */
508 MIN = 1, /**< Non linear erode. */
509 MAX = 2, /**< Non linear dilate. */
510};
511
Georgios Pinitasd9769582017-08-03 10:19:40 +0100512/** Available reduction operations */
513enum class ReductionOperation
514{
515 SUM_SQUARE, /**< Sum of squares */
Michalis Spyrou04f089c2017-08-08 17:42:38 +0100516 SUM, /**< Sum */
Georgios Pinitasd9769582017-08-03 10:19:40 +0100517};
518
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100519/** The normalization type used for the normalization layer */
520enum class NormType
521{
522 IN_MAP_1D, /**< Normalization applied within the same map in 1D region */
523 IN_MAP_2D, /**< Normalization applied within the same map in 2D region */
524 CROSS_MAP /**< Normalization applied cross maps */
525};
526
527/** Normalization type for Histogram of Oriented Gradients (HOG) */
528enum class HOGNormType
529{
530 L2_NORM = 1, /**< L2-norm */
531 L2HYS_NORM = 2, /**< L2-norm followed by clipping */
532 L1_NORM = 3 /**< L1 norm */
533};
534
535/** Detection window used for the object detection. The detection window keeps the following information:
536 *
537 * -# Geometry of the rectangular window (x/y of top-left corner and width/height)
538 * -# Index of the class used for evaluating which class the detection window belongs to
539 * -# Confidence value (score) obtained with the classifier
540 */
541struct DetectionWindow
542{
543 uint16_t x{ 0 }; /**< Top-left x coordinate */
544 uint16_t y{ 0 }; /**< Top-left y coordinate */
545 uint16_t width{ 0 }; /**< Width of the detection window */
546 uint16_t height{ 0 }; /**< Height of the detection window */
547 uint16_t idx_class{ 0 }; /**< Index of the class */
548 float score{ 0.f }; /**< Confidence value for the detection window */
549};
550
551/** Dimension rounding type when down-scaling on CNNs
552 * @note Used in pooling and convolution layer
553 */
554enum class DimensionRoundingType
555{
556 FLOOR, /**< Floor rounding */
557 CEIL /**< Ceil rounding */
558};
559
560/** Available pooling types */
561enum class PoolingType
562{
563 MAX, /**< Max Pooling */
Georgios Pinitascdf51452017-08-31 14:21:36 +0100564 AVG, /**< Average Pooling */
565 L2 /**< L2 Pooling */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100566};
567
568/** Padding and stride information class */
569class PadStrideInfo
570{
571public:
572 /** Constructor
573 *
574 * @param[in] stride_x (Optional) Stride, in elements, across x. Defaults to 1.
575 * @param[in] stride_y (Optional) Stride, in elements, across y. Defaults to 1.
576 * @param[in] pad_x (Optional) Padding, in elements, across x. Defaults to 0.
577 * @param[in] pad_y (Optional) Padding, in elements, across y. Defaults to 0.
578 * @param[in] round (Optional) Dimensions rounding. Defaults to @ref FLOOR.
579 */
580 PadStrideInfo(unsigned int stride_x = 1, unsigned int stride_y = 1,
581 unsigned int pad_x = 0, unsigned int pad_y = 0,
582 DimensionRoundingType round = DimensionRoundingType::FLOOR)
583 : _stride(std::make_pair(stride_x, stride_y)),
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100584 _pad_left(pad_x),
585 _pad_top(pad_y),
586 _pad_right(pad_x),
587 _pad_bottom(pad_y),
588 _round_type(round)
589 {
590 }
591 /** Constructor
592 *
593 * @param[in] stride_x Stride, in elements, across x.
594 * @param[in] stride_y Stride, in elements, across y.
595 * @param[in] pad_left Padding across x on the left, in elements.
596 * @param[in] pad_top Padding across y on the top, in elements.
597 * @param[in] pad_right Padding across x on the right, in elements.
598 * @param[in] pad_bottom Padding across y on the bottom, in elements.
599 * @param[in] round Dimensions rounding.
600 */
601 PadStrideInfo(unsigned int stride_x, unsigned int stride_y,
602 unsigned int pad_left, unsigned int pad_right,
603 unsigned int pad_top, unsigned int pad_bottom,
604 DimensionRoundingType round)
605 : _stride(std::make_pair(stride_x, stride_y)),
606 _pad_left(pad_left),
607 _pad_top(pad_top),
608 _pad_right(pad_right),
609 _pad_bottom(pad_bottom),
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100610 _round_type(round)
611 {
612 }
Alex Gildayc357c472018-03-21 13:54:09 +0000613 /** Get the stride.
614 *
615 * @return a pair: stride x, stride y.
616 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100617 std::pair<unsigned int, unsigned int> stride() const
618 {
619 return _stride;
620 }
Alex Gildayc357c472018-03-21 13:54:09 +0000621 /** Check whether the padding is symmetric.
622 *
623 * @return True if the padding is symmetric.
624 */
Anthony Barbier21f67d62018-02-16 15:17:48 +0000625 bool padding_is_symmetric() const
626 {
627 return (_pad_left == _pad_right) && (_pad_top == _pad_bottom);
628 }
Alex Gildayc357c472018-03-21 13:54:09 +0000629 /** Get the padding.
630 *
631 * @note This should only be used when the padding is symmetric.
632 *
633 * @return a pair: padding left/right, padding top/bottom
634 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100635 std::pair<unsigned int, unsigned int> pad() const
636 {
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100637 //this accessor should be used only when padding is symmetric
Anthony Barbier21f67d62018-02-16 15:17:48 +0000638 ARM_COMPUTE_ERROR_ON(!padding_is_symmetric());
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100639 return std::make_pair(_pad_left, _pad_top);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100640 }
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100641
Alex Gildayc357c472018-03-21 13:54:09 +0000642 /** Get the left padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100643 unsigned int pad_left() const
644 {
645 return _pad_left;
646 }
Alex Gildayc357c472018-03-21 13:54:09 +0000647 /** Get the right padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100648 unsigned int pad_right() const
649 {
650 return _pad_right;
651 }
Alex Gildayc357c472018-03-21 13:54:09 +0000652 /** Get the top padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100653 unsigned int pad_top() const
654 {
655 return _pad_top;
656 }
Alex Gildayc357c472018-03-21 13:54:09 +0000657 /** Get the bottom padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100658 unsigned int pad_bottom() const
659 {
660 return _pad_bottom;
661 }
662
Alex Gildayc357c472018-03-21 13:54:09 +0000663 /** Get the rounding type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100664 DimensionRoundingType round() const
665 {
666 return _round_type;
667 }
668
Alex Gildayc357c472018-03-21 13:54:09 +0000669 /** Check whether this has any padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100670 bool has_padding() const
671 {
672 return (_pad_left != 0 || _pad_top != 0 || _pad_right != 0 || _pad_bottom != 0);
673 }
674
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100675private:
676 std::pair<unsigned int, unsigned int> _stride;
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100677 unsigned int _pad_left;
678 unsigned int _pad_top;
679 unsigned int _pad_right;
680 unsigned int _pad_bottom;
681
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100682 DimensionRoundingType _round_type;
683};
684
Georgios Pinitas7d66a8e2018-07-17 12:28:42 +0100685/** Fully connected layer info */
686struct FullyConnectedLayerInfo
687{
688 DataLayout weights_trained_layout{ DataLayout::NCHW }; /**< Layout that the weights have been trained with. */
689 bool transpose_weights{ true }; /**< Transpose weights if true. */
690 bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */
691 bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */
Georgios Pinitasc55cef12018-08-01 15:24:18 +0100692
693 /** Sets the weights trained data layout
694 *
695 * @param[in] layout Data layout that the weights were trained with
696 *
697 * @return Updated object
698 */
699 FullyConnectedLayerInfo &set_weights_trained_layout(DataLayout layout)
700 {
701 weights_trained_layout = layout;
702 return *this;
703 }
Georgios Pinitas195b0ba2018-08-02 17:18:51 +0100704 /** Sets the transpose weights flag
705 *
706 * @param[in] should_transpose_weights Boolean flag indicating if weights should be transposed
707 *
708 * @return Updated object
709 */
710 FullyConnectedLayerInfo &set_transpose_weights(bool should_transpose_weights)
711 {
712 transpose_weights = should_transpose_weights;
713 return *this;
714 }
Georgios Pinitas7d66a8e2018-07-17 12:28:42 +0100715};
716
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100717/** Pooling Layer Information class */
718class PoolingLayerInfo
719{
720public:
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000721 /** Default Constructor */
722 PoolingLayerInfo()
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000723 : _pool_type(PoolingType::MAX), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo()), _exclude_padding(false), _is_global_pooling(false)
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000724 {
725 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726 /** Default Constructor
727 *
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000728 * @param[in] pool_type Pooling type @ref PoolingType.
729 * @param[in] pool_size Pooling size, in elements, across x and y.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100730 * @param[in] pad_stride_info (Optional) Padding and stride information @ref PadStrideInfo
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000731 * @param[in] exclude_padding (Optional) Strategy when accounting padding in calculations.
732 * True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
733 * Defaults to false;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100734 */
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000735 explicit PoolingLayerInfo(PoolingType pool_type,
736 unsigned int pool_size,
737 PadStrideInfo pad_stride_info = PadStrideInfo(),
738 bool exclude_padding = false)
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000739 : _pool_type(pool_type), _pool_size(Size2D(pool_size, pool_size)), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false)
740 {
741 }
742 /** Default Constructor
743 *
744 * @param[in] pool_type Pooling type @ref PoolingType.
745 * @param[in] pool_size Pooling size, in elements, across x and y.
746 * @param[in] pad_stride_info (Optional) Padding and stride information @ref PadStrideInfo
747 * @param[in] exclude_padding (Optional) Strategy when accounting padding in calculations.
748 * True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
749 * Defaults to false;
750 */
751 explicit PoolingLayerInfo(PoolingType pool_type,
752 Size2D pool_size,
753 PadStrideInfo pad_stride_info = PadStrideInfo(),
754 bool exclude_padding = false)
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000755 : _pool_type(pool_type), _pool_size(pool_size), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false)
756 {
757 }
758 /** Default Constructor
759 *
760 * @note This constructor is used for global pooling
761 *
762 * @param[in] pool_type Pooling type @ref PoolingType.
763 */
764 explicit PoolingLayerInfo(PoolingType pool_type)
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000765 : _pool_type(pool_type), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo(1, 1, 0, 0)), _exclude_padding(false), _is_global_pooling(true)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100766 {
767 }
Alex Gildayc357c472018-03-21 13:54:09 +0000768 /** Get the pooling type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 PoolingType pool_type() const
770 {
771 return _pool_type;
772 }
Alex Gildayc357c472018-03-21 13:54:09 +0000773 /** Get the pooling size */
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000774 const Size2D &pool_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100775 {
776 return _pool_size;
777 }
Alex Gildayc357c472018-03-21 13:54:09 +0000778 /** Get the padding and stride */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100779 PadStrideInfo pad_stride_info() const
780 {
781 return _pad_stride_info;
782 }
Alex Gildayc357c472018-03-21 13:54:09 +0000783 /** Check if padding is excluded in calculations */
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000784 bool exclude_padding() const
785 {
786 return _exclude_padding;
787 }
Alex Gildayc357c472018-03-21 13:54:09 +0000788 /** Check if is global pooling */
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000789 bool is_global_pooling() const
790 {
791 return _is_global_pooling;
792 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100793
794private:
795 PoolingType _pool_type;
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000796 Size2D _pool_size;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100797 PadStrideInfo _pad_stride_info;
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000798 bool _exclude_padding;
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000799 bool _is_global_pooling;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100800};
801
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100802/** ROI Pooling Layer Information class */
803class ROIPoolingLayerInfo
804{
805public:
806 /** Default Constructor
807 *
808 * @param[in] pooled_width Pooled width of the layer.
809 * @param[in] pooled_height Pooled height of the layer.
810 * @param[in] spatial_scale Spatial scale to be applied to the ROI coordinates and dimensions.
811 */
812 ROIPoolingLayerInfo(unsigned int pooled_width, unsigned int pooled_height, float spatial_scale)
813 : _pooled_width(pooled_width), _pooled_height(pooled_height), _spatial_scale(spatial_scale)
814 {
815 }
Alex Gildayc357c472018-03-21 13:54:09 +0000816 /** Get the pooled width of the layer */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100817 unsigned int pooled_width() const
818 {
819 return _pooled_width;
820 }
Alex Gildayc357c472018-03-21 13:54:09 +0000821 /** Get the pooled height of the layer */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100822 unsigned int pooled_height() const
823 {
824 return _pooled_height;
825 }
Alex Gildayc357c472018-03-21 13:54:09 +0000826 /** Get the spatial scale */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100827 float spatial_scale() const
828 {
829 return _spatial_scale;
830 }
831
832private:
833 unsigned int _pooled_width;
834 unsigned int _pooled_height;
835 float _spatial_scale;
836};
837
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100838/** Activation Layer Information class */
839class ActivationLayerInfo
840{
841public:
842 /** Available activation functions */
843 enum class ActivationFunction
844 {
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +0100845 LOGISTIC, /**< Logistic ( \f$ f(x) = \frac{1}{1 + e^{-x}} \f$ ) */
846 TANH, /**< Hyperbolic tangent ( \f$ f(x) = a \cdot tanh(b \cdot x) \f$ ) */
847 RELU, /**< Rectifier ( \f$ f(x) = max(0,x) \f$ ) */
848 BOUNDED_RELU, /**< Upper Bounded Rectifier ( \f$ f(x) = min(a, max(0,x)) \f$ ) */
849 LU_BOUNDED_RELU, /**< Lower and Upper Bounded Rectifier ( \f$ f(x) = min(a, max(b,x)) \f$ ) */
850 LEAKY_RELU, /**< Leaky Rectifier ( \f$ f(x)= log(1+e^x) \f$ ) */
851 SOFT_RELU, /**< Soft Rectifier ( \f$ f(x)= log(1+e^x) \f$ ) */
852 ABS, /**< Absolute ( \f$ f(x)= |x| \f$ ) */
853 SQUARE, /**< Square ( \f$ f(x)= x^2 \f$ )*/
854 SQRT, /**< Square root ( \f$ f(x) = \sqrt{x} \f$ )*/
855 LINEAR /**< Linear ( \f$ f(x)= ax + b \f$ ) */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100856 };
857
Giorgio Arena11674872018-02-07 15:38:12 +0000858 ActivationLayerInfo() = default;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100859 /** Default Constructor
860 *
861 * @param[in] f The activation function to use.
862 * @param[in] a (Optional) The alpha parameter used by some activation functions
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +0100863 * (@ref ActivationFunction::BOUNDED_RELU, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::LINEAR, @ref ActivationFunction::TANH).
864 * @param[in] b (Optional) The beta parameter used by some activation functions (@ref ActivationFunction::LINEAR, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::TANH).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100865 */
866 ActivationLayerInfo(ActivationFunction f, float a = 0.0f, float b = 0.0f)
Giorgio Arena11674872018-02-07 15:38:12 +0000867 : _act(f), _a(a), _b(b), _enabled(true)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100868 {
869 }
Alex Gildayc357c472018-03-21 13:54:09 +0000870 /** Get the type of activation function */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100871 ActivationFunction activation() const
872 {
873 return _act;
874 }
Alex Gildayc357c472018-03-21 13:54:09 +0000875 /** Get the alpha value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876 float a() const
877 {
878 return _a;
879 }
Alex Gildayc357c472018-03-21 13:54:09 +0000880 /** Get the beta value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100881 float b() const
882 {
883 return _b;
884 }
Alex Gildayc357c472018-03-21 13:54:09 +0000885 /** Check if initialised */
Giorgio Arena11674872018-02-07 15:38:12 +0000886 bool enabled() const
887 {
888 return _enabled;
889 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100890
891private:
Giorgio Arena11674872018-02-07 15:38:12 +0000892 ActivationFunction _act = { ActivationLayerInfo::ActivationFunction::LOGISTIC };
893 float _a = {};
894 float _b = {};
895 bool _enabled = { false };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100896};
897
898/** Normalization Layer Information class */
899class NormalizationLayerInfo
900{
901public:
902 /** Default Constructor
903 *
904 * @param[in] type The normalization type. Can be @ref NormType::IN_MAP_1D, @ref NormType::IN_MAP_2D or @ref NORM_TYPE::CROSS_MAP
905 * @param[in] norm_size The normalization size is the number of elements to normalize across. Defaults to 5.
Georgios Pinitas41caa622017-11-16 14:37:08 +0000906 * @param[in] alpha (Optional) Alpha parameter used by normalization equation. Defaults to 0.0001.
907 * @param[in] beta (Optional) Beta parameter used by normalization equation. Defaults to 0.5.
908 * @param[in] kappa (Optional) Kappa parameter used by [Krichevksy 2012] Across Channel Local Brightness Normalization equation.
909 * @param[in] is_scaled (Optional) Boolean that specifies if alpha will be scaled by the normalization size or not.
910 * Should be false to follow [Krichevksy 2012].
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100911 */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000912 NormalizationLayerInfo(NormType type, uint32_t norm_size = 5, float alpha = 0.0001f, float beta = 0.5f, float kappa = 1.f, bool is_scaled = true)
913 : _type(type), _norm_size(norm_size), _alpha(alpha), _beta(beta), _kappa(kappa), _is_scaled(is_scaled)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100914 {
915 }
Alex Gildayc357c472018-03-21 13:54:09 +0000916 /** Get the normalization type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100917 NormType type() const
918 {
919 return _type;
920 }
Alex Gildayc357c472018-03-21 13:54:09 +0000921 /** Get the normalization size */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100922 uint32_t norm_size() const
923 {
924 return _norm_size;
925 }
Alex Gildayc357c472018-03-21 13:54:09 +0000926 /** Get the alpha value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100927 float alpha() const
928 {
929 return _alpha;
930 }
Alex Gildayc357c472018-03-21 13:54:09 +0000931 /** Get the beta value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100932 float beta() const
933 {
934 return _beta;
935 }
Alex Gildayc357c472018-03-21 13:54:09 +0000936 /** Get the kappa value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100937 float kappa() const
938 {
939 return _kappa;
940 }
Alex Gildayc357c472018-03-21 13:54:09 +0000941 /** Check if normalization is cross map */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000942 bool is_cross_map() const
943 {
944 return _type == NormType::CROSS_MAP;
945 }
Alex Gildayc357c472018-03-21 13:54:09 +0000946 /** Check if normalization is not cross map */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000947 bool is_in_map() const
948 {
949 return !is_cross_map();
950 }
951 /** Return the scaling factor of the normalization function.
952 *
953 * If is_scaled is set to false then [Krichevksy 2012] normalization scaling is performed,
954 * where alpha is returned plainly, else alpha is scaled by the total number of elements used for the normalization.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100955 *
956 * @return The normalization scaling factor.
957 */
958 float scale_coeff() const
959 {
960 const uint32_t size = (_type == NormType::IN_MAP_2D) ? _norm_size * _norm_size : _norm_size;
Georgios Pinitas41caa622017-11-16 14:37:08 +0000961 return (_is_scaled) ? (_alpha / size) : _alpha;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100962 }
963
964private:
965 NormType _type;
966 uint32_t _norm_size;
967 float _alpha;
968 float _beta;
969 float _kappa;
Georgios Pinitas41caa622017-11-16 14:37:08 +0000970 bool _is_scaled;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100971};
972
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100973/** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100974class WeightsInfo
975{
976public:
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100977 /** Default constructor */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100978 WeightsInfo()
Michele Di Giorgiob62280a2018-05-31 17:31:05 +0100979 : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100980 {
981 }
982 /** Constructor
983 *
Michele Di Giorgiob62280a2018-05-31 17:31:05 +0100984 * @param[in] are_reshaped True if the weights have been reshaped
985 * @param[in] kernel_width Kernel width.
986 * @param[in] kernel_height Kernel height.
987 * @param[in] num_kernels Number of convolution kernels.
988 * @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100989 */
Michele Di Giorgiob62280a2018-05-31 17:31:05 +0100990 WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false)
991 : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100992 {
993 }
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100994 /** Flag which specifies if the weights tensor has been reshaped.
995 *
996 * @return True if the weights tensors has been reshaped
997 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100998 bool are_reshaped() const
999 {
1000 return _are_reshaped;
1001 };
Gian Marco Iodice559d7712017-08-08 08:38:09 +01001002 /** Return the number of convolution kernels
1003 *
1004 * @return The number of convolution kernels
1005 */
1006 unsigned int num_kernels() const
1007 {
1008 return _num_kernels;
1009 };
Gian Marco Iodice4e288692017-06-27 11:41:59 +01001010 /** Return the width and height of the kernel
1011 *
1012 * @return The width and height of the kernel
1013 */
1014 std::pair<unsigned int, unsigned int> kernel_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001015 {
Gian Marco Iodice4e288692017-06-27 11:41:59 +01001016 return std::make_pair(_kernel_width, _kernel_height);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001017 }
Michele Di Giorgiob62280a2018-05-31 17:31:05 +01001018 bool retain_internal_weights() const
1019 {
1020 return _retain_internal_weights;
1021 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001022
1023private:
1024 const bool _are_reshaped;
Gian Marco Iodice4e288692017-06-27 11:41:59 +01001025 const unsigned int _kernel_width;
1026 const unsigned int _kernel_height;
Gian Marco Iodice559d7712017-08-08 08:38:09 +01001027 const unsigned int _num_kernels;
Michele Di Giorgiob62280a2018-05-31 17:31:05 +01001028 const bool _retain_internal_weights;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001029};
1030
Gian Marco36a0a462018-01-12 10:21:40 +00001031/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
1032 *
1033 * The matrix A can only be reshaped through @ref CLGEMMInterleave4x4Kernel or @ref NEGEMMInterleave4x4Kernel or @ref GCGEMMInterleave4x4Kernel
1034 * Note: Optionally just for @ref CLGEMMInterleave4x4Kernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block
1035 *
1036 * The matrix B can only be reshaped through @ref CLGEMMTranspose1xWKernel or @ref NEGEMMTranspose1xWKernel or @ref GCGEMMTranspose1xWKernel
1037 * Note: Optionally just for @ref CLGEMMTranspose1xWKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block
1038 *
1039 */
1040class GEMMReshapeInfo final
1041{
1042public:
1043 /** Default constructor */
1044 GEMMReshapeInfo()
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001045 : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false)
Gian Marco36a0a462018-01-12 10:21:40 +00001046 {
1047 }
1048 /** Constructor
1049 *
1050 * @param[in] m Number of matrix A rows
1051 * @param[in] n Number of matrix B columns
1052 * @param[in] k Number of matrix A columns or matrix B rows
1053 * @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block
1054 * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001055 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001056 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
1057 * to perform 1x1 convolutions with the NHWC data layout)
Gian Marco36a0a462018-01-12 10:21:40 +00001058 */
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001059 GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false)
1060 : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
1061 _reinterpret_input_as_3d(reinterpret_input_as_3d)
Gian Marco36a0a462018-01-12 10:21:40 +00001062 {
1063 }
1064 /** Number of matrix A rows
1065 *
1066 * @return the number of matrix A rows
1067 */
1068 int m() const
1069 {
1070 return _m;
1071 }
1072 /** Number of matrix B columns
1073 *
1074 * @return the number of matrix B columns
1075 */
1076 int n() const
1077 {
1078 return _n;
1079 }
1080 /** Number of matrix A columns or matrix B rows
1081 *
1082 * @return the number of matrix A columns or matrix B rows
1083 */
1084 int k() const
1085 {
1086 return _k;
1087 }
1088 /** Multiplication factor for the width of the 1xW transposed block
1089 *
1090 * @return the multiplication factor for the width of the 1xW transposed block
1091 */
1092 int mult_transpose1xW_width() const
1093 {
1094 return _mult_transpose1xW_width;
1095 }
1096 /** Multiplication factor for the height of the 4x4 interleaved block
1097 *
1098 * @return the multiplication factor for the height of the 4x4 interleaved block
1099 */
1100 int mult_interleave4x4_height() const
1101 {
1102 return _mult_interleave4x4_height;
1103 }
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001104 /** Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
1105 *
1106 * @note GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case:
1107 * m = depth_output_gemm3d * output_height
1108 *
1109 * @return the depth of the output tensor to be used with the GEMM3D kernel
1110 */
1111 int depth_output_gemm3d() const
1112 {
1113 return _depth_output_gemm3d;
1114 }
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001115 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
1116 *
1117 * @return True if the input tensor has to be reinterpreted as 3D tensor
1118 */
1119 bool reinterpret_input_as_3d() const
1120 {
1121 return _reinterpret_input_as_3d;
1122 };
Gian Marco36a0a462018-01-12 10:21:40 +00001123
1124private:
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001125 const int _m;
1126 const int _n;
1127 const int _k;
1128 const int _mult_transpose1xW_width;
1129 const int _mult_interleave4x4_height;
1130 const int _depth_output_gemm3d;
1131 const bool _reinterpret_input_as_3d;
Gian Marco36a0a462018-01-12 10:21:40 +00001132};
1133
1134/** GEMM information class. This class stores the necessary information to compute GEMM functions
1135 *
1136 * This object also contains the information about how matrix A and matrix B have been reshaped
1137 *
1138 */
Chunosov5124be52017-11-22 20:42:13 +07001139class GEMMInfo
1140{
1141public:
1142 /** Default constructor */
1143 GEMMInfo()
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001144 : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false)
Chunosov5124be52017-11-22 20:42:13 +07001145 {
1146 }
1147 /** Constructor
1148 *
1149 * @param[in] is_a_reshaped True if the matrix A has been reshaped
1150 * @param[in] is_b_reshaped True if the matrix B has been reshaped
1151 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001152 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001153 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
1154 * to perform 1x1 convolutions with the NHWC data layout)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001155 *
Chunosov5124be52017-11-22 20:42:13 +07001156 */
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001157 GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false)
1158 : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
1159 _reinterpret_input_as_3d(reinterpret_input_as_3d)
Chunosov5124be52017-11-22 20:42:13 +07001160 {
1161 }
1162 /** Flag which specifies if the matrix A has been reshaped
1163 *
1164 * @return True if the matrix A has been reshaped
1165 */
1166 bool is_a_reshaped() const
1167 {
1168 return _is_a_reshaped;
1169 };
1170 /** Flag which specifies if the matrix B has been reshaped
1171 *
1172 * @return True if the matrix B has been reshaped
1173 */
1174 bool is_b_reshaped() const
1175 {
1176 return _is_b_reshaped;
1177 };
1178 /** Flag which specifies if the reshape of matrix B should executed only for the first
1179 *
1180 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
1181 *
1182 * @return True if the reshaped of matrix B happens only for the first run
1183 */
1184 bool reshape_b_only_on_first_run() const
1185 {
1186 return _reshape_b_only_on_first_run;
1187 };
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001188 /** Depth of the output when GEMM output is reinterpreted as 3D tensor
Gian Marco36a0a462018-01-12 10:21:40 +00001189 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001190 * @return the depth of the output tensor
Gian Marco36a0a462018-01-12 10:21:40 +00001191 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001192 int depth_output_gemm3d() const
Gian Marco36a0a462018-01-12 10:21:40 +00001193 {
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001194 return _depth_output_gemm3d;
1195 };
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001196 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
1197 *
1198 * @return True if the input tensor has to be reinterpreted as 3D tensor
1199 */
1200 bool reinterpret_input_as_3d() const
1201 {
1202 return _reinterpret_input_as_3d;
1203 };
Chunosov5124be52017-11-22 20:42:13 +07001204
1205private:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001206 const bool _is_a_reshaped;
1207 const bool _is_b_reshaped;
1208 const bool _reshape_b_only_on_first_run;
1209 const int _depth_output_gemm3d;
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001210 const bool _reinterpret_input_as_3d;
Chunosov5124be52017-11-22 20:42:13 +07001211};
1212
Gian Marco Iodice247f52c2018-03-22 11:24:56 +00001213/** Winograd information */
1214struct WinogradInfo
1215{
1216 /** Default constructor
1217 *
1218 * @param[in] output_tile_sz Width and height of the output tile
1219 * @param[in] kernel_sz Width and height of the kernel
1220 * @param[in] input_dims Width and height of the input tensor before the convolution is applied
1221 * @param[in] conv_info Convolution info (Pads, strides)
1222 * @param[in] data_layout Data layout to use for the output tensor once the convolution has been applied
1223 */
1224 WinogradInfo(Size2D output_tile_sz, Size2D kernel_sz, Size2D input_dims, PadStrideInfo conv_info, DataLayout data_layout)
1225 : output_tile_size(output_tile_sz), kernel_size(kernel_sz), input_dimensions(input_dims), convolution_info(conv_info), output_data_layout(data_layout)
1226 {
1227 }
1228
1229 Size2D output_tile_size{}; /**< Width and height of the output tile */
1230 Size2D kernel_size{}; /**< Width and height of the kernel*/
1231 Size2D input_dimensions{}; /**< Width and height of the input tensor before the convolution is applied */
1232 PadStrideInfo convolution_info{}; /**< Convolution info (Pads, strides,...) */
1233 DataLayout output_data_layout{ DataLayout::NCHW }; /**< Data layout to use for the output tensor once the convolution has been applied (NCHW or NHWC) */
1234};
1235
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001236/** IO formatting information class*/
1237struct IOFormatInfo
1238{
1239 /** Precision type used when printing floating point numbers */
1240 enum class PrecisionType
1241 {
1242 Default, /**< Default precision to the one that the current stream has */
1243 Custom, /**< Custom precision specified by the user using the precision parameter */
1244 Full /**< The maximum precision of the floating point representation */
1245 };
1246
1247 /** Specifies the area to be printed, used by Tensor objects */
1248 enum class PrintRegion
1249 {
1250 ValidRegion, /**< Prints the valid region of the Tensor object */
1251 NoPadding, /**< Prints the Tensor object without the padding */
1252 Full /**< Print the tensor object including padding */
1253 };
1254
Alex Gildayc357c472018-03-21 13:54:09 +00001255 /** Construct a set of IO formatting information.
1256 *
1257 * @param[in] print_region Area to be printed. Used by Tensor objects. Default: ValidRegion.
1258 * @param[in] precision_type Precision type for floating point numbers. Default: stream default.
1259 * @param[in] precision Precision value for float point numbers. Default: 10.
1260 * @param[in] align_columns Whether to align columns when printed. Default: true.
1261 * @param[in] element_delim Delimeter between elements. Default: " ".
1262 * @param[in] row_delim Delimenter between rows. Default: "\n".
1263 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001264 IOFormatInfo(PrintRegion print_region = PrintRegion::ValidRegion,
1265 PrecisionType precision_type = PrecisionType::Default,
1266 unsigned int precision = 10,
1267 bool align_columns = true,
1268 std::string element_delim = " ",
1269 std::string row_delim = "\n")
1270 : print_region(print_region),
1271 precision_type(precision_type),
1272 precision(precision),
1273 element_delim(element_delim),
1274 row_delim(row_delim),
1275 align_columns(align_columns)
1276 {
1277 }
1278
Alex Gildayc357c472018-03-21 13:54:09 +00001279 /** Area to be printed by Tensor objects */
1280 PrintRegion print_region;
1281 /** Floating point precision type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001282 PrecisionType precision_type;
Alex Gildayc357c472018-03-21 13:54:09 +00001283 /** Floating point precision */
1284 unsigned int precision;
1285 /** Element delimeter */
1286 std::string element_delim;
1287 /** Row delimeter */
1288 std::string row_delim;
1289 /** Align columns */
1290 bool align_columns;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001291};
Isabella Gottardif07d28d2018-02-06 14:52:43 +00001292
1293/** Available ConvolutionMethod*/
1294enum class ConvolutionMethod
1295{
1296 GEMM, /**< Convolution using GEMM */
1297 DIRECT, /**< Direct convolution */
1298 WINOGRAD /**< Convolution using Winograd */
1299};
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001300} // namespace arm_compute
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001301#endif /* __ARM_COMPUTE_TYPES_H__ */