blob: 73baf789183053247c60cd4af4b5fff592c92122 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2016-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TYPES_H__
25#define __ARM_COMPUTE_TYPES_H__
26
27#include "arm_compute/core/Coordinates.h"
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000028#include "arm_compute/core/QAsymm8.h"
29#include "arm_compute/core/Rounding.h"
Isabella Gottardi6e464c32018-01-26 12:32:45 +000030#include "arm_compute/core/Size2D.h"
Georgios Pinitas8795ffb2017-12-01 16:13:40 +000031#include "arm_compute/core/Strides.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/TensorShape.h"
Georgios Pinitas583137c2017-08-31 18:12:42 +010033#include "support/Half.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034
Michel Iwaniec5dfeae62017-11-29 10:48:23 +000035#include <cmath>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include <cstddef>
37#include <cstdint>
38#include <string>
39#include <utility>
40
41namespace arm_compute
42{
Georgios Pinitas583137c2017-08-31 18:12:42 +010043/** 16-bit floating point type */
44using half = half_float::half;
45
Georgios Pinitas8795ffb2017-12-01 16:13:40 +000046/** Permutation vector */
47using PermutationVector = Strides;
48
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049/** Image colour formats */
50enum class Format
51{
Daniil Efremov02bf80d2017-11-22 00:26:51 +070052 UNKNOWN, /**< Unknown image format */
53 U8, /**< 1 channel, 1 U8 per channel */
54 S16, /**< 1 channel, 1 S16 per channel */
55 U16, /**< 1 channel, 1 U16 per channel */
56 S32, /**< 1 channel, 1 S32 per channel */
57 U32, /**< 1 channel, 1 U32 per channel */
58 F16, /**< 1 channel, 1 F16 per channel */
59 F32, /**< 1 channel, 1 F32 per channel */
60 UV88, /**< 2 channel, 1 U8 per channel */
61 RGB888, /**< 3 channels, 1 U8 per channel */
62 RGBA8888, /**< 4 channels, 1 U8 per channel */
63 YUV444, /**< A 3 plane of 8 bit 4:4:4 sampled Y, U, V planes */
64 YUYV422, /**< A single plane of 32-bit macro pixel of Y0, U0, Y1, V0 bytes */
65 NV12, /**< A 2 plane YUV format of Luma (Y) and interleaved UV data at 4:2:0 sampling */
66 NV21, /**< A 2 plane YUV format of Luma (Y) and interleaved VU data at 4:2:0 sampling */
67 IYUV, /**< A 3 plane of 8-bit 4:2:0 sampled Y, U, V planes */
68 UYVY422 /**< A single plane of 32-bit macro pixel of U0, Y0, V0, Y1 byte */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069};
70
71/** Available data types */
72enum class DataType
73{
Alex Gildayc357c472018-03-21 13:54:09 +000074 UNKNOWN, /**< Unknown data type */
75 U8, /**< unsigned 8-bit number */
76 S8, /**< signed 8-bit number */
77 QS8, /**< quantized, symmetric fixed-point 8-bit number */
78 QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */
79 U16, /**< unsigned 16-bit number */
80 S16, /**< signed 16-bit number */
81 QS16, /**< quantized, symmetric fixed-point 16-bit number */
82 U32, /**< unsigned 32-bit number */
83 S32, /**< signed 32-bit number */
84 QS32, /**< quantized, symmetric fixed-point 32-bit number */
85 U64, /**< unsigned 64-bit number */
86 S64, /**< signed 64-bit number */
87 F16, /**< 16-bit floating-point number */
88 F32, /**< 32-bit floating-point number */
89 F64, /**< 64-bit floating-point number */
90 SIZET /**< size_t */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010091};
92
Daniil Efremov02bf80d2017-11-22 00:26:51 +070093/** Available Sampling Policies */
94enum class SamplingPolicy
95{
96 CENTER, /**< Samples are taken at pixel center */
97 TOP_LEFT /**< Samples are taken at pixel top left corner */
98};
99
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100100/** Constant value of the border pixels when using BorderMode::CONSTANT */
101constexpr uint8_t CONSTANT_BORDER_VALUE = 199;
102
Alex Gildayc357c472018-03-21 13:54:09 +0000103/** Constant value used to indicate a half-scale pyramid */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100104constexpr float SCALE_PYRAMID_HALF = 0.5f;
105
Alex Gildayc357c472018-03-21 13:54:09 +0000106/** Constant value used to indicate a ORB scaled pyramid */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100107constexpr float SCALE_PYRAMID_ORB = 8.408964152537146130583778358414e-01;
108
Georgios Pinitas4074c992018-01-30 18:13:46 +0000109/** Supported tensor data layouts */
110enum class DataLayout
111{
Alex Gildayc357c472018-03-21 13:54:09 +0000112 UNKNOWN, /**< Unknown data layout */
113 NCHW, /**< Num samples, channels, height, width */
114 NHWC /**< Num samples, height, width, channels */
Georgios Pinitas4074c992018-01-30 18:13:46 +0000115};
116
Isabella Gottardid17a6772018-02-27 17:41:55 +0000117/** Supported tensor data layout dimensions */
118enum class DataLayoutDimension
119{
Alex Gildayc357c472018-03-21 13:54:09 +0000120 CHANNEL, /**< channel */
121 HEIGHT, /**< height */
122 WIDTH, /**< width */
123 BATCHES /**< batches */
Isabella Gottardid17a6772018-02-27 17:41:55 +0000124};
125
Michel Iwaniec00633802017-10-12 14:14:15 +0100126/** Quantization settings (used for QASYMM8 data type) */
127struct QuantizationInfo
128{
Alex Gildayc357c472018-03-21 13:54:09 +0000129 /** Default constructor */
Michel Iwaniec00633802017-10-12 14:14:15 +0100130 QuantizationInfo()
131 : scale(0.0f), offset(0)
132 {
133 }
134
Alex Gildayc357c472018-03-21 13:54:09 +0000135 /** Construct quantization info.
136 *
137 * @param[in] scale Scale.
138 * @param[in] offset Offset.
139 */
Michel Iwaniec00633802017-10-12 14:14:15 +0100140 QuantizationInfo(float scale, int offset)
141 : scale(scale), offset(offset)
142 {
143 }
144
Alex Gildayc357c472018-03-21 13:54:09 +0000145 /** Check whether equal to a given quantization info.
146 *
147 * @param[in] other Other quantization info.
148 *
149 * @return True if the given quantization info is the same.
150 */
Daniil Efremoveed841c2017-11-09 19:05:25 +0700151 bool operator==(const QuantizationInfo &other)
152 {
153 return scale == other.scale && offset == other.offset;
154 }
155
Alex Gildayc357c472018-03-21 13:54:09 +0000156 /** Check whether not equal to a given quantization info.
157 *
158 * @param[in] other Other quantization info.
159 *
160 * @return True if the given quantization info is not the same.
161 */
Daniil Efremoveed841c2017-11-09 19:05:25 +0700162 bool operator!=(const QuantizationInfo &other)
163 {
164 return !(*this == other);
165 }
166
Michel Iwaniec00633802017-10-12 14:14:15 +0100167 float scale; /**< scale */
168 int offset; /**< offset */
169
Alex Gildayc357c472018-03-21 13:54:09 +0000170 /** Quantizes a value using the scale/offset in this QuantizationInfo
171 *
172 * @param[in] value Value to quantize.
173 * @param[in] rounding_policy Policy to use when rounding.
174 *
175 * @return the quantized value.
176 */
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000177 qasymm8_t quantize(float value, RoundingPolicy rounding_policy) const
Michel Iwaniec00633802017-10-12 14:14:15 +0100178 {
179 ARM_COMPUTE_ERROR_ON_MSG(scale == 0, "QuantizationInfo::quantize: scale == 0");
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000180 return sqcvt_qasymm8_f32(value, scale, offset, rounding_policy);
Michel Iwaniec00633802017-10-12 14:14:15 +0100181 }
182
Alex Gildayc357c472018-03-21 13:54:09 +0000183 /** Dequantizes a value using the scale/offset in this QuantizationInfo
184 *
185 * @param[in] value Value to dequantize.
186 *
187 * @return the original value before quantization.
188 */
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000189 float dequantize(qasymm8_t value) const
Michel Iwaniec00633802017-10-12 14:14:15 +0100190 {
191 ARM_COMPUTE_ERROR_ON_MSG(scale == 0, "QuantizationInfo::dequantize: scale == 0");
Michel Iwaniec5dfeae62017-11-29 10:48:23 +0000192 return scvt_f32_qasymm8(value, scale, offset);
Michel Iwaniec00633802017-10-12 14:14:15 +0100193 }
194
Alex Gildayc357c472018-03-21 13:54:09 +0000195 /** Indicates whether this QuantizationInfo has valid settings or not
196 *
197 * @return True if the this has invalid settings.
198 */
Michel Iwaniec00633802017-10-12 14:14:15 +0100199 bool empty() const
200 {
201 return scale == 0;
202 }
203};
204
Alex Gildayc357c472018-03-21 13:54:09 +0000205/** Container for valid region of a window */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206struct ValidRegion
207{
Alex Gildayc357c472018-03-21 13:54:09 +0000208 /** Default constructor */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209 ValidRegion()
210 : anchor{}, shape{}
211 {
212 }
213
Alex Gildayc357c472018-03-21 13:54:09 +0000214 /** Allow instances of this class to be copy constructed */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100215 ValidRegion(const ValidRegion &) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000216 /** Allow instances of this class to be move constructed */
217 ValidRegion(ValidRegion &&) = default;
218 /** Allow instances of this class to be copied */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219 ValidRegion &operator=(const ValidRegion &) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000220 /** Allow instances of this class to be moved */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221 ValidRegion &operator=(ValidRegion &&) = default;
Alex Gildayc357c472018-03-21 13:54:09 +0000222 /** Default destructor */
223 ~ValidRegion() = default;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224
Alex Gildayc357c472018-03-21 13:54:09 +0000225 /** Constructor for a valid region with default number of dimensions
226 *
227 * @param[in] an_anchor Anchor for the start of the valid region.
228 * @param[in] a_shape Shape of the valid region.
229 *
230 */
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000231 ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape)
232 : anchor{ an_anchor }, shape{ a_shape }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100233 {
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000234 anchor.set_num_dimensions(std::max(anchor.num_dimensions(), shape.num_dimensions()));
235 }
236
Alex Gildayc357c472018-03-21 13:54:09 +0000237 /** Constructor for a valid region with specified number of dimensions
238 *
239 * @param[in] an_anchor Anchor for the start of the valid region.
240 * @param[in] a_shape Shape of the valid region.
241 * @param[in] num_dimensions Number of dimensions (must be >= number of dimensions of anchor and shape).
242 *
243 */
Diego Lopez Recasbcbc9702017-12-18 11:28:27 +0000244 ValidRegion(const Coordinates &an_anchor, const TensorShape &a_shape, size_t num_dimensions)
245 : anchor{ an_anchor }, shape{ a_shape }
246 {
247 ARM_COMPUTE_ERROR_ON(num_dimensions < std::max(anchor.num_dimensions(), shape.num_dimensions()));
248 anchor.set_num_dimensions(num_dimensions);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249 }
250
251 /** Return the start of the valid region for the given dimension @p d */
252 int start(unsigned int d) const
253 {
254 return anchor[d];
255 }
256
257 /** Return the end of the valid region for the given dimension @p d */
258 int end(unsigned int d) const
259 {
260 return anchor[d] + shape[d];
261 }
262
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000263 /** Accessor to set the value of anchor and shape for one of the dimensions.
264 *
265 * @param[in] dimension Dimension for which the value is set.
266 * @param[in] start Value to be set in anchor for the dimension.
267 * @param[in] size Value to be set in shape for the dimension.
268 *
269 * @return *this.
270 */
271 ValidRegion &set(size_t dimension, int start, size_t size)
272 {
273 anchor.set(dimension, start);
274 shape.set(dimension, size);
275 return *this;
276 }
277
Alex Gildayc357c472018-03-21 13:54:09 +0000278 Coordinates anchor; /**< Anchor for the start of the valid region. */
279 TensorShape shape; /**< Shape of the valid region. */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100280};
281
282/** Methods available to handle borders */
283enum class BorderMode
284{
285 UNDEFINED, /**< Borders are left undefined */
286 CONSTANT, /**< Pixels outside the image are assumed to have a constant value */
287 REPLICATE /**< Pixels outside the image are assumed to have the same value as the closest image pixel */
288};
289
290/** Container for 2D border size */
291struct BorderSize
292{
293 /** Empty border, i.e. no border */
294 constexpr BorderSize()
295 : top{ 0 }, right{ 0 }, bottom{ 0 }, left{ 0 }
296 {
297 }
298
299 /** Border with equal size around the 2D plane */
Moritz Pflanzer7655a672017-09-23 11:57:33 +0100300 explicit constexpr BorderSize(unsigned int size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100301 : top{ size }, right{ size }, bottom{ size }, left{ size }
302 {
303 }
304
305 /** Border with same size for top/bottom and left/right */
306 constexpr BorderSize(unsigned int top_bottom, unsigned int left_right)
307 : top{ top_bottom }, right{ left_right }, bottom{ top_bottom }, left{ left_right }
308 {
309 }
310
311 /** Border with different sizes */
312 constexpr BorderSize(unsigned int top, unsigned int right, unsigned int bottom, unsigned int left)
313 : top{ top }, right{ right }, bottom{ bottom }, left{ left }
314 {
315 }
316
317 /** Check if the entire border is zero */
318 constexpr bool empty() const
319 {
320 return top == 0 && right == 0 && bottom == 0 && left == 0;
321 }
322
323 /** Check if the border is the same size on all sides */
324 constexpr bool uniform() const
325 {
326 return top == right && top == bottom && top == left;
327 }
328
Alex Gildayc357c472018-03-21 13:54:09 +0000329 /** Scale this border size.
330 *
331 * @param[in] scale Scale to multiply border size by.
332 *
333 * @return *this.
334 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100335 BorderSize &operator*=(float scale)
336 {
337 top *= scale;
338 right *= scale;
339 bottom *= scale;
340 left *= scale;
341
342 return *this;
343 }
344
Alex Gildayc357c472018-03-21 13:54:09 +0000345 /** Scale a copy of this border size.
346 *
347 * @param[in] scale Scale to multiply border size by.
348 *
349 * @return a scaled copy of this.
350 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351 BorderSize operator*(float scale)
352 {
353 BorderSize size = *this;
354 size *= scale;
355
356 return size;
357 }
358
Alex Gildayc357c472018-03-21 13:54:09 +0000359 /** Limit this border size.
360 *
361 * @param[in] limit Border size to limit this border size to.
362 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100363 void limit(const BorderSize &limit)
364 {
365 top = std::min(top, limit.top);
366 right = std::min(right, limit.right);
367 bottom = std::min(bottom, limit.bottom);
368 left = std::min(left, limit.left);
369 }
370
Alex Gildayc357c472018-03-21 13:54:09 +0000371 unsigned int top; /**< top of the border */
372 unsigned int right; /**< right of the border */
373 unsigned int bottom; /**< bottom of the border */
374 unsigned int left; /**< left of the border */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100375};
376
Alex Gildayc357c472018-03-21 13:54:09 +0000377/** Container for 2D padding size */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100378using PaddingSize = BorderSize;
379
380/** Policy to handle overflow */
381enum class ConvertPolicy
382{
383 WRAP, /**< Wrap around */
384 SATURATE /**< Saturate */
385};
386
387/** Interpolation method */
388enum class InterpolationPolicy
389{
390 NEAREST_NEIGHBOR, /**< Output values are defined to match the source pixel whose center is nearest to the sample position */
391 BILINEAR, /**< Output values are defined by bilinear interpolation between the pixels */
392 AREA, /**< Output values are determined by averaging the source pixels whose areas fall under the area of the destination pixel, projected onto the source image */
393};
394
395/** Bilinear Interpolation method used by LKTracker */
396enum class BilinearInterpolation
397{
Alex Gildayc357c472018-03-21 13:54:09 +0000398 BILINEAR_OLD_NEW, /**< Old-new method */
399 BILINEAR_SCHARR /**< Scharr method */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400};
401
402/** Threshold mode */
403enum class ThresholdType
404{
405 BINARY, /**< Threshold with one value */
406 RANGE /**< Threshold with two values*/
407};
408
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100409/** Termination criteria */
410enum class Termination
411{
Alex Gildayc357c472018-03-21 13:54:09 +0000412 TERM_CRITERIA_EPSILON, /**< Terminate when within epsilon of a threshold */
413 TERM_CRITERIA_ITERATIONS, /**< Terminate after a maximum number of iterations */
414 TERM_CRITERIA_BOTH /**< Terminate on whichever of the other conditions occurs first */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100415};
416
417/** Magnitude calculation type. */
418enum class MagnitudeType
419{
420 L1NORM, /**< L1 normalization type */
421 L2NORM /**< L2 normalization type */
422};
423
424/** Phase calculation type.
425 *
426 * @note When PhaseType == SIGNED, each angle is mapped to the range 0 to 255 inclusive otherwise angles between 0 and 180
427 */
428enum class PhaseType
429{
430 SIGNED, /**< Angle range: [0, 360] */
431 UNSIGNED /**< Angle range: [0, 180] */
432};
433
434/** Keypoint type */
435struct KeyPoint
436{
437 int32_t x{ 0 }; /**< X coordinates */
438 int32_t y{ 0 }; /**< Y coordinates */
439 float strength{ 0.f }; /**< Strength of the point */
440 float scale{ 0.f }; /**< Scale initialized to 0 by the corner detector */
441 float orientation{ 0.f }; /**< Orientation initialized to 0 by the corner detector */
442 int32_t tracking_status{ 0 }; /**< Status initialized to 1 by the corner detector, set to 0 when the point is lost */
443 float error{ 0.f }; /**< Tracking error initialized to 0 by the corner detector */
444};
445
Alex Gildayc357c472018-03-21 13:54:09 +0000446/** Internal key point */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100447using InternalKeypoint = std::tuple<float, float, float>; /* x,y,strength */
448
449/** Rectangle type */
450struct Rectangle
451{
452 uint16_t x; /**< Top-left x coordinate */
453 uint16_t y; /**< Top-left y coordinate */
454 uint16_t width; /**< Width of the rectangle */
455 uint16_t height; /**< Height of the rectangle */
456};
457
458/** Coordinate type */
459struct Coordinates2D
460{
461 int32_t x; /**< X coordinates */
462 int32_t y; /**< Y coordinates */
463};
464
465/** Coordinate type */
466struct Coordinates3D
467{
468 uint32_t x; /**< X coordinates */
469 uint32_t y; /**< Y coordinates */
470 uint32_t z; /**< Z coordinates */
471};
472
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100473/** Region of interest */
474struct ROI
475{
476 Rectangle rect; /**< Rectangle specifying the region of interest */
477 uint16_t batch_idx; /**< The batch index of the region of interest */
478};
479
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100480/** Available channels */
481enum class Channel
482{
483 UNKNOWN, /** Unknown channel format */
484 C0, /**< First channel (used by formats with unknown channel types). */
485 C1, /**< Second channel (used by formats with unknown channel types). */
486 C2, /**< Third channel (used by formats with unknown channel types). */
487 C3, /**< Fourth channel (used by formats with unknown channel types). */
488 R, /**< Red channel. */
489 G, /**< Green channel. */
490 B, /**< Blue channel. */
491 A, /**< Alpha channel. */
492 Y, /**< Luma channel. */
493 U, /**< Cb/U channel. */
494 V /**< Cr/V/Value channel. */
495};
496
497/** Available matrix patterns */
498enum class MatrixPattern
499{
500 BOX, /**< Box pattern matrix. */
501 CROSS, /**< Cross pattern matrix. */
502 DISK, /**< Disk pattern matrix. */
503 OTHER /**< Any other matrix pattern. */
504};
505
506/** Available non linear functions. */
507enum class NonLinearFilterFunction : unsigned
508{
509 MEDIAN = 0, /**< Non linear median filter. */
510 MIN = 1, /**< Non linear erode. */
511 MAX = 2, /**< Non linear dilate. */
512};
513
Georgios Pinitasd9769582017-08-03 10:19:40 +0100514/** Available reduction operations */
515enum class ReductionOperation
516{
517 SUM_SQUARE, /**< Sum of squares */
Michalis Spyrou04f089c2017-08-08 17:42:38 +0100518 SUM, /**< Sum */
Georgios Pinitasd9769582017-08-03 10:19:40 +0100519};
520
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100521/** The normalization type used for the normalization layer */
522enum class NormType
523{
524 IN_MAP_1D, /**< Normalization applied within the same map in 1D region */
525 IN_MAP_2D, /**< Normalization applied within the same map in 2D region */
526 CROSS_MAP /**< Normalization applied cross maps */
527};
528
529/** Normalization type for Histogram of Oriented Gradients (HOG) */
530enum class HOGNormType
531{
532 L2_NORM = 1, /**< L2-norm */
533 L2HYS_NORM = 2, /**< L2-norm followed by clipping */
534 L1_NORM = 3 /**< L1 norm */
535};
536
537/** Detection window used for the object detection. The detection window keeps the following information:
538 *
539 * -# Geometry of the rectangular window (x/y of top-left corner and width/height)
540 * -# Index of the class used for evaluating which class the detection window belongs to
541 * -# Confidence value (score) obtained with the classifier
542 */
543struct DetectionWindow
544{
545 uint16_t x{ 0 }; /**< Top-left x coordinate */
546 uint16_t y{ 0 }; /**< Top-left y coordinate */
547 uint16_t width{ 0 }; /**< Width of the detection window */
548 uint16_t height{ 0 }; /**< Height of the detection window */
549 uint16_t idx_class{ 0 }; /**< Index of the class */
550 float score{ 0.f }; /**< Confidence value for the detection window */
551};
552
553/** Dimension rounding type when down-scaling on CNNs
554 * @note Used in pooling and convolution layer
555 */
556enum class DimensionRoundingType
557{
558 FLOOR, /**< Floor rounding */
559 CEIL /**< Ceil rounding */
560};
561
562/** Available pooling types */
563enum class PoolingType
564{
565 MAX, /**< Max Pooling */
Georgios Pinitascdf51452017-08-31 14:21:36 +0100566 AVG, /**< Average Pooling */
567 L2 /**< L2 Pooling */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100568};
569
570/** Padding and stride information class */
571class PadStrideInfo
572{
573public:
574 /** Constructor
575 *
576 * @param[in] stride_x (Optional) Stride, in elements, across x. Defaults to 1.
577 * @param[in] stride_y (Optional) Stride, in elements, across y. Defaults to 1.
578 * @param[in] pad_x (Optional) Padding, in elements, across x. Defaults to 0.
579 * @param[in] pad_y (Optional) Padding, in elements, across y. Defaults to 0.
580 * @param[in] round (Optional) Dimensions rounding. Defaults to @ref FLOOR.
581 */
582 PadStrideInfo(unsigned int stride_x = 1, unsigned int stride_y = 1,
583 unsigned int pad_x = 0, unsigned int pad_y = 0,
584 DimensionRoundingType round = DimensionRoundingType::FLOOR)
585 : _stride(std::make_pair(stride_x, stride_y)),
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100586 _pad_left(pad_x),
587 _pad_top(pad_y),
588 _pad_right(pad_x),
589 _pad_bottom(pad_y),
590 _round_type(round)
591 {
592 }
593 /** Constructor
594 *
595 * @param[in] stride_x Stride, in elements, across x.
596 * @param[in] stride_y Stride, in elements, across y.
597 * @param[in] pad_left Padding across x on the left, in elements.
598 * @param[in] pad_top Padding across y on the top, in elements.
599 * @param[in] pad_right Padding across x on the right, in elements.
600 * @param[in] pad_bottom Padding across y on the bottom, in elements.
601 * @param[in] round Dimensions rounding.
602 */
603 PadStrideInfo(unsigned int stride_x, unsigned int stride_y,
604 unsigned int pad_left, unsigned int pad_right,
605 unsigned int pad_top, unsigned int pad_bottom,
606 DimensionRoundingType round)
607 : _stride(std::make_pair(stride_x, stride_y)),
608 _pad_left(pad_left),
609 _pad_top(pad_top),
610 _pad_right(pad_right),
611 _pad_bottom(pad_bottom),
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100612 _round_type(round)
613 {
614 }
Alex Gildayc357c472018-03-21 13:54:09 +0000615 /** Get the stride.
616 *
617 * @return a pair: stride x, stride y.
618 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100619 std::pair<unsigned int, unsigned int> stride() const
620 {
621 return _stride;
622 }
Alex Gildayc357c472018-03-21 13:54:09 +0000623 /** Check whether the padding is symmetric.
624 *
625 * @return True if the padding is symmetric.
626 */
Anthony Barbier21f67d62018-02-16 15:17:48 +0000627 bool padding_is_symmetric() const
628 {
629 return (_pad_left == _pad_right) && (_pad_top == _pad_bottom);
630 }
Alex Gildayc357c472018-03-21 13:54:09 +0000631 /** Get the padding.
632 *
633 * @note This should only be used when the padding is symmetric.
634 *
635 * @return a pair: padding left/right, padding top/bottom
636 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100637 std::pair<unsigned int, unsigned int> pad() const
638 {
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100639 //this accessor should be used only when padding is symmetric
Anthony Barbier21f67d62018-02-16 15:17:48 +0000640 ARM_COMPUTE_ERROR_ON(!padding_is_symmetric());
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100641 return std::make_pair(_pad_left, _pad_top);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100642 }
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100643
Alex Gildayc357c472018-03-21 13:54:09 +0000644 /** Get the left padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100645 unsigned int pad_left() const
646 {
647 return _pad_left;
648 }
Alex Gildayc357c472018-03-21 13:54:09 +0000649 /** Get the right padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100650 unsigned int pad_right() const
651 {
652 return _pad_right;
653 }
Alex Gildayc357c472018-03-21 13:54:09 +0000654 /** Get the top padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100655 unsigned int pad_top() const
656 {
657 return _pad_top;
658 }
Alex Gildayc357c472018-03-21 13:54:09 +0000659 /** Get the bottom padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100660 unsigned int pad_bottom() const
661 {
662 return _pad_bottom;
663 }
664
Alex Gildayc357c472018-03-21 13:54:09 +0000665 /** Get the rounding type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100666 DimensionRoundingType round() const
667 {
668 return _round_type;
669 }
670
Alex Gildayc357c472018-03-21 13:54:09 +0000671 /** Check whether this has any padding */
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100672 bool has_padding() const
673 {
674 return (_pad_left != 0 || _pad_top != 0 || _pad_right != 0 || _pad_bottom != 0);
675 }
676
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100677private:
678 std::pair<unsigned int, unsigned int> _stride;
Jaroslaw Rzepeckia1ed41f2017-10-13 11:13:58 +0100679 unsigned int _pad_left;
680 unsigned int _pad_top;
681 unsigned int _pad_right;
682 unsigned int _pad_bottom;
683
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100684 DimensionRoundingType _round_type;
685};
686
687/** Pooling Layer Information class */
688class PoolingLayerInfo
689{
690public:
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000691 /** Default Constructor */
692 PoolingLayerInfo()
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000693 : _pool_type(PoolingType::MAX), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo()), _exclude_padding(false), _is_global_pooling(false)
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000694 {
695 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100696 /** Default Constructor
697 *
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000698 * @param[in] pool_type Pooling type @ref PoolingType.
699 * @param[in] pool_size Pooling size, in elements, across x and y.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100700 * @param[in] pad_stride_info (Optional) Padding and stride information @ref PadStrideInfo
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000701 * @param[in] exclude_padding (Optional) Strategy when accounting padding in calculations.
702 * True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
703 * Defaults to false;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100704 */
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000705 explicit PoolingLayerInfo(PoolingType pool_type,
706 unsigned int pool_size,
707 PadStrideInfo pad_stride_info = PadStrideInfo(),
708 bool exclude_padding = false)
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000709 : _pool_type(pool_type), _pool_size(Size2D(pool_size, pool_size)), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false)
710 {
711 }
712 /** Default Constructor
713 *
714 * @param[in] pool_type Pooling type @ref PoolingType.
715 * @param[in] pool_size Pooling size, in elements, across x and y.
716 * @param[in] pad_stride_info (Optional) Padding and stride information @ref PadStrideInfo
717 * @param[in] exclude_padding (Optional) Strategy when accounting padding in calculations.
718 * True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area).
719 * Defaults to false;
720 */
721 explicit PoolingLayerInfo(PoolingType pool_type,
722 Size2D pool_size,
723 PadStrideInfo pad_stride_info = PadStrideInfo(),
724 bool exclude_padding = false)
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000725 : _pool_type(pool_type), _pool_size(pool_size), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false)
726 {
727 }
728 /** Default Constructor
729 *
730 * @note This constructor is used for global pooling
731 *
732 * @param[in] pool_type Pooling type @ref PoolingType.
733 */
734 explicit PoolingLayerInfo(PoolingType pool_type)
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000735 : _pool_type(pool_type), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo(1, 1, 0, 0)), _exclude_padding(false), _is_global_pooling(true)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736 {
737 }
Alex Gildayc357c472018-03-21 13:54:09 +0000738 /** Get the pooling type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100739 PoolingType pool_type() const
740 {
741 return _pool_type;
742 }
Alex Gildayc357c472018-03-21 13:54:09 +0000743 /** Get the pooling size */
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000744 const Size2D &pool_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100745 {
746 return _pool_size;
747 }
Alex Gildayc357c472018-03-21 13:54:09 +0000748 /** Get the padding and stride */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100749 PadStrideInfo pad_stride_info() const
750 {
751 return _pad_stride_info;
752 }
Alex Gildayc357c472018-03-21 13:54:09 +0000753 /** Check if padding is excluded in calculations */
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000754 bool exclude_padding() const
755 {
756 return _exclude_padding;
757 }
Alex Gildayc357c472018-03-21 13:54:09 +0000758 /** Check if is global pooling */
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000759 bool is_global_pooling() const
760 {
761 return _is_global_pooling;
762 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100763
764private:
765 PoolingType _pool_type;
Isabella Gottardi6e464c32018-01-26 12:32:45 +0000766 Size2D _pool_size;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100767 PadStrideInfo _pad_stride_info;
Georgios Pinitasadaae7e2017-10-30 15:56:32 +0000768 bool _exclude_padding;
Georgios Pinitas4c2dd542017-11-13 12:58:41 +0000769 bool _is_global_pooling;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100770};
771
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100772/** ROI Pooling Layer Information class */
773class ROIPoolingLayerInfo
774{
775public:
776 /** Default Constructor
777 *
778 * @param[in] pooled_width Pooled width of the layer.
779 * @param[in] pooled_height Pooled height of the layer.
780 * @param[in] spatial_scale Spatial scale to be applied to the ROI coordinates and dimensions.
781 */
782 ROIPoolingLayerInfo(unsigned int pooled_width, unsigned int pooled_height, float spatial_scale)
783 : _pooled_width(pooled_width), _pooled_height(pooled_height), _spatial_scale(spatial_scale)
784 {
785 }
Alex Gildayc357c472018-03-21 13:54:09 +0000786 /** Get the pooled width of the layer */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100787 unsigned int pooled_width() const
788 {
789 return _pooled_width;
790 }
Alex Gildayc357c472018-03-21 13:54:09 +0000791 /** Get the pooled height of the layer */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100792 unsigned int pooled_height() const
793 {
794 return _pooled_height;
795 }
Alex Gildayc357c472018-03-21 13:54:09 +0000796 /** Get the spatial scale */
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100797 float spatial_scale() const
798 {
799 return _spatial_scale;
800 }
801
802private:
803 unsigned int _pooled_width;
804 unsigned int _pooled_height;
805 float _spatial_scale;
806};
807
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100808/** Activation Layer Information class */
809class ActivationLayerInfo
810{
811public:
812 /** Available activation functions */
813 enum class ActivationFunction
814 {
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +0100815 LOGISTIC, /**< Logistic ( \f$ f(x) = \frac{1}{1 + e^{-x}} \f$ ) */
816 TANH, /**< Hyperbolic tangent ( \f$ f(x) = a \cdot tanh(b \cdot x) \f$ ) */
817 RELU, /**< Rectifier ( \f$ f(x) = max(0,x) \f$ ) */
818 BOUNDED_RELU, /**< Upper Bounded Rectifier ( \f$ f(x) = min(a, max(0,x)) \f$ ) */
819 LU_BOUNDED_RELU, /**< Lower and Upper Bounded Rectifier ( \f$ f(x) = min(a, max(b,x)) \f$ ) */
820 LEAKY_RELU, /**< Leaky Rectifier ( \f$ f(x)= log(1+e^x) \f$ ) */
821 SOFT_RELU, /**< Soft Rectifier ( \f$ f(x)= log(1+e^x) \f$ ) */
822 ABS, /**< Absolute ( \f$ f(x)= |x| \f$ ) */
823 SQUARE, /**< Square ( \f$ f(x)= x^2 \f$ )*/
824 SQRT, /**< Square root ( \f$ f(x) = \sqrt{x} \f$ )*/
825 LINEAR /**< Linear ( \f$ f(x)= ax + b \f$ ) */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100826 };
827
Giorgio Arena11674872018-02-07 15:38:12 +0000828 ActivationLayerInfo() = default;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100829 /** Default Constructor
830 *
831 * @param[in] f The activation function to use.
832 * @param[in] a (Optional) The alpha parameter used by some activation functions
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +0100833 * (@ref ActivationFunction::BOUNDED_RELU, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::LINEAR, @ref ActivationFunction::TANH).
834 * @param[in] b (Optional) The beta parameter used by some activation functions (@ref ActivationFunction::LINEAR, @ref ActivationFunction::LU_BOUNDED_RELU, @ref ActivationFunction::TANH).
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100835 */
836 ActivationLayerInfo(ActivationFunction f, float a = 0.0f, float b = 0.0f)
Giorgio Arena11674872018-02-07 15:38:12 +0000837 : _act(f), _a(a), _b(b), _enabled(true)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100838 {
839 }
Alex Gildayc357c472018-03-21 13:54:09 +0000840 /** Get the type of activation function */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100841 ActivationFunction activation() const
842 {
843 return _act;
844 }
Alex Gildayc357c472018-03-21 13:54:09 +0000845 /** Get the alpha value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100846 float a() const
847 {
848 return _a;
849 }
Alex Gildayc357c472018-03-21 13:54:09 +0000850 /** Get the beta value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100851 float b() const
852 {
853 return _b;
854 }
Alex Gildayc357c472018-03-21 13:54:09 +0000855 /** Check if initialised */
Giorgio Arena11674872018-02-07 15:38:12 +0000856 bool enabled() const
857 {
858 return _enabled;
859 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100860
861private:
Giorgio Arena11674872018-02-07 15:38:12 +0000862 ActivationFunction _act = { ActivationLayerInfo::ActivationFunction::LOGISTIC };
863 float _a = {};
864 float _b = {};
865 bool _enabled = { false };
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100866};
867
868/** Normalization Layer Information class */
869class NormalizationLayerInfo
870{
871public:
872 /** Default Constructor
873 *
874 * @param[in] type The normalization type. Can be @ref NormType::IN_MAP_1D, @ref NormType::IN_MAP_2D or @ref NORM_TYPE::CROSS_MAP
875 * @param[in] norm_size The normalization size is the number of elements to normalize across. Defaults to 5.
Georgios Pinitas41caa622017-11-16 14:37:08 +0000876 * @param[in] alpha (Optional) Alpha parameter used by normalization equation. Defaults to 0.0001.
877 * @param[in] beta (Optional) Beta parameter used by normalization equation. Defaults to 0.5.
878 * @param[in] kappa (Optional) Kappa parameter used by [Krichevksy 2012] Across Channel Local Brightness Normalization equation.
879 * @param[in] is_scaled (Optional) Boolean that specifies if alpha will be scaled by the normalization size or not.
880 * Should be false to follow [Krichevksy 2012].
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100881 */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000882 NormalizationLayerInfo(NormType type, uint32_t norm_size = 5, float alpha = 0.0001f, float beta = 0.5f, float kappa = 1.f, bool is_scaled = true)
883 : _type(type), _norm_size(norm_size), _alpha(alpha), _beta(beta), _kappa(kappa), _is_scaled(is_scaled)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100884 {
885 }
Alex Gildayc357c472018-03-21 13:54:09 +0000886 /** Get the normalization type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100887 NormType type() const
888 {
889 return _type;
890 }
Alex Gildayc357c472018-03-21 13:54:09 +0000891 /** Get the normalization size */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100892 uint32_t norm_size() const
893 {
894 return _norm_size;
895 }
Alex Gildayc357c472018-03-21 13:54:09 +0000896 /** Get the alpha value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100897 float alpha() const
898 {
899 return _alpha;
900 }
Alex Gildayc357c472018-03-21 13:54:09 +0000901 /** Get the beta value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100902 float beta() const
903 {
904 return _beta;
905 }
Alex Gildayc357c472018-03-21 13:54:09 +0000906 /** Get the kappa value */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100907 float kappa() const
908 {
909 return _kappa;
910 }
Alex Gildayc357c472018-03-21 13:54:09 +0000911 /** Check if normalization is cross map */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000912 bool is_cross_map() const
913 {
914 return _type == NormType::CROSS_MAP;
915 }
Alex Gildayc357c472018-03-21 13:54:09 +0000916 /** Check if normalization is not cross map */
Georgios Pinitas41caa622017-11-16 14:37:08 +0000917 bool is_in_map() const
918 {
919 return !is_cross_map();
920 }
921 /** Return the scaling factor of the normalization function.
922 *
923 * If is_scaled is set to false then [Krichevksy 2012] normalization scaling is performed,
924 * where alpha is returned plainly, else alpha is scaled by the total number of elements used for the normalization.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100925 *
926 * @return The normalization scaling factor.
927 */
928 float scale_coeff() const
929 {
930 const uint32_t size = (_type == NormType::IN_MAP_2D) ? _norm_size * _norm_size : _norm_size;
Georgios Pinitas41caa622017-11-16 14:37:08 +0000931 return (_is_scaled) ? (_alpha / size) : _alpha;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100932 }
933
934private:
935 NormType _type;
936 uint32_t _norm_size;
937 float _alpha;
938 float _beta;
939 float _kappa;
Georgios Pinitas41caa622017-11-16 14:37:08 +0000940 bool _is_scaled;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100941};
942
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100943/** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100944class WeightsInfo
945{
946public:
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100947 /** Default constructor */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100948 WeightsInfo()
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100949 : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100950 {
951 }
952 /** Constructor
953 *
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100954 * @param[in] are_reshaped True if the weights have been reshaped
955 * @param[in] kernel_width Kernel width.
956 * @param[in] kernel_height Kernel height.
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100957 * @param[in] num_kernels Number of convolution kernels.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100958 */
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100959 WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels)
960 : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100961 {
962 }
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100963 /** Flag which specifies if the weights tensor has been reshaped.
964 *
965 * @return True if the weights tensors has been reshaped
966 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100967 bool are_reshaped() const
968 {
969 return _are_reshaped;
970 };
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100971 /** Return the number of convolution kernels
972 *
973 * @return The number of convolution kernels
974 */
975 unsigned int num_kernels() const
976 {
977 return _num_kernels;
978 };
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100979 /** Return the width and height of the kernel
980 *
981 * @return The width and height of the kernel
982 */
983 std::pair<unsigned int, unsigned int> kernel_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100984 {
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100985 return std::make_pair(_kernel_width, _kernel_height);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100986 }
987
988private:
989 const bool _are_reshaped;
Gian Marco Iodice4e288692017-06-27 11:41:59 +0100990 const unsigned int _kernel_width;
991 const unsigned int _kernel_height;
Gian Marco Iodice559d7712017-08-08 08:38:09 +0100992 const unsigned int _num_kernels;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100993};
994
Gian Marco36a0a462018-01-12 10:21:40 +0000995/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
996 *
997 * The matrix A can only be reshaped through @ref CLGEMMInterleave4x4Kernel or @ref NEGEMMInterleave4x4Kernel or @ref GCGEMMInterleave4x4Kernel
998 * Note: Optionally just for @ref CLGEMMInterleave4x4Kernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block
999 *
1000 * The matrix B can only be reshaped through @ref CLGEMMTranspose1xWKernel or @ref NEGEMMTranspose1xWKernel or @ref GCGEMMTranspose1xWKernel
1001 * Note: Optionally just for @ref CLGEMMTranspose1xWKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block
1002 *
1003 */
1004class GEMMReshapeInfo final
1005{
1006public:
1007 /** Default constructor */
1008 GEMMReshapeInfo()
1009 : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1)
1010 {
1011 }
1012 /** Constructor
1013 *
1014 * @param[in] m Number of matrix A rows
1015 * @param[in] n Number of matrix B columns
1016 * @param[in] k Number of matrix A columns or matrix B rows
1017 * @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block
1018 * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block
1019 */
1020 GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1)
1021 : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height)
1022 {
1023 }
1024 /** Number of matrix A rows
1025 *
1026 * @return the number of matrix A rows
1027 */
1028 int m() const
1029 {
1030 return _m;
1031 }
1032 /** Number of matrix B columns
1033 *
1034 * @return the number of matrix B columns
1035 */
1036 int n() const
1037 {
1038 return _n;
1039 }
1040 /** Number of matrix A columns or matrix B rows
1041 *
1042 * @return the number of matrix A columns or matrix B rows
1043 */
1044 int k() const
1045 {
1046 return _k;
1047 }
1048 /** Multiplication factor for the width of the 1xW transposed block
1049 *
1050 * @return the multiplication factor for the width of the 1xW transposed block
1051 */
1052 int mult_transpose1xW_width() const
1053 {
1054 return _mult_transpose1xW_width;
1055 }
1056 /** Multiplication factor for the height of the 4x4 interleaved block
1057 *
1058 * @return the multiplication factor for the height of the 4x4 interleaved block
1059 */
1060 int mult_interleave4x4_height() const
1061 {
1062 return _mult_interleave4x4_height;
1063 }
1064
1065private:
1066 const int _m;
1067 const int _n;
1068 const int _k;
1069 const int _mult_transpose1xW_width;
1070 const int _mult_interleave4x4_height;
1071};
1072
1073/** GEMM information class. This class stores the necessary information to compute GEMM functions
1074 *
1075 * This object also contains the information about how matrix A and matrix B have been reshaped
1076 *
1077 */
Chunosov5124be52017-11-22 20:42:13 +07001078class GEMMInfo
1079{
1080public:
1081 /** Default constructor */
1082 GEMMInfo()
Gian Marco36a0a462018-01-12 10:21:40 +00001083 : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _reshape_info()
Chunosov5124be52017-11-22 20:42:13 +07001084 {
1085 }
1086 /** Constructor
1087 *
1088 * @param[in] is_a_reshaped True if the matrix A has been reshaped
1089 * @param[in] is_b_reshaped True if the matrix B has been reshaped
1090 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
Gian Marco36a0a462018-01-12 10:21:40 +00001091 * @param[in] reshape_info (Optional) GEMM reshape information object
Chunosov5124be52017-11-22 20:42:13 +07001092 */
Gian Marco36a0a462018-01-12 10:21:40 +00001093 GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo())
1094 : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _reshape_info(reshape_info)
Chunosov5124be52017-11-22 20:42:13 +07001095 {
1096 }
1097 /** Flag which specifies if the matrix A has been reshaped
1098 *
1099 * @return True if the matrix A has been reshaped
1100 */
1101 bool is_a_reshaped() const
1102 {
1103 return _is_a_reshaped;
1104 };
1105 /** Flag which specifies if the matrix B has been reshaped
1106 *
1107 * @return True if the matrix B has been reshaped
1108 */
1109 bool is_b_reshaped() const
1110 {
1111 return _is_b_reshaped;
1112 };
1113 /** Flag which specifies if the reshape of matrix B should executed only for the first
1114 *
1115 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
1116 *
1117 * @return True if the reshaped of matrix B happens only for the first run
1118 */
1119 bool reshape_b_only_on_first_run() const
1120 {
1121 return _reshape_b_only_on_first_run;
1122 };
Gian Marco36a0a462018-01-12 10:21:40 +00001123 /** GEMMReshapeInfo object which stores the necessary information to understand how the matrix A and matrix B have been reshaped
1124 *
1125 * @return the GEMMReshapeInfo object
1126 */
1127 const GEMMReshapeInfo &reshape_info() const
1128 {
1129 return _reshape_info;
1130 }
Chunosov5124be52017-11-22 20:42:13 +07001131
1132private:
Gian Marco36a0a462018-01-12 10:21:40 +00001133 const bool _is_a_reshaped;
1134 const bool _is_b_reshaped;
1135 const bool _reshape_b_only_on_first_run;
1136 GEMMReshapeInfo _reshape_info;
Chunosov5124be52017-11-22 20:42:13 +07001137};
1138
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001139/** IO formatting information class*/
1140struct IOFormatInfo
1141{
1142 /** Precision type used when printing floating point numbers */
1143 enum class PrecisionType
1144 {
1145 Default, /**< Default precision to the one that the current stream has */
1146 Custom, /**< Custom precision specified by the user using the precision parameter */
1147 Full /**< The maximum precision of the floating point representation */
1148 };
1149
1150 /** Specifies the area to be printed, used by Tensor objects */
1151 enum class PrintRegion
1152 {
1153 ValidRegion, /**< Prints the valid region of the Tensor object */
1154 NoPadding, /**< Prints the Tensor object without the padding */
1155 Full /**< Print the tensor object including padding */
1156 };
1157
Alex Gildayc357c472018-03-21 13:54:09 +00001158 /** Construct a set of IO formatting information.
1159 *
1160 * @param[in] print_region Area to be printed. Used by Tensor objects. Default: ValidRegion.
1161 * @param[in] precision_type Precision type for floating point numbers. Default: stream default.
1162 * @param[in] precision Precision value for float point numbers. Default: 10.
1163 * @param[in] align_columns Whether to align columns when printed. Default: true.
1164 * @param[in] element_delim Delimeter between elements. Default: " ".
1165 * @param[in] row_delim Delimenter between rows. Default: "\n".
1166 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001167 IOFormatInfo(PrintRegion print_region = PrintRegion::ValidRegion,
1168 PrecisionType precision_type = PrecisionType::Default,
1169 unsigned int precision = 10,
1170 bool align_columns = true,
1171 std::string element_delim = " ",
1172 std::string row_delim = "\n")
1173 : print_region(print_region),
1174 precision_type(precision_type),
1175 precision(precision),
1176 element_delim(element_delim),
1177 row_delim(row_delim),
1178 align_columns(align_columns)
1179 {
1180 }
1181
Alex Gildayc357c472018-03-21 13:54:09 +00001182 /** Area to be printed by Tensor objects */
1183 PrintRegion print_region;
1184 /** Floating point precision type */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001185 PrecisionType precision_type;
Alex Gildayc357c472018-03-21 13:54:09 +00001186 /** Floating point precision */
1187 unsigned int precision;
1188 /** Element delimeter */
1189 std::string element_delim;
1190 /** Row delimeter */
1191 std::string row_delim;
1192 /** Align columns */
1193 bool align_columns;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001194};
Isabella Gottardif07d28d2018-02-06 14:52:43 +00001195
1196/** Available ConvolutionMethod*/
1197enum class ConvolutionMethod
1198{
1199 GEMM, /**< Convolution using GEMM */
1200 DIRECT, /**< Direct convolution */
1201 WINOGRAD /**< Convolution using Winograd */
1202};
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001203} // namespace arm_compute
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001204#endif /* __ARM_COMPUTE_TYPES_H__ */