blob: c63b653ae34f6a3883330f4efdb7197342276e60 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "Types.hpp"
surmeh013537c2c2018-05-18 16:31:43 +01008#include "Tensor.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <cmath>
10#include <ostream>
11#include <boost/assert.hpp>
12#include <boost/numeric/conversion/cast.hpp>
13
14namespace armnn
15{
16
17constexpr char const* GetStatusAsCString(Status compute)
18{
19 switch (compute)
20 {
21 case armnn::Status::Success: return "Status::Success";
22 case armnn::Status::Failure: return "Status::Failure";
23 default: return "Unknown";
24 }
25}
26
27constexpr char const* GetComputeDeviceAsCString(Compute compute)
28{
29 switch (compute)
30 {
31 case armnn::Compute::CpuRef: return "CpuRef";
32 case armnn::Compute::CpuAcc: return "CpuAcc";
33 case armnn::Compute::GpuAcc: return "GpuAcc";
34 default: return "Unknown";
35 }
36}
37
surmeh01bceff2f2018-03-29 16:29:27 +010038constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
39{
40 switch (activation)
41 {
42 case ActivationFunction::Sigmoid: return "Sigmoid";
43 case ActivationFunction::TanH: return "TanH";
44 case ActivationFunction::Linear: return "Linear";
45 case ActivationFunction::ReLu: return "ReLu";
46 case ActivationFunction::BoundedReLu: return "BoundedReLu";
47 case ActivationFunction::SoftReLu: return "SoftReLu";
48 case ActivationFunction::LeakyReLu: return "LeakyReLu";
49 case ActivationFunction::Abs: return "Abs";
50 case ActivationFunction::Sqrt: return "Sqrt";
51 case ActivationFunction::Square: return "Square";
52 default: return "Unknown";
53 }
54}
55
56constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
57{
58 switch (pooling)
59 {
60 case PoolingAlgorithm::Average: return "Average";
61 case PoolingAlgorithm::Max: return "Max";
62 case PoolingAlgorithm::L2: return "L2";
63 default: return "Unknown";
64 }
65}
66
67constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
68{
69 switch (rounding)
70 {
71 case OutputShapeRounding::Ceiling: return "Ceiling";
72 case OutputShapeRounding::Floor: return "Floor";
73 default: return "Unknown";
74 }
75}
76
77
78constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
79{
80 switch (method)
81 {
82 case PaddingMethod::Exclude: return "Exclude";
83 case PaddingMethod::IgnoreValue: return "IgnoreValue";
84 default: return "Unknown";
85 }
86}
87
telsoa014fcda012018-03-09 14:13:49 +000088constexpr unsigned int GetDataTypeSize(DataType dataType)
89{
90 switch (dataType)
91 {
92 case DataType::Signed32:
93 case DataType::Float32: return 4U;
94 case DataType::QuantisedAsymm8: return 1U;
95 default: return 0U;
96 }
97}
98
99template <int N>
100constexpr bool StrEqual(const char* strA, const char (&strB)[N])
101{
102 bool isEqual = true;
103 for (int i = 0; isEqual && (i < N); ++i)
104 {
105 isEqual = (strA[i] == strB[i]);
106 }
107 return isEqual;
108}
109
110constexpr Compute ParseComputeDevice(const char* str)
111{
112 if (StrEqual(str, "CpuAcc"))
113 {
114 return armnn::Compute::CpuAcc;
115 }
116 else if (StrEqual(str, "CpuRef"))
117 {
118 return armnn::Compute::CpuRef;
119 }
120 else if (StrEqual(str, "GpuAcc"))
121 {
122 return armnn::Compute::GpuAcc;
123 }
124 else
125 {
126 return armnn::Compute::Undefined;
127 }
128}
129
130constexpr const char* GetDataTypeName(DataType dataType)
131{
132 switch (dataType)
133 {
134 case DataType::Float32: return "Float32";
135 case DataType::QuantisedAsymm8: return "Unsigned8";
136 case DataType::Signed32: return "Signed32";
137 default: return "Unknown";
138 }
139}
140
141template <typename T>
142constexpr DataType GetDataType();
143
144template <>
145constexpr DataType GetDataType<float>()
146{
147 return DataType::Float32;
148}
149
150template <>
151constexpr DataType GetDataType<uint8_t>()
152{
153 return DataType::QuantisedAsymm8;
154}
155
156template <>
157constexpr DataType GetDataType<int32_t>()
158{
159 return DataType::Signed32;
160}
161
162template<typename T>
163constexpr bool IsQuantizedType()
164{
165 return std::is_integral<T>::value;
166}
167
168
169template<DataType DT>
170struct ResolveTypeImpl;
171
172template<>
173struct ResolveTypeImpl<DataType::QuantisedAsymm8>
174{
175 using Type = uint8_t;
176};
177
178template<>
179struct ResolveTypeImpl<DataType::Float32>
180{
181 using Type = float;
182};
183
184template<DataType DT>
185using ResolveType = typename ResolveTypeImpl<DT>::Type;
186
187
188inline std::ostream& operator<<(std::ostream& os, Status stat)
189{
190 os << GetStatusAsCString(stat);
191 return os;
192}
193
194inline std::ostream& operator<<(std::ostream& os, Compute compute)
195{
196 os << GetComputeDeviceAsCString(compute);
197 return os;
198}
199
surmeh013537c2c2018-05-18 16:31:43 +0100200inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
201{
202 os << "[";
203 for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
204 {
205 if (i!=0)
206 {
207 os << ",";
208 }
209 os << shape[i];
210 }
211 os << "]";
212 return os;
213}
214
telsoa014fcda012018-03-09 14:13:49 +0000215/// Quantize a floating point data type into an 8-bit data type
216/// @param value The value to quantize
217/// @param scale The scale (must be non-zero)
218/// @param offset The offset
219/// @return The quantized value calculated as round(value/scale)+offset
220///
221template<typename QuantizedType>
222inline QuantizedType Quantize(float value, float scale, int32_t offset)
223{
224 static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
225 constexpr QuantizedType max = std::numeric_limits<QuantizedType>::max();
226 constexpr QuantizedType min = std::numeric_limits<QuantizedType>::lowest();
227 BOOST_ASSERT(scale != 0.f);
228 int quantized = boost::numeric_cast<int>(round(value / scale)) + offset;
surmeh013537c2c2018-05-18 16:31:43 +0100229 QuantizedType quantizedBits = quantized <= min
230 ? min
231 : quantized >= max
232 ? max
233 : static_cast<QuantizedType>(quantized);
telsoa014fcda012018-03-09 14:13:49 +0000234 return quantizedBits;
235}
236
237/// Dequantize an 8-bit data type into a floating point data type
238/// @param value The value to dequantize
239/// @param scale The scale (must be non-zero)
240/// @param offset The offset
241/// @return The dequantized value calculated as (value-offset)*scale
242///
243template <typename QuantizedType>
244inline float Dequantize(QuantizedType value, float scale, int32_t offset)
245{
246 static_assert(IsQuantizedType<QuantizedType>(), "Not an integer type.");
247 BOOST_ASSERT(scale != 0.f);
248 float dequantized = boost::numeric_cast<float>(value - offset) * scale;
249 return dequantized;
250}
251
surmeh013537c2c2018-05-18 16:31:43 +0100252} //namespace armnn