blob: c7d250a7067e18412dbae58268013b81ab644a86 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005#include <aclCommon/ArmComputeTensorUtils.hpp>
6#include <aclCommon/ArmComputeUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00007
Francis Murtagh351d13d2018-09-24 15:01:18 +01008#include "armnn/Exceptions.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Descriptors.hpp>
10
11namespace armnn
12{
13namespace armcomputetensorutils
14{
15
16arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
17{
18 switch(dataType)
19 {
telsoa01c577f2c2018-08-31 09:22:23 +010020 case armnn::DataType::Float16:
21 return arm_compute::DataType::F16;
telsoa014fcda012018-03-09 14:13:49 +000022 case armnn::DataType::Float32:
telsoa014fcda012018-03-09 14:13:49 +000023 return arm_compute::DataType::F32;
telsoa014fcda012018-03-09 14:13:49 +000024 case armnn::DataType::QuantisedAsymm8:
telsoa014fcda012018-03-09 14:13:49 +000025 return arm_compute::DataType::QASYMM8;
Aron Virginas-Tar7a3e2fe2019-06-27 18:54:47 +010026 case armnn::DataType::QuantisedSymm16:
27 return arm_compute::DataType::QSYMM16;
telsoa014fcda012018-03-09 14:13:49 +000028 case armnn::DataType::Signed32:
telsoa014fcda012018-03-09 14:13:49 +000029 return arm_compute::DataType::S32;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +000030 case armnn::DataType::Boolean:
31 return arm_compute::DataType::U8;
telsoa014fcda012018-03-09 14:13:49 +000032 default:
telsoa014fcda012018-03-09 14:13:49 +000033 BOOST_ASSERT_MSG(false, "Unknown data type");
34 return arm_compute::DataType::UNKNOWN;
telsoa014fcda012018-03-09 14:13:49 +000035 }
36}
37
Matthew Benthamfd899962018-12-31 15:49:42 +000038arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
39 unsigned int originalInputRank,
40 const std::vector<unsigned int>& armnnAxes)
41{
42 arm_compute::Coordinates outAclCoords;
43
44 if (armnnAxes.empty())
45 {
46 // If no reduction axes were provided, then the input must be reduced along all dimensions.
47 // Since Compute Library does not accept an empty vector as the reduction dimensions, we then
48 // manually create a vector including all the input dimensions (in reversed order) as:
49 //
50 // { inputDimensions - 1, inputDimensions - 2, ..., 1, 0 }
51 //
52 outAclCoords.set_num_dimensions(inputDimensions);
53 std::generate(outAclCoords.begin(), outAclCoords.end(), [d = inputDimensions - 1] () mutable { return d--; });
54 }
55 else
56 {
57 // Create a vector of reduction dimensions (in reversed order) with the given reduction axes.
58 //
59 // Adjust the given reduction axes according to the original rank of the input tensor (before ACL applied any
60 // dimension correction).
61 // For example, if the input tensor originally had 4 dimensions, and one of the reduction axes was 2, then the
62 // new value for that reduction axis should be 1.
63 //
64 // Example:
65 // ArmNN input shape = { 1, 1, 3, 2 } -> ACL input shape = { 2, 3 }
66 // ArmNN reduction axis = { 2 } -> ACL reduction axis = { 1 }
67 // ArmNN reduction axis = { 3 } -> ACL reduction axis = { 0 }
68 //
69 // The transformation: ACL reduction axis index = original rank - ArmNN reduction axis index - 1
70 //
71 outAclCoords.set_num_dimensions(armnnAxes.size());
72 std::transform(armnnAxes.begin(), armnnAxes.end(),
73 outAclCoords.begin(),
74 [originalInputRank](unsigned int i){ return originalInputRank - i - 1; });
75 }
76
77 return outAclCoords;
78}
79
telsoa014fcda012018-03-09 14:13:49 +000080arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
81{
82 arm_compute::TensorShape shape;
83
telsoa01c577f2c2018-08-31 09:22:23 +010084 // armnn tensors are (batch, channels, height, width).
85 // arm_compute tensors are (width, height, channels, batch).
telsoa014fcda012018-03-09 14:13:49 +000086 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
87 {
telsoa01c577f2c2018-08-31 09:22:23 +010088 // Note that our dimensions are stored in the opposite order to ACL's.
Matthew Bentham89105282018-11-20 14:33:33 +000089 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i], false);
telsoa014fcda012018-03-09 14:13:49 +000090
91 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
telsoa01c577f2c2018-08-31 09:22:23 +010092 // arm_compute tensors expect this.
telsoa014fcda012018-03-09 14:13:49 +000093 }
94
95 // prevent arm_compute issue where tensor is flattened to nothing
96 if (shape.num_dimensions() == 0)
97 {
98 shape.set_num_dimensions(1);
99 }
100
101 return shape;
102}
103
104// Utility function used to build a TensorInfo object, that can be used to initialise
105// ARM Compute Tensor and CLTensor allocators.
106arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
107{
108 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
109 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
110 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
111 tensorInfo.GetQuantizationOffset());
112
113 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
114}
115
Francis Murtagh351d13d2018-09-24 15:01:18 +0100116arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
117 armnn::DataLayout dataLayout)
118{
119 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
120 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
121 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
122 tensorInfo.GetQuantizationOffset());
123
124 arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
125 clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
126
127 return clTensorInfo;
128}
129
Matteo Martincigh747ef822018-12-18 09:26:39 +0000130arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
131{
132 switch(dataLayout)
133 {
134 case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
135
136 case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
137
138 default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
139 std::to_string(static_cast<int>(dataLayout)) + "]");
140 }
141}
142
Sadik Armagana3600ba2019-10-10 10:43:20 +0100143arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
144 bool fpMixedPrecision)
telsoa014fcda012018-03-09 14:13:49 +0000145{
146 using arm_compute::PoolingType;
147 using arm_compute::DimensionRoundingType;
148 using arm_compute::PadStrideInfo;
149 using arm_compute::PoolingLayerInfo;
surmeh01bceff2f2018-03-29 16:29:27 +0100150 using arm_compute::Size2D;
telsoa014fcda012018-03-09 14:13:49 +0000151
telsoa01c577f2c2018-08-31 09:22:23 +0100152 // Resolve ARM Compute layer parameters.
telsoa014fcda012018-03-09 14:13:49 +0000153 const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
telsoa01c577f2c2018-08-31 09:22:23 +0100154
155 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
156 //use specific constructor if global pooling
157 if(isGlobalPooling)
158 {
159 return arm_compute::PoolingLayerInfo(poolingType);
160 }
161
telsoa014fcda012018-03-09 14:13:49 +0000162 const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
163 descriptor.m_OutputShapeRounding);
telsoa014fcda012018-03-09 14:13:49 +0000164 const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
165 descriptor.m_StrideY,
166 descriptor.m_PadLeft,
167 descriptor.m_PadRight,
168 descriptor.m_PadTop,
169 descriptor.m_PadBottom,
170 rounding);
171
172 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
173
surmeh01bceff2f2018-03-29 16:29:27 +0100174 const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
175
Sadik Armagana3600ba2019-10-10 10:43:20 +0100176 return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding, fpMixedPrecision);
telsoa014fcda012018-03-09 14:13:49 +0000177}
178
179arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
180{
181 const arm_compute::NormType normType =
182 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
183 return arm_compute::NormalizationLayerInfo(normType,
184 descriptor.m_NormSize,
185 descriptor.m_Alpha,
186 descriptor.m_Beta,
187 descriptor.m_K,
188 false);
189}
190
191arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
192{
193 arm_compute::PermutationVector aclPerm;
194
195 unsigned int start = 0;
surmeh01bceff2f2018-03-29 16:29:27 +0100196 while ((start < perm.GetSize()) && (start == perm[start]))
telsoa014fcda012018-03-09 14:13:49 +0000197 {
198 ++start;
199 }
200
201 for (unsigned int i = start; i < perm.GetSize(); ++i)
202 {
203 aclPerm.set(i - start, perm[i] - start);
204 }
205
206 return aclPerm;
207}
208
Sadik Armaganf4464322018-12-20 16:19:12 +0000209arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height)
210{
211 return arm_compute::Size2D(width, height);
212}
213
Mike Kelly0a08ec62019-07-25 08:39:31 +0100214arm_compute::PixelValue GetPixelValue(arm_compute::ITensor& input, float pixelValue)
215{
216 switch (input.info()->data_type())
217 {
218 case arm_compute::DataType::QASYMM8:
219 return arm_compute::PixelValue(static_cast<uint8_t>(pixelValue));
220 case arm_compute::DataType::QSYMM16:
221 return arm_compute::PixelValue(static_cast<int16_t>(pixelValue));
222 case arm_compute::DataType::F16:
223 return arm_compute::PixelValue(static_cast<Half>(pixelValue));
224 case arm_compute::DataType::F32:
225 return arm_compute::PixelValue(pixelValue);
226 default:
227 throw InvalidArgumentException("Unsupported DataType: [" +
228 std::to_string(static_cast<int>(input.info()->data_type())) + "]");
229 }
230}
231
telsoa014fcda012018-03-09 14:13:49 +0000232} // namespace armcomputetensorutils
233} // namespace armnn