blob: e65c4ad35f2bdffcb94df0e05cb5d8d78f0dac56 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "ArmComputeTensorUtils.hpp"
6#include "ArmComputeUtils.hpp"
7
Francis Murtagh351d13d2018-09-24 15:01:18 +01008#include "armnn/Exceptions.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Descriptors.hpp>
10
11namespace armnn
12{
13namespace armcomputetensorutils
14{
15
16arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
17{
18 switch(dataType)
19 {
telsoa01c577f2c2018-08-31 09:22:23 +010020 case armnn::DataType::Float16:
21 return arm_compute::DataType::F16;
telsoa014fcda012018-03-09 14:13:49 +000022 case armnn::DataType::Float32:
telsoa014fcda012018-03-09 14:13:49 +000023 return arm_compute::DataType::F32;
telsoa014fcda012018-03-09 14:13:49 +000024 case armnn::DataType::QuantisedAsymm8:
telsoa014fcda012018-03-09 14:13:49 +000025 return arm_compute::DataType::QASYMM8;
telsoa014fcda012018-03-09 14:13:49 +000026 case armnn::DataType::Signed32:
telsoa014fcda012018-03-09 14:13:49 +000027 return arm_compute::DataType::S32;
telsoa014fcda012018-03-09 14:13:49 +000028 default:
telsoa014fcda012018-03-09 14:13:49 +000029 BOOST_ASSERT_MSG(false, "Unknown data type");
30 return arm_compute::DataType::UNKNOWN;
telsoa014fcda012018-03-09 14:13:49 +000031 }
32}
33
34arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
35{
36 arm_compute::TensorShape shape;
37
telsoa01c577f2c2018-08-31 09:22:23 +010038 // armnn tensors are (batch, channels, height, width).
39 // arm_compute tensors are (width, height, channels, batch).
telsoa014fcda012018-03-09 14:13:49 +000040 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
41 {
telsoa01c577f2c2018-08-31 09:22:23 +010042 // Note that our dimensions are stored in the opposite order to ACL's.
telsoa014fcda012018-03-09 14:13:49 +000043 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
44
45 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
telsoa01c577f2c2018-08-31 09:22:23 +010046 // arm_compute tensors expect this.
telsoa014fcda012018-03-09 14:13:49 +000047 }
48
49 // prevent arm_compute issue where tensor is flattened to nothing
50 if (shape.num_dimensions() == 0)
51 {
52 shape.set_num_dimensions(1);
53 }
54
55 return shape;
56}
57
58// Utility function used to build a TensorInfo object, that can be used to initialise
59// ARM Compute Tensor and CLTensor allocators.
60arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
61{
62 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
63 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
64 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
65 tensorInfo.GetQuantizationOffset());
66
67 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
68}
69
Francis Murtagh351d13d2018-09-24 15:01:18 +010070arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
71{
72 switch(dataLayout)
73 {
74 case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
75
76 case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
77
78 default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
79 std::to_string(static_cast<int>(dataLayout)) + "]");
80 }
81}
82
83arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
84 armnn::DataLayout dataLayout)
85{
86 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
87 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
88 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
89 tensorInfo.GetQuantizationOffset());
90
91 arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
92 clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
93
94 return clTensorInfo;
95}
96
telsoa014fcda012018-03-09 14:13:49 +000097arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
98{
99 using arm_compute::PoolingType;
100 using arm_compute::DimensionRoundingType;
101 using arm_compute::PadStrideInfo;
102 using arm_compute::PoolingLayerInfo;
surmeh01bceff2f2018-03-29 16:29:27 +0100103 using arm_compute::Size2D;
telsoa014fcda012018-03-09 14:13:49 +0000104
telsoa01c577f2c2018-08-31 09:22:23 +0100105 // Resolve ARM Compute layer parameters.
telsoa014fcda012018-03-09 14:13:49 +0000106 const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
telsoa01c577f2c2018-08-31 09:22:23 +0100107
108 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
109 //use specific constructor if global pooling
110 if(isGlobalPooling)
111 {
112 return arm_compute::PoolingLayerInfo(poolingType);
113 }
114
telsoa014fcda012018-03-09 14:13:49 +0000115 const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
116 descriptor.m_OutputShapeRounding);
telsoa014fcda012018-03-09 14:13:49 +0000117 const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
118 descriptor.m_StrideY,
119 descriptor.m_PadLeft,
120 descriptor.m_PadRight,
121 descriptor.m_PadTop,
122 descriptor.m_PadBottom,
123 rounding);
124
125 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
126
surmeh01bceff2f2018-03-29 16:29:27 +0100127 const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
128
129 return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
telsoa014fcda012018-03-09 14:13:49 +0000130}
131
132arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
133{
134 const arm_compute::NormType normType =
135 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
136 return arm_compute::NormalizationLayerInfo(normType,
137 descriptor.m_NormSize,
138 descriptor.m_Alpha,
139 descriptor.m_Beta,
140 descriptor.m_K,
141 false);
142}
143
144arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
145{
146 arm_compute::PermutationVector aclPerm;
147
148 unsigned int start = 0;
surmeh01bceff2f2018-03-29 16:29:27 +0100149 while ((start < perm.GetSize()) && (start == perm[start]))
telsoa014fcda012018-03-09 14:13:49 +0000150 {
151 ++start;
152 }
153
154 for (unsigned int i = start; i < perm.GetSize(); ++i)
155 {
156 aclPerm.set(i - start, perm[i] - start);
157 }
158
159 return aclPerm;
160}
161
162} // namespace armcomputetensorutils
163} // namespace armnn