blob: cfd2e0e110c32d7a3d149a76924e34fd2e334b4c [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlina52bca22024-02-01 17:36:48 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Colm Donelanb4ef1632024-02-01 15:00:43 +00005
6#include <armnn/Exceptions.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <aclCommon/ArmComputeTensorUtils.hpp>
8#include <aclCommon/ArmComputeUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
Teresa Charlin6bc85252022-12-06 20:43:06 +000010#include "ArmComputeUtils.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Descriptors.hpp>
12
Cathal Corbett4b19d222022-05-11 20:12:17 +010013#include <fmt/format.h>
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17namespace armcomputetensorutils
18{
19
Derek Lambertid466a542020-01-22 15:37:29 +000020arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
telsoa014fcda012018-03-09 14:13:49 +000021{
22 switch(dataType)
23 {
Narumol Prangnawarat250d3922020-03-30 16:11:04 +010024 case armnn::DataType::BFloat16:
25 return arm_compute::DataType::BFLOAT16;
Mike Kelly130ec602019-11-08 12:08:35 +000026 case armnn::DataType::Boolean:
27 return arm_compute::DataType::U8;
telsoa01c577f2c2018-08-31 09:22:23 +010028 case armnn::DataType::Float16:
29 return arm_compute::DataType::F16;
telsoa014fcda012018-03-09 14:13:49 +000030 case armnn::DataType::Float32:
telsoa014fcda012018-03-09 14:13:49 +000031 return arm_compute::DataType::F32;
Ryan OShea9add1202020-02-07 10:06:33 +000032 case armnn::DataType::QAsymmS8:
33 return arm_compute::DataType::QASYMM8_SIGNED;
Derek Lambertif90c56d2020-01-10 17:14:08 +000034 case armnn::DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000035 return arm_compute::DataType::QASYMM8;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case armnn::DataType::QSymmS16:
Aron Virginas-Tar7a3e2fe2019-06-27 18:54:47 +010037 return arm_compute::DataType::QSYMM16;
Inki Daed4619e22020-09-10 15:33:54 +090038 case armnn::DataType::Signed64:
39 return arm_compute::DataType::S64;
Finn Williamsfd271062019-12-04 14:27:27 +000040 case armnn::DataType::QSymmS8:
Derek Lambertid466a542020-01-22 15:37:29 +000041 {
42 return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
43 }
telsoa014fcda012018-03-09 14:13:49 +000044 case armnn::DataType::Signed32:
telsoa014fcda012018-03-09 14:13:49 +000045 return arm_compute::DataType::S32;
telsoa014fcda012018-03-09 14:13:49 +000046 default:
telsoa014fcda012018-03-09 14:13:49 +000047 return arm_compute::DataType::UNKNOWN;
telsoa014fcda012018-03-09 14:13:49 +000048 }
49}
50
Cathal Corbettfd5bec42022-03-03 15:13:23 +000051armnn::DataType GetArmNNDataType(arm_compute::DataType dataType)
52{
53 switch(dataType)
54 {
55 case arm_compute::DataType::BFLOAT16:
56 return armnn::DataType::BFloat16;
57 case arm_compute::DataType::U8:
58 return armnn::DataType::Boolean;
59 case arm_compute::DataType::F16:
60 return armnn::DataType::Float16;
61 case arm_compute::DataType::F32:
62 return armnn::DataType::Float32;
63 case arm_compute::DataType::QASYMM8_SIGNED:
64 return armnn::DataType::QAsymmS8;
65 case arm_compute::DataType::QASYMM8:
66 return armnn::DataType::QAsymmU8;
67 case arm_compute::DataType::QSYMM16:
68 return armnn::DataType::QSymmS16;
69 case arm_compute::DataType::S64:
70 return armnn::DataType::Signed64;
71 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
72 return armnn::DataType::QSymmS8;
73 case arm_compute::DataType::QSYMM8:
74 return armnn::DataType::QSymmS8;
75 case arm_compute::DataType::S32:
76 return armnn::DataType::Signed32;
77 default:
Colm Donelanb4ef1632024-02-01 15:00:43 +000078 throw InvalidArgumentException("Unknown arm_compute::DataType data type");
Cathal Corbettfd5bec42022-03-03 15:13:23 +000079 }
80}
81
Matthew Benthamfd899962018-12-31 15:49:42 +000082arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
83 unsigned int originalInputRank,
84 const std::vector<unsigned int>& armnnAxes)
85{
86 arm_compute::Coordinates outAclCoords;
87
88 if (armnnAxes.empty())
89 {
90 // If no reduction axes were provided, then the input must be reduced along all dimensions.
91 // Since Compute Library does not accept an empty vector as the reduction dimensions, we then
92 // manually create a vector including all the input dimensions (in reversed order) as:
93 //
94 // { inputDimensions - 1, inputDimensions - 2, ..., 1, 0 }
95 //
96 outAclCoords.set_num_dimensions(inputDimensions);
97 std::generate(outAclCoords.begin(), outAclCoords.end(), [d = inputDimensions - 1] () mutable { return d--; });
98 }
99 else
100 {
101 // Create a vector of reduction dimensions (in reversed order) with the given reduction axes.
102 //
103 // Adjust the given reduction axes according to the original rank of the input tensor (before ACL applied any
104 // dimension correction).
105 // For example, if the input tensor originally had 4 dimensions, and one of the reduction axes was 2, then the
106 // new value for that reduction axis should be 1.
107 //
108 // Example:
109 // ArmNN input shape = { 1, 1, 3, 2 } -> ACL input shape = { 2, 3 }
110 // ArmNN reduction axis = { 2 } -> ACL reduction axis = { 1 }
111 // ArmNN reduction axis = { 3 } -> ACL reduction axis = { 0 }
112 //
113 // The transformation: ACL reduction axis index = original rank - ArmNN reduction axis index - 1
114 //
115 outAclCoords.set_num_dimensions(armnnAxes.size());
116 std::transform(armnnAxes.begin(), armnnAxes.end(),
117 outAclCoords.begin(),
118 [originalInputRank](unsigned int i){ return originalInputRank - i - 1; });
119 }
120
121 return outAclCoords;
122}
123
telsoa014fcda012018-03-09 14:13:49 +0000124arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
125{
126 arm_compute::TensorShape shape;
127
telsoa01c577f2c2018-08-31 09:22:23 +0100128 // armnn tensors are (batch, channels, height, width).
129 // arm_compute tensors are (width, height, channels, batch).
telsoa014fcda012018-03-09 14:13:49 +0000130 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
131 {
telsoa01c577f2c2018-08-31 09:22:23 +0100132 // Note that our dimensions are stored in the opposite order to ACL's.
Matthew Bentham89105282018-11-20 14:33:33 +0000133 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i], false);
telsoa014fcda012018-03-09 14:13:49 +0000134
135 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
telsoa01c577f2c2018-08-31 09:22:23 +0100136 // arm_compute tensors expect this.
telsoa014fcda012018-03-09 14:13:49 +0000137 }
138
139 // prevent arm_compute issue where tensor is flattened to nothing
140 if (shape.num_dimensions() == 0)
141 {
142 shape.set_num_dimensions(1);
143 }
144
145 return shape;
146}
147
Mike Kelly0e3fe102023-01-23 19:32:06 +0000148std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions)
149{
150 std::vector<unsigned int> newShape;
151
152 unsigned int dimsToSkip = 0;
153
154 if (tensorShape.GetNumDimensions() > dimensions)
155 {
156 dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
157 }
158 unsigned int dimsSkipped = 0;
159 bool insertRemainder = false;
160
161 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
162 {
163 if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
164 {
165 ++dimsSkipped;
166 continue;
167 }
168 newShape.insert(newShape.begin(), tensorShape[i]);
169 // Once we insert the first dimension we can't skip any more
170 insertRemainder = true;
171 }
172 return newShape;
173}
174
175arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions)
176{
177 arm_compute::TensorShape shape;
178 std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions);
179
180 for (unsigned int i = 0; i < strippedShape.size(); i++)
181 {
182 shape.set(i, strippedShape[i], false);
183 }
184
185 // prevent arm_compute issue where tensor is flattened to nothing
186 if (shape.num_dimensions() == 0)
187 {
188 shape.set_num_dimensions(1);
189 }
190 return shape;
191}
192
telsoa014fcda012018-03-09 14:13:49 +0000193// Utility function used to build a TensorInfo object, that can be used to initialise
194// ARM Compute Tensor and CLTensor allocators.
Cathal Corbett4452baf2022-05-13 09:55:59 +0100195// Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors
196// default to constant but Arm NN ones default to non constant. In the cases where
197// we expect ACL to treat a tensor as constant that value must be set after this
198// utility has been called.
telsoa014fcda012018-03-09 14:13:49 +0000199arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
200{
Derek Lambertid466a542020-01-22 15:37:29 +0000201 bool multiScales = tensorInfo.HasMultipleQuantizationScales();
telsoa014fcda012018-03-09 14:13:49 +0000202 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
Derek Lambertid466a542020-01-22 15:37:29 +0000203 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000204
Derek Lambertid466a542020-01-22 15:37:29 +0000205 const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000206 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
207 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
telsoa014fcda012018-03-09 14:13:49 +0000208
209 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
210}
211
Francis Murtagh351d13d2018-09-24 15:01:18 +0100212arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
213 armnn::DataLayout dataLayout)
214{
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000215 arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
216 aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
Francis Murtagh351d13d2018-09-24 15:01:18 +0100217
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +0000218 return aclTensorInfo;
Francis Murtagh351d13d2018-09-24 15:01:18 +0100219}
220
Mike Kelly0e3fe102023-01-23 19:32:06 +0000221arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions)
222{
223 bool multiScales = tensorInfo.HasMultipleQuantizationScales();
224 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions);
225 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
226
227 const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
228 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
229 arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
230
231 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
232}
233arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
234 armnn::DataLayout dataLayout, unsigned int dimensions)
235{
236 arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions);
237 aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
238
239 return aclTensorInfo;
240}
241
242
Matteo Martincigh747ef822018-12-18 09:26:39 +0000243arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
244{
245 switch(dataLayout)
246 {
247 case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
248
249 case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
250
Teresa Charlinec5f7d12021-10-22 17:15:00 +0100251 case armnn::DataLayout::NDHWC : return arm_compute::DataLayout::NDHWC;
252
253 case armnn::DataLayout::NCDHW : return arm_compute::DataLayout::NCDHW;
254
Matteo Martincigh747ef822018-12-18 09:26:39 +0000255 default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
256 std::to_string(static_cast<int>(dataLayout)) + "]");
257 }
258}
259
Sadik Armagana3600ba2019-10-10 10:43:20 +0100260arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
261 bool fpMixedPrecision)
telsoa014fcda012018-03-09 14:13:49 +0000262{
telsoa01c577f2c2018-08-31 09:22:23 +0100263 // Resolve ARM Compute layer parameters.
Ryan OSheabab8fa92022-03-09 10:29:02 +0000264 const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
telsoa01c577f2c2018-08-31 09:22:23 +0100265
Ryan OSheabab8fa92022-03-09 10:29:02 +0000266 const arm_compute::DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout);
Teresa Charlinc809a292020-01-31 10:21:44 +0000267
telsoa01c577f2c2018-08-31 09:22:23 +0100268 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
269 //use specific constructor if global pooling
270 if(isGlobalPooling)
271 {
Teresa Charlinc809a292020-01-31 10:21:44 +0000272 return arm_compute::PoolingLayerInfo(poolingType, dataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +0100273 }
274
Ryan OSheabab8fa92022-03-09 10:29:02 +0000275 const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
telsoa014fcda012018-03-09 14:13:49 +0000276 descriptor.m_OutputShapeRounding);
Ryan OSheabab8fa92022-03-09 10:29:02 +0000277 const arm_compute::PadStrideInfo padStrideInfo(descriptor.m_StrideX,
Teresa Charlina52bca22024-02-01 17:36:48 +0000278 descriptor.m_StrideY,
279 descriptor.m_PadLeft,
280 descriptor.m_PadRight,
281 descriptor.m_PadTop,
282 descriptor.m_PadBottom,
283 rounding);
telsoa014fcda012018-03-09 14:13:49 +0000284
285 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
286
Ryan OSheabab8fa92022-03-09 10:29:02 +0000287 const arm_compute::Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
surmeh01bceff2f2018-03-29 16:29:27 +0100288
Teresa Charlinc809a292020-01-31 10:21:44 +0000289 return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding,
290 fpMixedPrecision);
telsoa014fcda012018-03-09 14:13:49 +0000291}
292
Ryan OSheabab8fa92022-03-09 10:29:02 +0000293arm_compute::Pooling3dLayerInfo BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor& descriptor,
294 bool fpMixedPrecision)
295{
296 const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
297
298 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0 && descriptor.m_StrideZ==0);
299 //use specific constructor if global pooling
300 if(isGlobalPooling)
301 {
302 return arm_compute::Pooling3dLayerInfo(poolingType);
303 }
304
305 const arm_compute::Size3D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight, descriptor.m_PoolDepth);
306
307 const arm_compute::Size3D stride(descriptor.m_StrideX,
308 descriptor.m_StrideY,
309 descriptor.m_StrideZ);
310
311 const arm_compute::Padding3D padding(descriptor.m_PadLeft,
312 descriptor.m_PadRight,
313 descriptor.m_PadTop,
314 descriptor.m_PadBottom,
315 descriptor.m_PadFront,
316 descriptor.m_PadBack);
317
318 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
319
320 const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
321 descriptor.m_OutputShapeRounding);
322
323 return arm_compute::Pooling3dLayerInfo(poolingType,
324 poolSize,
325 stride,
326 padding,
327 excludePadding,
328 fpMixedPrecision,
329 rounding);
330}
331
telsoa014fcda012018-03-09 14:13:49 +0000332arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
333{
334 const arm_compute::NormType normType =
335 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
336 return arm_compute::NormalizationLayerInfo(normType,
337 descriptor.m_NormSize,
338 descriptor.m_Alpha,
339 descriptor.m_Beta,
340 descriptor.m_K,
341 false);
342}
343
344arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
345{
346 arm_compute::PermutationVector aclPerm;
347
348 unsigned int start = 0;
surmeh01bceff2f2018-03-29 16:29:27 +0100349 while ((start < perm.GetSize()) && (start == perm[start]))
telsoa014fcda012018-03-09 14:13:49 +0000350 {
351 ++start;
352 }
353
354 for (unsigned int i = start; i < perm.GetSize(); ++i)
355 {
356 aclPerm.set(i - start, perm[i] - start);
357 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000358 return aclPerm;
359}
telsoa014fcda012018-03-09 14:13:49 +0000360
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000361arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm)
362{
Teresa Charlin6bc85252022-12-06 20:43:06 +0000363 // As ArmNN indexes are left to right and ACL indexes are right to left,
364 // the permutation vector has to be reversed and then translated into ACL axis.
365 // i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} --> {0, 1, 3, 2}
366
367 // Below an example of how the ArmNN and ACL index format work:
368 // ArmNN Format:
369 // Input Shape {1, 10, 20, 30}
370 // Permutation Vector {1, 0, 2, 3}
371 // Output Shape {10, 1, 20, 30}
372 // dim "1" of input goes into index 0 of the output ([ 10, X, X, X])
373 // dim "0" of input goes into index 1 of the output ([ 10, 1, X, X ])
374 // dim "2" of input goes into index 2 of the output ([ 10, 1, 20, X ])
375 // dim "3" of input goes into index 3 of the output ([ 10, 1, 20, 30 ])
376 // ACL Format:
377 // Input Shape {30, 20, 10, 1}
378 // Permutation Vector {0, 1, 3, 2}
379 // Output Shape {30, 20, 1, 10}
380 // dim "0" of input goes into index 0 of the output ([ 30, X, X, X])
381 // dim "1" of input goes into index 1 of the output ([ 30, 20, X, X ])
382 // dim "3" of input goes into index 2 of the output ([ 30, 20, 1, X ])
383 // dim "2" of input goes into index 3 of the output ([ 30, 20, 1, 10 ])
384
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000385 arm_compute::PermutationVector aclPerm;
Teresa Charlin6bc85252022-12-06 20:43:06 +0000386 auto rank = perm.GetSize();
387
388 // Reverse the order. i.e. {1, 0, 2, 3} --> {3, 2, 0, 1}
389 std::vector<unsigned int> reversedPerm;
390 reversedPerm.reserve(rank);
391 for (unsigned int i = rank; i > 0; --i)
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000392 {
Teresa Charlin6bc85252022-12-06 20:43:06 +0000393 reversedPerm.push_back(perm[i-1]);
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000394 }
395
Teresa Charlin6bc85252022-12-06 20:43:06 +0000396 // Translate from Arm NN axis to ACL axis. i.e. {3, 2, 0, 1} --> {0, 1, 3, 2}
397 for (unsigned int i = 0; i < rank; ++i)
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000398 {
Teresa Charlin6bc85252022-12-06 20:43:06 +0000399 auto aclAxis = rank - 1 - reversedPerm[i];
400 aclPerm.set(i, aclAxis);
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000401 }
telsoa014fcda012018-03-09 14:13:49 +0000402 return aclPerm;
403}
404
Sadik Armaganf4464322018-12-20 16:19:12 +0000405arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height)
406{
407 return arm_compute::Size2D(width, height);
408}
409
Kevin May263d7092022-11-29 14:34:48 +0000410arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo, float value)
Mike Kelly0a08ec62019-07-25 08:39:31 +0100411{
Matthew Sloyan2e5d0b22021-10-21 14:05:31 +0100412 switch (tensorInfo->data_type())
Mike Kelly0a08ec62019-07-25 08:39:31 +0100413 {
Mike Kelly0a08ec62019-07-25 08:39:31 +0100414 case arm_compute::DataType::F16:
Kevin May263d7092022-11-29 14:34:48 +0000415 {
416 arm_compute::PixelValue pixelValue = arm_compute::PixelValue(static_cast<Half>(value));
417 if (isinf(pixelValue.get<Half>())) {
418 throw InvalidArgumentException("Under/Overflow converting float value [" + std::to_string(value) +
419 "] to fp16: [" + std::to_string(pixelValue.get<Half>()) + "]");
420 }
421 return pixelValue;
422 }
Mike Kelly0a08ec62019-07-25 08:39:31 +0100423 case arm_compute::DataType::F32:
Kevin May263d7092022-11-29 14:34:48 +0000424 return arm_compute::PixelValue(value);
Mike Kelly130ec602019-11-08 12:08:35 +0000425 case arm_compute::DataType::QASYMM8:
Kevin May263d7092022-11-29 14:34:48 +0000426 return arm_compute::PixelValue(static_cast<uint8_t>(value));
Mike Kelly130ec602019-11-08 12:08:35 +0000427 case arm_compute::DataType::QSYMM16:
Kevin May263d7092022-11-29 14:34:48 +0000428 return arm_compute::PixelValue(static_cast<int16_t>(value));
Tamas Nyirid3065d72021-11-12 11:22:50 +0000429 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100430 case arm_compute::DataType::QASYMM8_SIGNED:
Mike Kelly130ec602019-11-08 12:08:35 +0000431 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
Kevin May263d7092022-11-29 14:34:48 +0000432 return arm_compute::PixelValue(static_cast<int8_t>(value));
Sadik Armagana792a052020-06-23 16:22:23 +0100433 case arm_compute::DataType::S32:
Kevin May263d7092022-11-29 14:34:48 +0000434 return arm_compute::PixelValue(static_cast<int32_t>(value));
Mike Kelly0a08ec62019-07-25 08:39:31 +0100435 default:
436 throw InvalidArgumentException("Unsupported DataType: [" +
Matthew Sloyan2e5d0b22021-10-21 14:05:31 +0100437 std::to_string(static_cast<int>(tensorInfo->data_type())) + "]");
Mike Kelly0a08ec62019-07-25 08:39:31 +0100438 }
439}
440
Cathal Corbett4b19d222022-05-11 20:12:17 +0100441unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
442 const arm_compute::TensorShape& weightsShape,
443 const arm_compute::TensorShape& inputShape)
444{
445 unsigned int depthMultiplier;
446 if (layout == armnn::DataLayout::NHWC)
447 {
448 depthMultiplier = static_cast<uint32_t>(weightsShape[0]) / static_cast<uint32_t>(inputShape[0]);
449 }
450 else if (layout == armnn::DataLayout::NCHW)
451 {
452 depthMultiplier = static_cast<uint32_t>(weightsShape[2]) / static_cast<uint32_t>(inputShape[2]);
453 }
454 else
455 {
456 throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
457 GetDataLayoutName(layout)));
458 }
459 return depthMultiplier;
460}
461
Teresa Charlin21bda142024-03-13 16:10:32 +0000462arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor)
463{
464 arm_compute::ScatterFunction scatterFunction;
465 switch(descriptor.m_Function)
466 {
467 case ScatterNdFunction::Update:
468 scatterFunction = arm_compute::ScatterFunction::Update;
469 break;
470 case ScatterNdFunction::Add:
471 scatterFunction = arm_compute::ScatterFunction::Add;
472 break;
473 case ScatterNdFunction::Sub:
474 scatterFunction = arm_compute::ScatterFunction::Sub;
475 break;
476 case ScatterNdFunction::Max:
477 scatterFunction = arm_compute::ScatterFunction::Max;
478 break;
479 case ScatterNdFunction::Min:
480 scatterFunction = arm_compute::ScatterFunction::Min;
481 break;
482 default: throw InvalidArgumentException("Unknown ArmNN::ScatterNd Function: [" +
483 std::to_string(static_cast<int>(descriptor.m_Function)) + "]");
484 }
485
486 return arm_compute::ScatterInfo(scatterFunction, !descriptor.m_InputEnabled);
487}
telsoa014fcda012018-03-09 14:13:49 +0000488} // namespace armcomputetensorutils
489} // namespace armnn