Laurent Carlier | 749294b | 2020-06-01 09:03:17 +0100 | [diff] [blame] | 1 | // |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
Jim Flynn | f92dfce | 2019-05-02 11:33:25 +0100 | [diff] [blame] | 7 | #include <armnn/Tensor.hpp> |
| 8 | #include <armnn/Types.hpp> |
Matthew Bentham | 313e1c8 | 2019-03-25 17:37:47 +0000 | [diff] [blame] | 9 | |
Matthew Bentham | 313e1c8 | 2019-03-25 17:37:47 +0000 | [diff] [blame] | 10 | #include <cmath> |
| 11 | #include <ostream> |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 12 | #include <set> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 13 | |
| 14 | namespace armnn |
| 15 | { |
| 16 | |
David Beck | 9df2d95 | 2018-10-10 15:11:44 +0100 | [diff] [blame] | 17 | constexpr char const* GetStatusAsCString(Status status) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 18 | { |
David Beck | 9df2d95 | 2018-10-10 15:11:44 +0100 | [diff] [blame] | 19 | switch (status) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 20 | { |
| 21 | case armnn::Status::Success: return "Status::Success"; |
| 22 | case armnn::Status::Failure: return "Status::Failure"; |
| 23 | default: return "Unknown"; |
| 24 | } |
| 25 | } |
| 26 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 27 | constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation) |
| 28 | { |
| 29 | switch (activation) |
| 30 | { |
| 31 | case ActivationFunction::Sigmoid: return "Sigmoid"; |
| 32 | case ActivationFunction::TanH: return "TanH"; |
| 33 | case ActivationFunction::Linear: return "Linear"; |
| 34 | case ActivationFunction::ReLu: return "ReLu"; |
| 35 | case ActivationFunction::BoundedReLu: return "BoundedReLu"; |
| 36 | case ActivationFunction::SoftReLu: return "SoftReLu"; |
| 37 | case ActivationFunction::LeakyReLu: return "LeakyReLu"; |
| 38 | case ActivationFunction::Abs: return "Abs"; |
| 39 | case ActivationFunction::Sqrt: return "Sqrt"; |
| 40 | case ActivationFunction::Square: return "Square"; |
David Monahan | 3b3c381 | 2020-02-25 09:03:29 +0000 | [diff] [blame] | 41 | case ActivationFunction::Elu: return "Elu"; |
Colm Donelan | 03fbeaf | 2020-02-26 15:39:23 +0000 | [diff] [blame] | 42 | case ActivationFunction::HardSwish: return "HardSwish"; |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 43 | default: return "Unknown"; |
| 44 | } |
| 45 | } |
| 46 | |
Francis Murtagh | b5b3b35 | 2019-11-13 16:58:20 +0000 | [diff] [blame] | 47 | constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function) |
| 48 | { |
| 49 | switch (function) |
| 50 | { |
| 51 | case ArgMinMaxFunction::Max: return "Max"; |
| 52 | case ArgMinMaxFunction::Min: return "Min"; |
| 53 | default: return "Unknown"; |
| 54 | } |
| 55 | } |
| 56 | |
Aron Virginas-Tar | 77bfb5e | 2019-10-16 17:45:38 +0100 | [diff] [blame] | 57 | constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation) |
| 58 | { |
| 59 | switch (operation) |
| 60 | { |
| 61 | case ComparisonOperation::Equal: return "Equal"; |
| 62 | case ComparisonOperation::Greater: return "Greater"; |
| 63 | case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual"; |
| 64 | case ComparisonOperation::Less: return "Less"; |
| 65 | case ComparisonOperation::LessOrEqual: return "LessOrEqual"; |
| 66 | case ComparisonOperation::NotEqual: return "NotEqual"; |
| 67 | default: return "Unknown"; |
| 68 | } |
| 69 | } |
| 70 | |
josh minor | 4a3c610 | 2020-01-06 16:40:46 -0600 | [diff] [blame] | 71 | constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation) |
| 72 | { |
| 73 | switch (operation) |
| 74 | { |
James Conroy | aba90cd | 2020-11-06 16:28:18 +0000 | [diff] [blame] | 75 | case UnaryOperation::Abs: return "Abs"; |
| 76 | case UnaryOperation::Exp: return "Exp"; |
| 77 | case UnaryOperation::Sqrt: return "Sqrt"; |
| 78 | case UnaryOperation::Rsqrt: return "Rsqrt"; |
| 79 | case UnaryOperation::Neg: return "Neg"; |
| 80 | case UnaryOperation::LogicalNot: return "LogicalNot"; |
| 81 | default: return "Unknown"; |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation) |
| 86 | { |
| 87 | switch (operation) |
| 88 | { |
| 89 | case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd"; |
| 90 | case LogicalBinaryOperation::LogicalOr: return "LogicalOr"; |
| 91 | default: return "Unknown"; |
josh minor | 4a3c610 | 2020-01-06 16:40:46 -0600 | [diff] [blame] | 92 | } |
| 93 | } |
| 94 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 95 | constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling) |
| 96 | { |
| 97 | switch (pooling) |
| 98 | { |
| 99 | case PoolingAlgorithm::Average: return "Average"; |
| 100 | case PoolingAlgorithm::Max: return "Max"; |
| 101 | case PoolingAlgorithm::L2: return "L2"; |
| 102 | default: return "Unknown"; |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding) |
| 107 | { |
| 108 | switch (rounding) |
| 109 | { |
| 110 | case OutputShapeRounding::Ceiling: return "Ceiling"; |
| 111 | case OutputShapeRounding::Floor: return "Floor"; |
| 112 | default: return "Unknown"; |
| 113 | } |
| 114 | } |
| 115 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 116 | constexpr char const* GetPaddingMethodAsCString(PaddingMethod method) |
| 117 | { |
| 118 | switch (method) |
| 119 | { |
| 120 | case PaddingMethod::Exclude: return "Exclude"; |
| 121 | case PaddingMethod::IgnoreValue: return "IgnoreValue"; |
| 122 | default: return "Unknown"; |
| 123 | } |
| 124 | } |
| 125 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 126 | constexpr unsigned int GetDataTypeSize(DataType dataType) |
| 127 | { |
| 128 | switch (dataType) |
| 129 | { |
Narumol Prangnawarat | c3bf6ef | 2020-02-28 12:45:21 +0000 | [diff] [blame] | 130 | case DataType::BFloat16: |
Aron Virginas-Tar | 5edc881 | 2019-11-05 18:00:21 +0000 | [diff] [blame] | 131 | case DataType::Float16: return 2U; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 132 | case DataType::Float32: |
Aron Virginas-Tar | 5edc881 | 2019-11-05 18:00:21 +0000 | [diff] [blame] | 133 | case DataType::Signed32: return 4U; |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 134 | case DataType::Signed64: return 8U; |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 135 | case DataType::QAsymmU8: return 1U; |
Ryan OShea | 9add120 | 2020-02-07 10:06:33 +0000 | [diff] [blame] | 136 | case DataType::QAsymmS8: return 1U; |
Finn Williams | fd27106 | 2019-12-04 14:27:27 +0000 | [diff] [blame] | 137 | case DataType::QSymmS8: return 1U; |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 138 | ARMNN_NO_DEPRECATE_WARN_BEGIN |
Aron Virginas-Tar | 5edc881 | 2019-11-05 18:00:21 +0000 | [diff] [blame] | 139 | case DataType::QuantizedSymm8PerAxis: return 1U; |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 140 | ARMNN_NO_DEPRECATE_WARN_END |
| 141 | case DataType::QSymmS16: return 2U; |
Aron Virginas-Tar | 5edc881 | 2019-11-05 18:00:21 +0000 | [diff] [blame] | 142 | case DataType::Boolean: return 1U; |
| 143 | default: return 0U; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 144 | } |
| 145 | } |
| 146 | |
Kohei Takahashi | 430c6d6 | 2018-08-31 17:53:11 +0900 | [diff] [blame] | 147 | template <unsigned N> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 148 | constexpr bool StrEqual(const char* strA, const char (&strB)[N]) |
| 149 | { |
| 150 | bool isEqual = true; |
Kohei Takahashi | 430c6d6 | 2018-08-31 17:53:11 +0900 | [diff] [blame] | 151 | for (unsigned i = 0; isEqual && (i < N); ++i) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 152 | { |
| 153 | isEqual = (strA[i] == strB[i]); |
| 154 | } |
| 155 | return isEqual; |
| 156 | } |
| 157 | |
David Beck | e648871 | 2018-10-16 11:32:20 +0100 | [diff] [blame] | 158 | /// Deprecated function that will be removed together with |
| 159 | /// the Compute enum |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 160 | constexpr armnn::Compute ParseComputeDevice(const char* str) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 161 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 162 | if (armnn::StrEqual(str, "CpuAcc")) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 163 | { |
| 164 | return armnn::Compute::CpuAcc; |
| 165 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 166 | else if (armnn::StrEqual(str, "CpuRef")) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 167 | { |
| 168 | return armnn::Compute::CpuRef; |
| 169 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 170 | else if (armnn::StrEqual(str, "GpuAcc")) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 171 | { |
| 172 | return armnn::Compute::GpuAcc; |
| 173 | } |
| 174 | else |
| 175 | { |
| 176 | return armnn::Compute::Undefined; |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | constexpr const char* GetDataTypeName(DataType dataType) |
| 181 | { |
| 182 | switch (dataType) |
| 183 | { |
Aron Virginas-Tar | b67f957 | 2019-11-04 15:00:19 +0000 | [diff] [blame] | 184 | case DataType::Float16: return "Float16"; |
| 185 | case DataType::Float32: return "Float32"; |
Inki Dae | d4619e2 | 2020-09-10 15:33:54 +0900 | [diff] [blame] | 186 | case DataType::Signed64: return "Signed64"; |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 187 | case DataType::QAsymmU8: return "QAsymmU8"; |
Keith Davis | 0c2eeac | 2020-02-11 16:51:50 +0000 | [diff] [blame] | 188 | case DataType::QAsymmS8: return "QAsymmS8"; |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 189 | case DataType::QSymmS8: return "QSymmS8"; |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 190 | ARMNN_NO_DEPRECATE_WARN_BEGIN |
Aron Virginas-Tar | b67f957 | 2019-11-04 15:00:19 +0000 | [diff] [blame] | 191 | case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis"; |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 192 | ARMNN_NO_DEPRECATE_WARN_END |
| 193 | case DataType::QSymmS16: return "QSymm16"; |
Aron Virginas-Tar | b67f957 | 2019-11-04 15:00:19 +0000 | [diff] [blame] | 194 | case DataType::Signed32: return "Signed32"; |
| 195 | case DataType::Boolean: return "Boolean"; |
Narumol Prangnawarat | c3bf6ef | 2020-02-28 12:45:21 +0000 | [diff] [blame] | 196 | case DataType::BFloat16: return "BFloat16"; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 197 | |
| 198 | default: |
| 199 | return "Unknown"; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 200 | } |
| 201 | } |
| 202 | |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 203 | constexpr const char* GetDataLayoutName(DataLayout dataLayout) |
| 204 | { |
| 205 | switch (dataLayout) |
| 206 | { |
| 207 | case DataLayout::NCHW: return "NCHW"; |
| 208 | case DataLayout::NHWC: return "NHWC"; |
| 209 | default: return "Unknown"; |
| 210 | } |
| 211 | } |
| 212 | |
Teresa Charlin | 190a39a | 2020-01-23 11:44:24 +0000 | [diff] [blame] | 213 | constexpr const char* GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel) |
| 214 | { |
| 215 | switch (channel) |
| 216 | { |
| 217 | case NormalizationAlgorithmChannel::Across: return "Across"; |
| 218 | case NormalizationAlgorithmChannel::Within: return "Within"; |
| 219 | default: return "Unknown"; |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | constexpr const char* GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method) |
| 224 | { |
| 225 | switch (method) |
| 226 | { |
| 227 | case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness"; |
| 228 | case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast"; |
| 229 | default: return "Unknown"; |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | constexpr const char* GetResizeMethodAsCString(ResizeMethod method) |
| 234 | { |
| 235 | switch (method) |
| 236 | { |
| 237 | case ResizeMethod::Bilinear: return "Bilinear"; |
| 238 | case ResizeMethod::NearestNeighbor: return "NearestNeighbour"; |
| 239 | default: return "Unknown"; |
| 240 | } |
| 241 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 242 | |
| 243 | template<typename T> |
| 244 | struct IsHalfType |
| 245 | : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2> |
| 246 | {}; |
| 247 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 248 | template<typename T> |
| 249 | constexpr bool IsQuantizedType() |
| 250 | { |
| 251 | return std::is_integral<T>::value; |
| 252 | } |
| 253 | |
Keith Davis | 0c2eeac | 2020-02-11 16:51:50 +0000 | [diff] [blame] | 254 | constexpr bool IsQuantized8BitType(DataType dataType) |
Aron Virginas-Tar | e9323ec | 2019-11-26 12:50:34 +0000 | [diff] [blame] | 255 | { |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 256 | ARMNN_NO_DEPRECATE_WARN_BEGIN |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 257 | return dataType == DataType::QAsymmU8 || |
Ryan OShea | 9add120 | 2020-02-07 10:06:33 +0000 | [diff] [blame] | 258 | dataType == DataType::QAsymmS8 || |
Finn Williams | fd27106 | 2019-12-04 14:27:27 +0000 | [diff] [blame] | 259 | dataType == DataType::QSymmS8 || |
Aron Virginas-Tar | e9323ec | 2019-11-26 12:50:34 +0000 | [diff] [blame] | 260 | dataType == DataType::QuantizedSymm8PerAxis; |
Derek Lamberti | d466a54 | 2020-01-22 15:37:29 +0000 | [diff] [blame] | 261 | ARMNN_NO_DEPRECATE_WARN_END |
Aron Virginas-Tar | e9323ec | 2019-11-26 12:50:34 +0000 | [diff] [blame] | 262 | } |
| 263 | |
Keith Davis | 0c2eeac | 2020-02-11 16:51:50 +0000 | [diff] [blame] | 264 | constexpr bool IsQuantizedType(DataType dataType) |
| 265 | { |
| 266 | return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType); |
| 267 | } |
| 268 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 269 | inline std::ostream& operator<<(std::ostream& os, Status stat) |
| 270 | { |
| 271 | os << GetStatusAsCString(stat); |
| 272 | return os; |
| 273 | } |
| 274 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 275 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 276 | inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape) |
| 277 | { |
| 278 | os << "["; |
| 279 | for (uint32_t i=0; i<shape.GetNumDimensions(); ++i) |
| 280 | { |
| 281 | if (i!=0) |
| 282 | { |
| 283 | os << ","; |
| 284 | } |
| 285 | os << shape[i]; |
| 286 | } |
| 287 | os << "]"; |
| 288 | return os; |
| 289 | } |
| 290 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 291 | /// Quantize a floating point data type into an 8-bit data type. |
| 292 | /// @param value - The value to quantize. |
| 293 | /// @param scale - The scale (must be non-zero). |
| 294 | /// @param offset - The offset. |
| 295 | /// @return - The quantized value calculated as round(value/scale)+offset. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 296 | /// |
| 297 | template<typename QuantizedType> |
Aron Virginas-Tar | d4f0fea | 2019-04-09 14:08:06 +0100 | [diff] [blame] | 298 | QuantizedType Quantize(float value, float scale, int32_t offset); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 299 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 300 | /// Dequantize an 8-bit data type into a floating point data type. |
| 301 | /// @param value - The value to dequantize. |
| 302 | /// @param scale - The scale (must be non-zero). |
| 303 | /// @param offset - The offset. |
| 304 | /// @return - The dequantized value calculated as (value-offset)*scale. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 305 | /// |
| 306 | template <typename QuantizedType> |
Aron Virginas-Tar | d4f0fea | 2019-04-09 14:08:06 +0100 | [diff] [blame] | 307 | float Dequantize(QuantizedType value, float scale, int32_t offset); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 308 | |
keidav01 | 1b3e2ea | 2019-02-21 10:07:37 +0000 | [diff] [blame] | 309 | inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType) |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 310 | { |
keidav01 | 1b3e2ea | 2019-02-21 10:07:37 +0000 | [diff] [blame] | 311 | if (info.GetDataType() != dataType) |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 312 | { |
| 313 | std::stringstream ss; |
| 314 | ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType()) |
keidav01 | 1b3e2ea | 2019-02-21 10:07:37 +0000 | [diff] [blame] | 315 | << " for tensor:" << info.GetShape() |
| 316 | << ". The type expected to be: " << armnn::GetDataTypeName(dataType); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 317 | throw armnn::Exception(ss.str()); |
| 318 | } |
| 319 | } |
| 320 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 321 | } //namespace armnn |