blob: 9e3d7192112d6b8f7f0fa75e92e8a80ec75b524c [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Matteo Martincighe011d202019-11-28 11:35:47 +00006#include <armnnUtils/TensorUtils.hpp>
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +00007
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/ITensorHandle.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009#include <armnn/utility/Assert.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010010#include <armnn/utility/NumericCast.hpp>
Nina Drozdd41b2592018-11-19 13:03:36 +000011
Colm Donelan5b5c2222020-09-09 12:48:16 +010012#include <fmt/format.h>
Narumol Prangnawarat02807852019-09-11 16:43:09 +010013
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000014using namespace armnn;
15
Nina Drozdd41b2592018-11-19 13:03:36 +000016namespace armnnUtils
17{
18
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000019TensorShape GetTensorShape(unsigned int numberOfBatches,
Nina Drozdd41b2592018-11-19 13:03:36 +000020 unsigned int numberOfChannels,
21 unsigned int height,
22 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000023 const DataLayout dataLayout)
Nina Drozdd41b2592018-11-19 13:03:36 +000024{
25 switch (dataLayout)
26 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000027 case DataLayout::NCHW:
28 return TensorShape({numberOfBatches, numberOfChannels, height, width});
29 case DataLayout::NHWC:
30 return TensorShape({numberOfBatches, height, width, numberOfChannels});
Nina Drozdd41b2592018-11-19 13:03:36 +000031 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000032 throw InvalidArgumentException("Unknown data layout ["
Nina Drozdd41b2592018-11-19 13:03:36 +000033 + std::to_string(static_cast<int>(dataLayout)) +
34 "]", CHECK_LOCATION());
35 }
36}
37
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000038TensorInfo GetTensorInfo(unsigned int numberOfBatches,
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000039 unsigned int numberOfChannels,
40 unsigned int height,
41 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000042 const DataLayout dataLayout,
43 const DataType dataType)
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000044{
45 switch (dataLayout)
46 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000047 case DataLayout::NCHW:
48 return TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
49 case DataLayout::NHWC:
50 return TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000051 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000052 throw InvalidArgumentException("Unknown data layout ["
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000053 + std::to_string(static_cast<int>(dataLayout)) +
54 "]", CHECK_LOCATION());
55 }
Nina Drozdd41b2592018-11-19 13:03:36 +000056}
57
Tamás Nyíri7b885b32021-10-26 14:47:57 +010058TensorInfo GetTensorInfo(unsigned int numberOfBatches,
59 unsigned int numberOfChannels,
60 unsigned int depth,
61 unsigned int height,
62 unsigned int width,
63 const DataLayout dataLayout,
64 const DataType dataType)
65{
66 switch (dataLayout)
67 {
68 case DataLayout::NDHWC:
69 return TensorInfo({numberOfBatches, depth, height, width, numberOfChannels}, dataType);
70 case DataLayout::NCDHW:
71 return TensorInfo({numberOfBatches, numberOfChannels, depth, height, width}, dataType);
72 default:
73 throw InvalidArgumentException("Unknown data layout ["
74 + std::to_string(static_cast<int>(dataLayout)) +
75 "]", CHECK_LOCATION());
76 }
77}
78
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000079std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
Jim Flynnf92dfce2019-05-02 11:33:25 +010080{
81 auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));
82 auto tensor_size = tensorHandle->GetShape().GetNumElements();
83
84 // Set min/max initially to first value in tensor
85 float min = tensor_data[0];
86 float max = tensor_data[0];
87
88 // Loop over rest of tensor and update min/max if necessary
89 for (unsigned int val = 1; val < tensor_size; val++)
90 {
91 if (tensor_data[val] < min)
92 {
93 min = tensor_data[val];
94 }
95 else if (tensor_data[val] > max)
96 {
97 max = tensor_data[val];
98 }
99 }
100
101 tensorHandle->Unmap();
102
103 return std::make_pair(min, max);
104}
105
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000106TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100107{
108 unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
109
Matthew Sloyan0663d662020-09-14 11:47:26 +0100110 if (axis < -armnn::numeric_cast<int>(outputDim) || axis > armnn::numeric_cast<int>(tensorShape.GetNumDimensions()))
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100111 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100112 throw InvalidArgumentException(fmt::format("Invalid expansion axis {} for {}D input tensor. {}",
113 axis,
114 tensorShape.GetNumDimensions(),
115 CHECK_LOCATION().AsString()));
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100116 }
117
118 if (axis < 0)
119 {
Matthew Sloyan0663d662020-09-14 11:47:26 +0100120 axis = armnn::numeric_cast<int>(outputDim) + axis;
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100121 }
122
123 std::vector<unsigned int> outputShape;
Colm Donelan5b5c2222020-09-09 12:48:16 +0100124 outputShape.reserve(tensorShape.GetNumDimensions());
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100125 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
126 {
127 outputShape.push_back(tensorShape[i]);
128 }
129 outputShape.insert(outputShape.begin() + axis, 1);
130
Mike Kelly0506ef02023-01-03 16:29:44 +0000131 return { outputDim, outputShape.data() };
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100132}
133
Mike Kelly80512b02022-05-16 23:10:42 +0100134std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape)
135{
Mike Kelly80512b02022-05-16 23:10:42 +0100136 std::vector<unsigned int> squeezedDims;
137
138 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
139 {
140 if (tensorShape[i] != 1)
141 {
142 squeezedDims.push_back(tensorShape[i]);
Mike Kelly80512b02022-05-16 23:10:42 +0100143 }
144 }
145 return squeezedDims;
146}
147
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000148unsigned int GetNumElementsBetween(const TensorShape& shape,
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100149 const unsigned int firstAxisInclusive,
150 const unsigned int lastAxisExclusive)
151{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100152 ARMNN_ASSERT(firstAxisInclusive <= lastAxisExclusive);
153 ARMNN_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100154 unsigned int count = 1;
155 for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
156 {
157 count *= shape[i];
158 }
159 return count;
160}
161
162unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
163{
Matthew Sloyan0663d662020-09-14 11:47:26 +0100164 ARMNN_ASSERT_MSG(axis < armnn::numeric_cast<int>(inputDimension),
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100165 "Required axis index greater than number of dimensions.");
Matthew Sloyan0663d662020-09-14 11:47:26 +0100166 ARMNN_ASSERT_MSG(axis >= -armnn::numeric_cast<int>(inputDimension),
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100167 "Required axis index lower than negative of the number of dimensions");
168
169 unsigned int uAxis = axis < 0 ?
Matthew Sloyan0663d662020-09-14 11:47:26 +0100170 inputDimension - armnn::numeric_cast<unsigned int>(abs(axis))
171 : armnn::numeric_cast<unsigned int>(axis);
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100172 return uAxis;
173}
174
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000175unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, unsigned int axis)
176{
177 unsigned int numDim = shape.GetNumDimensions();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100178 ARMNN_ASSERT(axis <= numDim - 1);
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000179 unsigned int count = 1;
Jan Eilers53ef7952021-06-02 12:01:25 +0100180 for (unsigned int i = axis+1; i < numDim; i++)
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000181 {
182 count *= shape[i];
183 }
184 return count;
185}
186
187std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
188{
189 const std::vector<float>& scales = info.GetQuantizationScales();
190 armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000191 if (!info.HasPerAxisQuantization())
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000192 {
193 throw armnn::InvalidArgumentException(
194 std::string("Per-axis quantization params not set for tensor of type ") +
195 armnn::GetDataTypeName(info.GetDataType()), CHECK_LOCATION());
196 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100197 unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()) ;
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000198
199 return { axisFactor, scales };
200}
201
Mike Kelly0506ef02023-01-03 16:29:44 +0000202template<typename PrimitiveType>
203void CheckSizes(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo, unsigned int size = 1)
204{
205 if (data.size() / size != tensorInfo.GetNumElements())
206 {
207 throw InvalidArgumentException(
208 fmt::format("The data does not contain the expected number of elements {} != {}. {}",
209 data.size(), tensorInfo.GetNumElements(), CHECK_LOCATION().AsString()));
210 }
211}
212
213template<typename PrimitiveType>
214std::unique_ptr<float[]> ToFloatArray(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo)
215{
216 CheckSizes(data, tensorInfo);
217
218 std::unique_ptr<float[]> returnBuffer(new float[tensorInfo.GetNumElements()]);
219
220 if (tensorInfo.HasPerAxisQuantization())
221 {
222 unsigned int axis = tensorInfo.GetQuantizationDim().value();
223 auto axisDimensionality = tensorInfo.GetShape()[axis];
224 auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis);
225
226 for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
227 {
228 unsigned int axisIndex;
229
230 if (i < axisFactor)
231 {
232 axisIndex = 0;
233 }
234 else
235 {
236 axisIndex = (i / axisFactor) % axisDimensionality;
237 }
238 returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
239 tensorInfo.GetQuantizationScales()[axisIndex],
240 tensorInfo.GetQuantizationOffset());
241 }
242 }
243 else
244 {
245 for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
246 {
247 returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
248 tensorInfo.GetQuantizationScale(),
249 tensorInfo.GetQuantizationOffset());
250 }
251 }
252 return returnBuffer;
253}
254
255std::unique_ptr<float[]> ToFloatArray(const std::vector<uint8_t>& data, const armnn::TensorInfo& tensorInfo)
256{
257 if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8)
258 {
259 CheckSizes(data, tensorInfo);
260 std::vector<int8_t> buffer(tensorInfo.GetNumElements());
261 ::memcpy(buffer.data(), data.data(), data.size());
262 return ToFloatArray<int8_t>(buffer, tensorInfo);
263 }
264 else if (tensorInfo.GetDataType() == DataType::QAsymmU8)
265 {
266 CheckSizes(data, tensorInfo);
267 return ToFloatArray<uint8_t>(data, tensorInfo);
268 }
269 else if (tensorInfo.GetDataType() == DataType::Signed32)
270 {
271 CheckSizes(data, tensorInfo, 4);
272 std::vector<int32_t> buffer(tensorInfo.GetNumElements());
273 ::memcpy(buffer.data(), data.data(), data.size());
274 return ToFloatArray<int32_t>(buffer, tensorInfo);
275 }
276 else if (tensorInfo.GetDataType() == DataType::Signed64)
277 {
278 CheckSizes(data, tensorInfo, 8);
279 std::vector<int64_t> buffer(tensorInfo.GetNumElements());
280 ::memcpy(buffer.data(), data.data(), data.size());
281 return ToFloatArray<int64_t>(buffer, tensorInfo);
282 }
283 throw InvalidArgumentException(
284 fmt::format("Unsupported datatype {}. {}",
285 GetDataTypeName(tensorInfo.GetDataType()),
286 CHECK_LOCATION().AsString()));
287}
288
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000289} // namespace armnnUtils