blob: cb73d92ef871b81660363a77d560ad3fb3acd794 [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
Mike Kelly0e3fe102023-01-23 19:32:06 +00002// Copyright © 2017-2023 Arm Ltd. All rights reserved.
Nina Drozdd41b2592018-11-19 13:03:36 +00003// SPDX-License-Identifier: MIT
4//
5
Matteo Martincighe011d202019-11-28 11:35:47 +00006#include <armnnUtils/TensorUtils.hpp>
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +00007
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/ITensorHandle.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009#include <armnn/utility/Assert.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010010#include <armnn/utility/NumericCast.hpp>
Nina Drozdd41b2592018-11-19 13:03:36 +000011
Colm Donelan5b5c2222020-09-09 12:48:16 +010012#include <fmt/format.h>
Narumol Prangnawarat02807852019-09-11 16:43:09 +010013
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000014using namespace armnn;
15
Nina Drozdd41b2592018-11-19 13:03:36 +000016namespace armnnUtils
17{
18
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000019TensorShape GetTensorShape(unsigned int numberOfBatches,
Nina Drozdd41b2592018-11-19 13:03:36 +000020 unsigned int numberOfChannels,
21 unsigned int height,
22 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000023 const DataLayout dataLayout)
Nina Drozdd41b2592018-11-19 13:03:36 +000024{
25 switch (dataLayout)
26 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000027 case DataLayout::NCHW:
28 return TensorShape({numberOfBatches, numberOfChannels, height, width});
29 case DataLayout::NHWC:
30 return TensorShape({numberOfBatches, height, width, numberOfChannels});
Nina Drozdd41b2592018-11-19 13:03:36 +000031 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000032 throw InvalidArgumentException("Unknown data layout ["
Nina Drozdd41b2592018-11-19 13:03:36 +000033 + std::to_string(static_cast<int>(dataLayout)) +
34 "]", CHECK_LOCATION());
35 }
36}
37
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000038TensorInfo GetTensorInfo(unsigned int numberOfBatches,
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000039 unsigned int numberOfChannels,
40 unsigned int height,
41 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000042 const DataLayout dataLayout,
43 const DataType dataType)
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000044{
45 switch (dataLayout)
46 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000047 case DataLayout::NCHW:
48 return TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
49 case DataLayout::NHWC:
50 return TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000051 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000052 throw InvalidArgumentException("Unknown data layout ["
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000053 + std::to_string(static_cast<int>(dataLayout)) +
54 "]", CHECK_LOCATION());
55 }
Nina Drozdd41b2592018-11-19 13:03:36 +000056}
57
Tamás Nyíri7b885b32021-10-26 14:47:57 +010058TensorInfo GetTensorInfo(unsigned int numberOfBatches,
59 unsigned int numberOfChannels,
60 unsigned int depth,
61 unsigned int height,
62 unsigned int width,
63 const DataLayout dataLayout,
64 const DataType dataType)
65{
66 switch (dataLayout)
67 {
68 case DataLayout::NDHWC:
69 return TensorInfo({numberOfBatches, depth, height, width, numberOfChannels}, dataType);
70 case DataLayout::NCDHW:
71 return TensorInfo({numberOfBatches, numberOfChannels, depth, height, width}, dataType);
72 default:
73 throw InvalidArgumentException("Unknown data layout ["
74 + std::to_string(static_cast<int>(dataLayout)) +
75 "]", CHECK_LOCATION());
76 }
77}
78
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000079std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
Jim Flynnf92dfce2019-05-02 11:33:25 +010080{
81 auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));
82 auto tensor_size = tensorHandle->GetShape().GetNumElements();
83
84 // Set min/max initially to first value in tensor
85 float min = tensor_data[0];
86 float max = tensor_data[0];
87
88 // Loop over rest of tensor and update min/max if necessary
89 for (unsigned int val = 1; val < tensor_size; val++)
90 {
91 if (tensor_data[val] < min)
92 {
93 min = tensor_data[val];
94 }
95 else if (tensor_data[val] > max)
96 {
97 max = tensor_data[val];
98 }
99 }
100
101 tensorHandle->Unmap();
102
103 return std::make_pair(min, max);
104}
105
Mike Kelly0e3fe102023-01-23 19:32:06 +0000106TensorShape ReduceDims(const TensorShape& tensorShape, unsigned int dimensions)
107{
108 if (tensorShape.GetNumDimensions() <= dimensions)
109 {
110 return tensorShape;
111 }
112 std::vector<unsigned int> newShape;
113
114 unsigned int dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
115 unsigned int dimsSkipped = 0;
116 bool insertRemainder = false;
117
118 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
119 {
120 if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
121 {
122 ++dimsSkipped;
123 continue;
124 }
125 newShape.push_back(tensorShape[i]);
126 // Once we insert the first dimension we can't skip any more
127 insertRemainder = true;
128 }
129 return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
130}
131
132TensorInfo ReduceDims(const TensorInfo& tensorInfo, unsigned int dimensions)
133{
134 TensorInfo strippedTensor(tensorInfo);
135 TensorShape strippedShape = ReduceDims(tensorInfo.GetShape(), dimensions);
136 strippedTensor.SetShape(strippedShape);
137 return strippedTensor;
138}
139
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000140TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100141{
142 unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
143
Matthew Sloyan0663d662020-09-14 11:47:26 +0100144 if (axis < -armnn::numeric_cast<int>(outputDim) || axis > armnn::numeric_cast<int>(tensorShape.GetNumDimensions()))
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100145 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100146 throw InvalidArgumentException(fmt::format("Invalid expansion axis {} for {}D input tensor. {}",
147 axis,
148 tensorShape.GetNumDimensions(),
149 CHECK_LOCATION().AsString()));
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100150 }
151
152 if (axis < 0)
153 {
Matthew Sloyan0663d662020-09-14 11:47:26 +0100154 axis = armnn::numeric_cast<int>(outputDim) + axis;
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100155 }
156
157 std::vector<unsigned int> outputShape;
Colm Donelan5b5c2222020-09-09 12:48:16 +0100158 outputShape.reserve(tensorShape.GetNumDimensions());
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100159 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
160 {
161 outputShape.push_back(tensorShape[i]);
162 }
163 outputShape.insert(outputShape.begin() + axis, 1);
164
Mike Kelly0506ef02023-01-03 16:29:44 +0000165 return { outputDim, outputShape.data() };
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100166}
167
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000168TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank)
169{
170 // Can't expand if rank is smaller than current shape
171 if (tensorShape.GetNumDimensions() >= rank)
172 {
173 return tensorShape;
174 }
175
176 std::vector<unsigned int> newShape;
177
178 // First add 1s to the beginning of the tensorInfo to fill in the space
179 for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i)
180 {
181 newShape.push_back(1);
182 }
183
184 // Then iterate through the original shape and append it to the new shape with the added 1s
185 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
186 {
187 newShape.push_back(tensorShape[i]);
188 }
189
190 return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
191}
192
Mike Kelly80512b02022-05-16 23:10:42 +0100193std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape)
194{
Mike Kelly80512b02022-05-16 23:10:42 +0100195 std::vector<unsigned int> squeezedDims;
196
197 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
198 {
199 if (tensorShape[i] != 1)
200 {
201 squeezedDims.push_back(tensorShape[i]);
Mike Kelly80512b02022-05-16 23:10:42 +0100202 }
203 }
204 return squeezedDims;
205}
206
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000207unsigned int GetNumElementsBetween(const TensorShape& shape,
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100208 const unsigned int firstAxisInclusive,
209 const unsigned int lastAxisExclusive)
210{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100211 ARMNN_ASSERT(firstAxisInclusive <= lastAxisExclusive);
212 ARMNN_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100213 unsigned int count = 1;
214 for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
215 {
216 count *= shape[i];
217 }
218 return count;
219}
220
221unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
222{
Matthew Sloyan0663d662020-09-14 11:47:26 +0100223 ARMNN_ASSERT_MSG(axis < armnn::numeric_cast<int>(inputDimension),
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100224 "Required axis index greater than number of dimensions.");
Matthew Sloyan0663d662020-09-14 11:47:26 +0100225 ARMNN_ASSERT_MSG(axis >= -armnn::numeric_cast<int>(inputDimension),
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100226 "Required axis index lower than negative of the number of dimensions");
227
228 unsigned int uAxis = axis < 0 ?
Matthew Sloyan0663d662020-09-14 11:47:26 +0100229 inputDimension - armnn::numeric_cast<unsigned int>(abs(axis))
230 : armnn::numeric_cast<unsigned int>(axis);
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100231 return uAxis;
232}
233
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000234unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, unsigned int axis)
235{
236 unsigned int numDim = shape.GetNumDimensions();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100237 ARMNN_ASSERT(axis <= numDim - 1);
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000238 unsigned int count = 1;
Jan Eilers53ef7952021-06-02 12:01:25 +0100239 for (unsigned int i = axis+1; i < numDim; i++)
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000240 {
241 count *= shape[i];
242 }
243 return count;
244}
245
246std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
247{
248 const std::vector<float>& scales = info.GetQuantizationScales();
249 armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000250 if (!info.HasPerAxisQuantization())
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000251 {
252 throw armnn::InvalidArgumentException(
253 std::string("Per-axis quantization params not set for tensor of type ") +
254 armnn::GetDataTypeName(info.GetDataType()), CHECK_LOCATION());
255 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100256 unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()) ;
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000257
258 return { axisFactor, scales };
259}
260
Mike Kelly0506ef02023-01-03 16:29:44 +0000261template<typename PrimitiveType>
262void CheckSizes(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo, unsigned int size = 1)
263{
264 if (data.size() / size != tensorInfo.GetNumElements())
265 {
266 throw InvalidArgumentException(
267 fmt::format("The data does not contain the expected number of elements {} != {}. {}",
268 data.size(), tensorInfo.GetNumElements(), CHECK_LOCATION().AsString()));
269 }
270}
271
272template<typename PrimitiveType>
273std::unique_ptr<float[]> ToFloatArray(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo)
274{
275 CheckSizes(data, tensorInfo);
276
277 std::unique_ptr<float[]> returnBuffer(new float[tensorInfo.GetNumElements()]);
278
279 if (tensorInfo.HasPerAxisQuantization())
280 {
281 unsigned int axis = tensorInfo.GetQuantizationDim().value();
282 auto axisDimensionality = tensorInfo.GetShape()[axis];
283 auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis);
284
285 for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
286 {
287 unsigned int axisIndex;
288
289 if (i < axisFactor)
290 {
291 axisIndex = 0;
292 }
293 else
294 {
295 axisIndex = (i / axisFactor) % axisDimensionality;
296 }
297 returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
298 tensorInfo.GetQuantizationScales()[axisIndex],
299 tensorInfo.GetQuantizationOffset());
300 }
301 }
302 else
303 {
304 for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
305 {
306 returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
307 tensorInfo.GetQuantizationScale(),
308 tensorInfo.GetQuantizationOffset());
309 }
310 }
311 return returnBuffer;
312}
313
314std::unique_ptr<float[]> ToFloatArray(const std::vector<uint8_t>& data, const armnn::TensorInfo& tensorInfo)
315{
316 if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8)
317 {
318 CheckSizes(data, tensorInfo);
319 std::vector<int8_t> buffer(tensorInfo.GetNumElements());
320 ::memcpy(buffer.data(), data.data(), data.size());
321 return ToFloatArray<int8_t>(buffer, tensorInfo);
322 }
323 else if (tensorInfo.GetDataType() == DataType::QAsymmU8)
324 {
325 CheckSizes(data, tensorInfo);
326 return ToFloatArray<uint8_t>(data, tensorInfo);
327 }
328 else if (tensorInfo.GetDataType() == DataType::Signed32)
329 {
330 CheckSizes(data, tensorInfo, 4);
331 std::vector<int32_t> buffer(tensorInfo.GetNumElements());
332 ::memcpy(buffer.data(), data.data(), data.size());
333 return ToFloatArray<int32_t>(buffer, tensorInfo);
334 }
335 else if (tensorInfo.GetDataType() == DataType::Signed64)
336 {
337 CheckSizes(data, tensorInfo, 8);
338 std::vector<int64_t> buffer(tensorInfo.GetNumElements());
339 ::memcpy(buffer.data(), data.data(), data.size());
340 return ToFloatArray<int64_t>(buffer, tensorInfo);
341 }
342 throw InvalidArgumentException(
343 fmt::format("Unsupported datatype {}. {}",
344 GetDataTypeName(tensorInfo.GetDataType()),
345 CHECK_LOCATION().AsString()));
346}
347
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000348} // namespace armnnUtils