blob: 1012fcfa226e73d9659e722f291e4cbdc4208383 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Jim Flynnf92dfce2019-05-02 11:33:25 +01007#include <armnn/Tensor.hpp>
8#include <armnn/Types.hpp>
Matthew Bentham313e1c82019-03-25 17:37:47 +00009
Matthew Bentham313e1c82019-03-25 17:37:47 +000010#include <cmath>
11#include <ostream>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <set>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace armnn
15{
16
David Beck9df2d952018-10-10 15:11:44 +010017constexpr char const* GetStatusAsCString(Status status)
telsoa014fcda012018-03-09 14:13:49 +000018{
David Beck9df2d952018-10-10 15:11:44 +010019 switch (status)
telsoa014fcda012018-03-09 14:13:49 +000020 {
21 case armnn::Status::Success: return "Status::Success";
22 case armnn::Status::Failure: return "Status::Failure";
23 default: return "Unknown";
24 }
25}
26
surmeh01bceff2f2018-03-29 16:29:27 +010027constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28{
29 switch (activation)
30 {
31 case ActivationFunction::Sigmoid: return "Sigmoid";
32 case ActivationFunction::TanH: return "TanH";
33 case ActivationFunction::Linear: return "Linear";
34 case ActivationFunction::ReLu: return "ReLu";
35 case ActivationFunction::BoundedReLu: return "BoundedReLu";
36 case ActivationFunction::SoftReLu: return "SoftReLu";
37 case ActivationFunction::LeakyReLu: return "LeakyReLu";
38 case ActivationFunction::Abs: return "Abs";
39 case ActivationFunction::Sqrt: return "Sqrt";
40 case ActivationFunction::Square: return "Square";
David Monahan3b3c3812020-02-25 09:03:29 +000041 case ActivationFunction::Elu: return "Elu";
Colm Donelan03fbeaf2020-02-26 15:39:23 +000042 case ActivationFunction::HardSwish: return "HardSwish";
surmeh01bceff2f2018-03-29 16:29:27 +010043 default: return "Unknown";
44 }
45}
46
Francis Murtaghb5b3b352019-11-13 16:58:20 +000047constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
48{
49 switch (function)
50 {
51 case ArgMinMaxFunction::Max: return "Max";
52 case ArgMinMaxFunction::Min: return "Min";
53 default: return "Unknown";
54 }
55}
56
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +010057constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
58{
59 switch (operation)
60 {
61 case ComparisonOperation::Equal: return "Equal";
62 case ComparisonOperation::Greater: return "Greater";
63 case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
64 case ComparisonOperation::Less: return "Less";
65 case ComparisonOperation::LessOrEqual: return "LessOrEqual";
66 case ComparisonOperation::NotEqual: return "NotEqual";
67 default: return "Unknown";
68 }
69}
70
josh minor4a3c6102020-01-06 16:40:46 -060071constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
72{
73 switch (operation)
74 {
James Conroyaba90cd2020-11-06 16:28:18 +000075 case UnaryOperation::Abs: return "Abs";
76 case UnaryOperation::Exp: return "Exp";
77 case UnaryOperation::Sqrt: return "Sqrt";
78 case UnaryOperation::Rsqrt: return "Rsqrt";
79 case UnaryOperation::Neg: return "Neg";
80 case UnaryOperation::LogicalNot: return "LogicalNot";
81 default: return "Unknown";
82 }
83}
84
85constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
86{
87 switch (operation)
88 {
89 case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
90 case LogicalBinaryOperation::LogicalOr: return "LogicalOr";
91 default: return "Unknown";
josh minor4a3c6102020-01-06 16:40:46 -060092 }
93}
94
surmeh01bceff2f2018-03-29 16:29:27 +010095constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
96{
97 switch (pooling)
98 {
99 case PoolingAlgorithm::Average: return "Average";
100 case PoolingAlgorithm::Max: return "Max";
101 case PoolingAlgorithm::L2: return "L2";
102 default: return "Unknown";
103 }
104}
105
106constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
107{
108 switch (rounding)
109 {
110 case OutputShapeRounding::Ceiling: return "Ceiling";
111 case OutputShapeRounding::Floor: return "Floor";
112 default: return "Unknown";
113 }
114}
115
surmeh01bceff2f2018-03-29 16:29:27 +0100116constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
117{
118 switch (method)
119 {
120 case PaddingMethod::Exclude: return "Exclude";
121 case PaddingMethod::IgnoreValue: return "IgnoreValue";
122 default: return "Unknown";
123 }
124}
125
telsoa014fcda012018-03-09 14:13:49 +0000126constexpr unsigned int GetDataTypeSize(DataType dataType)
127{
128 switch (dataType)
129 {
Narumol Prangnawaratc3bf6ef2020-02-28 12:45:21 +0000130 case DataType::BFloat16:
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000131 case DataType::Float16: return 2U;
telsoa01c577f2c2018-08-31 09:22:23 +0100132 case DataType::Float32:
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000133 case DataType::Signed32: return 4U;
Inki Daed4619e22020-09-10 15:33:54 +0900134 case DataType::Signed64: return 8U;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000135 case DataType::QAsymmU8: return 1U;
Ryan OShea9add1202020-02-07 10:06:33 +0000136 case DataType::QAsymmS8: return 1U;
Finn Williamsfd271062019-12-04 14:27:27 +0000137 case DataType::QSymmS8: return 1U;
Derek Lambertid466a542020-01-22 15:37:29 +0000138 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000139 case DataType::QuantizedSymm8PerAxis: return 1U;
Derek Lambertid466a542020-01-22 15:37:29 +0000140 ARMNN_NO_DEPRECATE_WARN_END
141 case DataType::QSymmS16: return 2U;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000142 case DataType::Boolean: return 1U;
143 default: return 0U;
telsoa014fcda012018-03-09 14:13:49 +0000144 }
145}
146
Kohei Takahashi430c6d62018-08-31 17:53:11 +0900147template <unsigned N>
telsoa014fcda012018-03-09 14:13:49 +0000148constexpr bool StrEqual(const char* strA, const char (&strB)[N])
149{
150 bool isEqual = true;
Kohei Takahashi430c6d62018-08-31 17:53:11 +0900151 for (unsigned i = 0; isEqual && (i < N); ++i)
telsoa014fcda012018-03-09 14:13:49 +0000152 {
153 isEqual = (strA[i] == strB[i]);
154 }
155 return isEqual;
156}
157
David Becke6488712018-10-16 11:32:20 +0100158/// Deprecated function that will be removed together with
159/// the Compute enum
telsoa01c577f2c2018-08-31 09:22:23 +0100160constexpr armnn::Compute ParseComputeDevice(const char* str)
telsoa014fcda012018-03-09 14:13:49 +0000161{
telsoa01c577f2c2018-08-31 09:22:23 +0100162 if (armnn::StrEqual(str, "CpuAcc"))
telsoa014fcda012018-03-09 14:13:49 +0000163 {
164 return armnn::Compute::CpuAcc;
165 }
telsoa01c577f2c2018-08-31 09:22:23 +0100166 else if (armnn::StrEqual(str, "CpuRef"))
telsoa014fcda012018-03-09 14:13:49 +0000167 {
168 return armnn::Compute::CpuRef;
169 }
telsoa01c577f2c2018-08-31 09:22:23 +0100170 else if (armnn::StrEqual(str, "GpuAcc"))
telsoa014fcda012018-03-09 14:13:49 +0000171 {
172 return armnn::Compute::GpuAcc;
173 }
174 else
175 {
176 return armnn::Compute::Undefined;
177 }
178}
179
180constexpr const char* GetDataTypeName(DataType dataType)
181{
182 switch (dataType)
183 {
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000184 case DataType::Float16: return "Float16";
185 case DataType::Float32: return "Float32";
Inki Daed4619e22020-09-10 15:33:54 +0900186 case DataType::Signed64: return "Signed64";
Derek Lambertif90c56d2020-01-10 17:14:08 +0000187 case DataType::QAsymmU8: return "QAsymmU8";
Keith Davis0c2eeac2020-02-11 16:51:50 +0000188 case DataType::QAsymmS8: return "QAsymmS8";
Derek Lambertif90c56d2020-01-10 17:14:08 +0000189 case DataType::QSymmS8: return "QSymmS8";
Derek Lambertid466a542020-01-22 15:37:29 +0000190 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000191 case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
Derek Lambertid466a542020-01-22 15:37:29 +0000192 ARMNN_NO_DEPRECATE_WARN_END
193 case DataType::QSymmS16: return "QSymm16";
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000194 case DataType::Signed32: return "Signed32";
195 case DataType::Boolean: return "Boolean";
Narumol Prangnawaratc3bf6ef2020-02-28 12:45:21 +0000196 case DataType::BFloat16: return "BFloat16";
telsoa01c577f2c2018-08-31 09:22:23 +0100197
198 default:
199 return "Unknown";
telsoa014fcda012018-03-09 14:13:49 +0000200 }
201}
202
Matteo Martincigh49124022019-01-11 13:25:59 +0000203constexpr const char* GetDataLayoutName(DataLayout dataLayout)
204{
205 switch (dataLayout)
206 {
207 case DataLayout::NCHW: return "NCHW";
208 case DataLayout::NHWC: return "NHWC";
209 default: return "Unknown";
210 }
211}
212
Teresa Charlin190a39a2020-01-23 11:44:24 +0000213constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
214{
215 switch (channel)
216 {
217 case NormalizationAlgorithmChannel::Across: return "Across";
218 case NormalizationAlgorithmChannel::Within: return "Within";
219 default: return "Unknown";
220 }
221}
222
223constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
224{
225 switch (method)
226 {
227 case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
228 case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
229 default: return "Unknown";
230 }
231}
232
233constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
234{
235 switch (method)
236 {
237 case ResizeMethod::Bilinear: return "Bilinear";
238 case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
239 default: return "Unknown";
240 }
241}
telsoa01c577f2c2018-08-31 09:22:23 +0100242
243template<typename T>
244struct IsHalfType
245 : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
246{};
247
telsoa014fcda012018-03-09 14:13:49 +0000248template<typename T>
249constexpr bool IsQuantizedType()
250{
251 return std::is_integral<T>::value;
252}
253
Keith Davis0c2eeac2020-02-11 16:51:50 +0000254constexpr bool IsQuantized8BitType(DataType dataType)
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +0000255{
Derek Lambertid466a542020-01-22 15:37:29 +0000256 ARMNN_NO_DEPRECATE_WARN_BEGIN
Derek Lambertif90c56d2020-01-10 17:14:08 +0000257 return dataType == DataType::QAsymmU8 ||
Ryan OShea9add1202020-02-07 10:06:33 +0000258 dataType == DataType::QAsymmS8 ||
Finn Williamsfd271062019-12-04 14:27:27 +0000259 dataType == DataType::QSymmS8 ||
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +0000260 dataType == DataType::QuantizedSymm8PerAxis;
Derek Lambertid466a542020-01-22 15:37:29 +0000261 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +0000262}
263
Keith Davis0c2eeac2020-02-11 16:51:50 +0000264constexpr bool IsQuantizedType(DataType dataType)
265{
266 return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
267}
268
telsoa014fcda012018-03-09 14:13:49 +0000269inline std::ostream& operator<<(std::ostream& os, Status stat)
270{
271 os << GetStatusAsCString(stat);
272 return os;
273}
274
telsoa014fcda012018-03-09 14:13:49 +0000275
surmeh013537c2c2018-05-18 16:31:43 +0100276inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
277{
278 os << "[";
279 for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
280 {
281 if (i!=0)
282 {
283 os << ",";
284 }
285 os << shape[i];
286 }
287 os << "]";
288 return os;
289}
290
telsoa01c577f2c2018-08-31 09:22:23 +0100291/// Quantize a floating point data type into an 8-bit data type.
292/// @param value - The value to quantize.
293/// @param scale - The scale (must be non-zero).
294/// @param offset - The offset.
295/// @return - The quantized value calculated as round(value/scale)+offset.
telsoa014fcda012018-03-09 14:13:49 +0000296///
297template<typename QuantizedType>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +0100298QuantizedType Quantize(float value, float scale, int32_t offset);
telsoa014fcda012018-03-09 14:13:49 +0000299
telsoa01c577f2c2018-08-31 09:22:23 +0100300/// Dequantize an 8-bit data type into a floating point data type.
301/// @param value - The value to dequantize.
302/// @param scale - The scale (must be non-zero).
303/// @param offset - The offset.
304/// @return - The dequantized value calculated as (value-offset)*scale.
telsoa014fcda012018-03-09 14:13:49 +0000305///
306template <typename QuantizedType>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +0100307float Dequantize(QuantizedType value, float scale, int32_t offset);
telsoa014fcda012018-03-09 14:13:49 +0000308
keidav011b3e2ea2019-02-21 10:07:37 +0000309inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
telsoa01c577f2c2018-08-31 09:22:23 +0100310{
keidav011b3e2ea2019-02-21 10:07:37 +0000311 if (info.GetDataType() != dataType)
telsoa01c577f2c2018-08-31 09:22:23 +0100312 {
313 std::stringstream ss;
314 ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
keidav011b3e2ea2019-02-21 10:07:37 +0000315 << " for tensor:" << info.GetShape()
316 << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100317 throw armnn::Exception(ss.str());
318 }
319}
320
surmeh013537c2c2018-05-18 16:31:43 +0100321} //namespace armnn