blob: 04202ada90873b96deedaa4fab60d45df292df08 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005#include <aclCommon/ArmComputeTensorUtils.hpp>
6#include <aclCommon/ArmComputeUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00007
Francis Murtagh351d13d2018-09-24 15:01:18 +01008#include "armnn/Exceptions.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Descriptors.hpp>
10
11namespace armnn
12{
13namespace armcomputetensorutils
14{
15
Derek Lambertid466a542020-01-22 15:37:29 +000016arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
telsoa014fcda012018-03-09 14:13:49 +000017{
18 switch(dataType)
19 {
Mike Kelly130ec602019-11-08 12:08:35 +000020 case armnn::DataType::Boolean:
21 return arm_compute::DataType::U8;
telsoa01c577f2c2018-08-31 09:22:23 +010022 case armnn::DataType::Float16:
23 return arm_compute::DataType::F16;
telsoa014fcda012018-03-09 14:13:49 +000024 case armnn::DataType::Float32:
telsoa014fcda012018-03-09 14:13:49 +000025 return arm_compute::DataType::F32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000026 case armnn::DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000027 return arm_compute::DataType::QASYMM8;
Derek Lambertif90c56d2020-01-10 17:14:08 +000028 case armnn::DataType::QSymmS16:
Aron Virginas-Tar7a3e2fe2019-06-27 18:54:47 +010029 return arm_compute::DataType::QSYMM16;
Finn Williamsfd271062019-12-04 14:27:27 +000030 case armnn::DataType::QSymmS8:
Derek Lambertid466a542020-01-22 15:37:29 +000031 {
32 return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
33 }
34 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kelly130ec602019-11-08 12:08:35 +000035 case armnn::DataType::QuantizedSymm8PerAxis:
36 return arm_compute::DataType::QSYMM8_PER_CHANNEL;
Derek Lambertid466a542020-01-22 15:37:29 +000037 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +000038 case armnn::DataType::Signed32:
telsoa014fcda012018-03-09 14:13:49 +000039 return arm_compute::DataType::S32;
telsoa014fcda012018-03-09 14:13:49 +000040 default:
telsoa014fcda012018-03-09 14:13:49 +000041 BOOST_ASSERT_MSG(false, "Unknown data type");
42 return arm_compute::DataType::UNKNOWN;
telsoa014fcda012018-03-09 14:13:49 +000043 }
44}
45
Matthew Benthamfd899962018-12-31 15:49:42 +000046arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
47 unsigned int originalInputRank,
48 const std::vector<unsigned int>& armnnAxes)
49{
50 arm_compute::Coordinates outAclCoords;
51
52 if (armnnAxes.empty())
53 {
54 // If no reduction axes were provided, then the input must be reduced along all dimensions.
55 // Since Compute Library does not accept an empty vector as the reduction dimensions, we then
56 // manually create a vector including all the input dimensions (in reversed order) as:
57 //
58 // { inputDimensions - 1, inputDimensions - 2, ..., 1, 0 }
59 //
60 outAclCoords.set_num_dimensions(inputDimensions);
61 std::generate(outAclCoords.begin(), outAclCoords.end(), [d = inputDimensions - 1] () mutable { return d--; });
62 }
63 else
64 {
65 // Create a vector of reduction dimensions (in reversed order) with the given reduction axes.
66 //
67 // Adjust the given reduction axes according to the original rank of the input tensor (before ACL applied any
68 // dimension correction).
69 // For example, if the input tensor originally had 4 dimensions, and one of the reduction axes was 2, then the
70 // new value for that reduction axis should be 1.
71 //
72 // Example:
73 // ArmNN input shape = { 1, 1, 3, 2 } -> ACL input shape = { 2, 3 }
74 // ArmNN reduction axis = { 2 } -> ACL reduction axis = { 1 }
75 // ArmNN reduction axis = { 3 } -> ACL reduction axis = { 0 }
76 //
77 // The transformation: ACL reduction axis index = original rank - ArmNN reduction axis index - 1
78 //
79 outAclCoords.set_num_dimensions(armnnAxes.size());
80 std::transform(armnnAxes.begin(), armnnAxes.end(),
81 outAclCoords.begin(),
82 [originalInputRank](unsigned int i){ return originalInputRank - i - 1; });
83 }
84
85 return outAclCoords;
86}
87
telsoa014fcda012018-03-09 14:13:49 +000088arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
89{
90 arm_compute::TensorShape shape;
91
telsoa01c577f2c2018-08-31 09:22:23 +010092 // armnn tensors are (batch, channels, height, width).
93 // arm_compute tensors are (width, height, channels, batch).
telsoa014fcda012018-03-09 14:13:49 +000094 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
95 {
telsoa01c577f2c2018-08-31 09:22:23 +010096 // Note that our dimensions are stored in the opposite order to ACL's.
Matthew Bentham89105282018-11-20 14:33:33 +000097 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i], false);
telsoa014fcda012018-03-09 14:13:49 +000098
99 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
telsoa01c577f2c2018-08-31 09:22:23 +0100100 // arm_compute tensors expect this.
telsoa014fcda012018-03-09 14:13:49 +0000101 }
102
103 // prevent arm_compute issue where tensor is flattened to nothing
104 if (shape.num_dimensions() == 0)
105 {
106 shape.set_num_dimensions(1);
107 }
108
109 return shape;
110}
111
112// Utility function used to build a TensorInfo object, that can be used to initialise
113// ARM Compute Tensor and CLTensor allocators.
114arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
115{
Derek Lambertid466a542020-01-22 15:37:29 +0000116 bool multiScales = tensorInfo.HasMultipleQuantizationScales();
telsoa014fcda012018-03-09 14:13:49 +0000117 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
Derek Lambertid466a542020-01-22 15:37:29 +0000118 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000119
Derek Lambertid466a542020-01-22 15:37:29 +0000120 const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000121 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
122 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
telsoa014fcda012018-03-09 14:13:49 +0000123
124 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
125}
126
Francis Murtagh351d13d2018-09-24 15:01:18 +0100127arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
128 armnn::DataLayout dataLayout)
129{
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000130 arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
131 aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
Francis Murtagh351d13d2018-09-24 15:01:18 +0100132
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000133 return aclTensorInfo;
Francis Murtagh351d13d2018-09-24 15:01:18 +0100134}
135
Matteo Martincigh747ef822018-12-18 09:26:39 +0000136arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
137{
138 switch(dataLayout)
139 {
140 case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
141
142 case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
143
144 default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
145 std::to_string(static_cast<int>(dataLayout)) + "]");
146 }
147}
148
Sadik Armagana3600ba2019-10-10 10:43:20 +0100149arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
150 bool fpMixedPrecision)
telsoa014fcda012018-03-09 14:13:49 +0000151{
152 using arm_compute::PoolingType;
153 using arm_compute::DimensionRoundingType;
154 using arm_compute::PadStrideInfo;
155 using arm_compute::PoolingLayerInfo;
surmeh01bceff2f2018-03-29 16:29:27 +0100156 using arm_compute::Size2D;
telsoa014fcda012018-03-09 14:13:49 +0000157
telsoa01c577f2c2018-08-31 09:22:23 +0100158 // Resolve ARM Compute layer parameters.
telsoa014fcda012018-03-09 14:13:49 +0000159 const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
telsoa01c577f2c2018-08-31 09:22:23 +0100160
161 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
162 //use specific constructor if global pooling
163 if(isGlobalPooling)
164 {
165 return arm_compute::PoolingLayerInfo(poolingType);
166 }
167
telsoa014fcda012018-03-09 14:13:49 +0000168 const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
169 descriptor.m_OutputShapeRounding);
telsoa014fcda012018-03-09 14:13:49 +0000170 const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
171 descriptor.m_StrideY,
172 descriptor.m_PadLeft,
173 descriptor.m_PadRight,
174 descriptor.m_PadTop,
175 descriptor.m_PadBottom,
176 rounding);
177
178 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
179
surmeh01bceff2f2018-03-29 16:29:27 +0100180 const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
181
Sadik Armagana3600ba2019-10-10 10:43:20 +0100182 return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding, fpMixedPrecision);
telsoa014fcda012018-03-09 14:13:49 +0000183}
184
185arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
186{
187 const arm_compute::NormType normType =
188 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
189 return arm_compute::NormalizationLayerInfo(normType,
190 descriptor.m_NormSize,
191 descriptor.m_Alpha,
192 descriptor.m_Beta,
193 descriptor.m_K,
194 false);
195}
196
197arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
198{
199 arm_compute::PermutationVector aclPerm;
200
201 unsigned int start = 0;
surmeh01bceff2f2018-03-29 16:29:27 +0100202 while ((start < perm.GetSize()) && (start == perm[start]))
telsoa014fcda012018-03-09 14:13:49 +0000203 {
204 ++start;
205 }
206
207 for (unsigned int i = start; i < perm.GetSize(); ++i)
208 {
209 aclPerm.set(i - start, perm[i] - start);
210 }
211
212 return aclPerm;
213}
214
Sadik Armaganf4464322018-12-20 16:19:12 +0000215arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height)
216{
217 return arm_compute::Size2D(width, height);
218}
219
Mike Kelly0a08ec62019-07-25 08:39:31 +0100220arm_compute::PixelValue GetPixelValue(arm_compute::ITensor& input, float pixelValue)
221{
222 switch (input.info()->data_type())
223 {
Mike Kelly0a08ec62019-07-25 08:39:31 +0100224 case arm_compute::DataType::F16:
225 return arm_compute::PixelValue(static_cast<Half>(pixelValue));
226 case arm_compute::DataType::F32:
227 return arm_compute::PixelValue(pixelValue);
Mike Kelly130ec602019-11-08 12:08:35 +0000228 case arm_compute::DataType::QASYMM8:
229 return arm_compute::PixelValue(static_cast<uint8_t>(pixelValue));
230 case arm_compute::DataType::QSYMM16:
231 return arm_compute::PixelValue(static_cast<int16_t>(pixelValue));
232 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
233 return arm_compute::PixelValue(static_cast<int8_t>(pixelValue));
Mike Kelly0a08ec62019-07-25 08:39:31 +0100234 default:
235 throw InvalidArgumentException("Unsupported DataType: [" +
236 std::to_string(static_cast<int>(input.info()->data_type())) + "]");
237 }
238}
239
Aron Virginas-Tar710f6642019-11-27 14:48:32 +0000240bool IsQuantMultiplierSupported(const TensorInfo& input,
241 const TensorInfo& output,
242 const TensorInfo& weights)
243{
244 constexpr float maxQuantMultiplier = 1.0f;
245 if (weights.HasMultipleQuantizationScales())
246 {
247 for (float weightScale : weights.GetQuantizationScales())
248 {
249 if ((input.GetQuantizationScale() * weightScale) / output.GetQuantizationScale() > maxQuantMultiplier)
250 {
251 return false;
252 }
253 }
254 }
255 else
256 {
257 if ((input.GetQuantizationScale() * weights.GetQuantizationScale()) /
258 output.GetQuantizationScale() > maxQuantMultiplier)
259 {
260 return false;
261 }
262 }
263
264 return true;
265}
266
telsoa014fcda012018-03-09 14:13:49 +0000267} // namespace armcomputetensorutils
268} // namespace armnn