blob: 7a9e46ce7d85b870e4e4c942cdce0809689ef9f1 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
David Beck111b5d92018-11-12 14:59:37 +000019#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000020
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
telsoa01c577f2c2018-08-31 09:22:23 +010024namespace
25{
Finn Williams3e54d032020-10-22 16:53:35 +010026using LayerList = std::list<Layer*>;
27using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010028
David Beck29c75de2018-10-23 13:35:58 +010029const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
30{
31 if (!type)
32 {
33 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010034 }
35
Matthew Sloyan81beae32021-07-13 19:46:11 +010036 return TensorInfo(info.GetShape(),
37 type.value(),
38 info.GetQuantizationScale(),
39 info.GetQuantizationOffset(),
40 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
Sadik Armagana097d2a2021-11-24 15:47:28 +000045inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
46{
47 if (!weightsType)
48 {
49 return weightsType;
50 }
51
52 switch(weightsType.value())
53 {
54 case armnn::DataType::BFloat16:
55 case armnn::DataType::Float16:
56 case armnn::DataType::Float32:
57 return weightsType;
58 case armnn::DataType::QAsymmS8:
59 case armnn::DataType::QAsymmU8:
60 case armnn::DataType::QSymmS8:
61 case armnn::DataType::QSymmS16:
62 return armnn::DataType::Signed32;
63 default:
64 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
65 }
66 return armnn::EmptyOptional();
67}
68
69
Sadik Armagan045f6be2020-09-10 13:37:32 +010070bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
71 const IConnectableLayer& connectableLayer,
72 Optional<DataType> dataType,
73 std::string& outReasonIfUnsupported,
74 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000075{
David Beck33f0ae02018-10-18 15:13:56 +010076 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000077 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010078 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010079
David Beck111b5d92018-11-12 14:59:37 +000080 auto const& backendRegistry = BackendRegistryInstance();
81 if (!backendRegistry.IsBackendRegistered(backendId))
82 {
83 std::stringstream ss;
84 ss << connectableLayer.GetName() << " is not supported on " << backendId
85 << " because this backend is not registered.";
86
87 outReasonIfUnsupported = ss.str();
88 return false;
89 }
90
91 auto backendFactory = backendRegistry.GetFactory(backendId);
92 auto backendObject = backendFactory();
Mike Kelly3ec30772023-03-08 13:47:17 +000093 auto layerSupport = backendObject->GetLayerSupport(modelOptions);
94 auto layerSupportObject = LayerSupportHandle(layerSupport, backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100101 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000112 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100113 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
114 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000116 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100117 OverrideDataType(input0, dataType),
118 OverrideDataType(input1, dataType),
119 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100120 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000121 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000122 break;
123 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100124 case LayerType::ArgMinMax:
125 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100126 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100127 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
128
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100129 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000131 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000133 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100134 descriptor,
135 reason);
136 break;
137 }
Samuel Yap6b478092022-07-06 15:36:03 +0100138 case LayerType::BatchMatMul:
139 {
140 auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
141 const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
142
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100143 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
144 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Samuel Yap6b478092022-07-06 15:36:03 +0100145 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
146 result = layerSupportObject.IsBatchMatMulSupported(
147 OverrideDataType(input0, dataType),
148 OverrideDataType(input1, dataType),
149 OverrideDataType(output, dataType),
150 descriptor,
151 reason);
152 break;
153 }
telsoa014fcda012018-03-09 14:13:49 +0000154 case LayerType::BatchNormalization:
155 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100156 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100157 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
159 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
160 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
161 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
162 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100164 OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 OverrideDataType(mean, dataType),
167 OverrideDataType(var, dataType),
168 OverrideDataType(beta, dataType),
169 OverrideDataType(gamma, dataType),
170 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100171 reason);
telsoa014fcda012018-03-09 14:13:49 +0000172 break;
173 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000174 case LayerType::BatchToSpaceNd:
175 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100176 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000177 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100178 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000179
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000180 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
181 OverrideDataType(output, dataType),
182 cLayer->GetParameters(),
183 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000184 break;
185 }
mathad01b392e982021-04-07 12:07:30 +0100186 case LayerType::Cast:
187 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100188 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
mathad01b392e982021-04-07 12:07:30 +0100189 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
190
191 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
192 OverrideDataType(output, dataType),
193 reason);
194 break;
195 }
Simon Obute51f67772021-09-03 15:50:13 +0100196 case LayerType::ChannelShuffle:
197 {
198 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
199
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100200 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
201 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Simon Obute51f67772021-09-03 15:50:13 +0100202
203 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
204
205 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
206 OverrideDataType(output, dataType),
207 descriptor,
208 reason);
209 break;
210 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100211 case LayerType::Comparison:
212 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100213 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100214
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100215 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
216 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100217 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
218
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000219 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
220 OverrideDataType(input1, dataType),
221 OverrideDataType(output, DataType::Boolean),
222 cLayer->GetParameters(),
223 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100224 break;
225 }
telsoa014fcda012018-03-09 14:13:49 +0000226 case LayerType::Constant:
227 {
228 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000229 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100230 break;
231 }
232 case LayerType::ConvertFp16ToFp32:
233 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100234 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100235 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000236 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100237 break;
238 }
239 case LayerType::ConvertFp32ToFp16:
240 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100241 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000243 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000244 break;
245 }
246 case LayerType::Convolution2d:
247 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100248 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100249
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100250 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
arovir01a6824102018-08-28 17:40:45 +0100251 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100252 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100253 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
254 "Convolution2dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100255 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100256 dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100257
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100258 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100259
arovir01a6824102018-08-28 17:40:45 +0100260 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100261 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100262 if (descriptor.m_BiasEnabled)
263 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100264 ARMNN_ASSERT_MSG(layer.GetInputSlot(2).GetConnection(),
265 "Convolution2dLayer: Bias should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100266 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100267 GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100268 }
269
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000270 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100271 input,
272 output,
273 descriptor,
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100274 weights,
arovir01a6824102018-08-28 17:40:45 +0100275 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100276 reason);
telsoa014fcda012018-03-09 14:13:49 +0000277 break;
278 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100279 case LayerType::Convolution3d:
280 {
281 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
282
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100283 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100284 dataType);
285 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100286
287 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
288 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100289 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100290 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100291
292 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
293
294 // Construct optional biases object based on the value of m_BiasEnabled
295 Optional<TensorInfo> biases;
296 if (descriptor.m_BiasEnabled)
297 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100298 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100299 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100300 }
301
302 result = layerSupportObject.IsConvolution3dSupported(
303 input,
304 output,
305 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100306 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100307 biases,
308 reason);
309 break;
310 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000311 case LayerType::Debug:
312 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100313 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000314 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
315
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000316 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000317 OverrideDataType(output, dataType),
318 reason);
319 break;
320 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100321 case LayerType::DepthToSpace:
322 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100323 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100324
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100325 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100326 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
327
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000328 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100329 OverrideDataType(output, dataType),
330 cLayer->GetParameters(),
331 reason);
332 break;
333 }
telsoa014fcda012018-03-09 14:13:49 +0000334 case LayerType::DepthwiseConvolution2d:
335 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100336 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100337 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100338 dataType);
339 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100340 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100341 dataType);
342
343 ARMNN_ASSERT(cLayer->GetInputSlot(1).GetConnection() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100344
telsoa01c577f2c2018-08-31 09:22:23 +0100345 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100346
347 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100348 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100349 if (descriptor.m_BiasEnabled)
350 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100351 biases = OverrideDataType(cLayer->GetInputSlot(2).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100352 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100353 }
telsoa01c577f2c2018-08-31 09:22:23 +0100354
Cathal Corbett06902652022-04-14 17:55:11 +0100355 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
356 output,
357 descriptor,
358 weights,
359 biases,
360 reason);
telsoa014fcda012018-03-09 14:13:49 +0000361 break;
362 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000363 case LayerType::Dequantize:
364 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100365 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000366 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
367
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000368 result = layerSupportObject.IsDequantizeSupported(input,
369 OverrideDataType(output, dataType),
370 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000371 break;
372 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000373 case LayerType::DetectionPostProcess:
374 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100375 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100376 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetTensorInfo();
377 const TensorInfo& scores = layer.GetInputSlot(1).GetTensorInfo();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000378 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
379
380 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
381 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
382 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
383 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
384
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000385 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000386 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
387 scores,
388 anchors,
389 detectionBoxes,
390 detectionClasses,
391 detectionScores,
392 numDetections,
393 descriptor,
394 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000395 break;
396 }
Mike Kelly3ec30772023-03-08 13:47:17 +0000397 case LayerType::ElementwiseBinary:
398 {
399 auto cLayer = PolymorphicDowncast<const ElementwiseBinaryLayer*>(&layer);
400
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100401 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
402 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Mike Kelly3ec30772023-03-08 13:47:17 +0000403 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
404 std::vector<TensorInfo> infos = { OverrideDataType(input0, dataType),
405 OverrideDataType(input1, dataType),
406 OverrideDataType(output, dataType) };
407 result = layerSupport->IsLayerSupported(LayerType::ElementwiseBinary,
408 infos,
409 cLayer->GetParameters(),
410 EmptyOptional(),
411 EmptyOptional(),
412 reason);
413 break;
414 }
josh minor4a3c6102020-01-06 16:40:46 -0600415 case LayerType::ElementwiseUnary:
416 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100417 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600418
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100419 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
josh minor4a3c6102020-01-06 16:40:46 -0600420 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
421
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000422 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
423 OverrideDataType(output, dataType),
424 cLayer->GetParameters(),
425 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600426 break;
427 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100428 case LayerType::Fill:
429 {
430 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100431 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Ryan OSheaec6c6802020-06-05 17:17:06 +0100432 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
433 const FillDescriptor& descriptor = cLayer->GetParameters();
434
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000435 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100436 OverrideDataType(input, dataType),
437 OverrideDataType(output, dataType),
438 descriptor,
439 reason);
440 break;
441 }
telsoa014fcda012018-03-09 14:13:49 +0000442 case LayerType::FakeQuantization:
443 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100444 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100445 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000446 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
447 cLayer->GetParameters(),
448 reason);
telsoa014fcda012018-03-09 14:13:49 +0000449 break;
450 }
451 case LayerType::Floor:
452 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100453 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000454 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000455 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
456 OverrideDataType(output, dataType),
457 reason);
telsoa014fcda012018-03-09 14:13:49 +0000458 break;
459 }
460 case LayerType::FullyConnected:
461 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100462 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100463 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100464 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000465
466 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
467 TensorInfo weightsInfo;
468 const TensorInfo* weightsInfoPtr = nullptr;
469
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100470 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000471 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100472
473 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000474 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000475 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100476 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
477 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
478 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
479
telsoa01c577f2c2018-08-31 09:22:23 +0100480 if (descriptor.m_BiasEnabled)
481 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100482 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(), dataType);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100483 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100484 }
485 else
486 {
487 // If biases are not enabled pass a dummy tensorinfo for the validation
488 switch(input.GetDataType())
489 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000490 case DataType::BFloat16:
491 {
492 biasInfoPtr = &dummyBFloat16Bias;
493 break;
494 }
telsoa01c577f2c2018-08-31 09:22:23 +0100495 case DataType::Float16:
496 {
497 biasInfoPtr = &dummyFloat16Bias;
498 break;
499 }
500 case DataType::Float32:
501 {
502 biasInfoPtr = &dummyFloat32Bias;
503 break;
504 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000505 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000506 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000507 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000508 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100509 {
510 biasInfoPtr = &dummyQA8Bias;
511 break;
512 }
513 default:
514 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100515 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100516 }
517 }
518 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000519 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100520 OverrideDataType(input, dataType),
521 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000522 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100523 *biasInfoPtr,
524 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100525 reason);
telsoa014fcda012018-03-09 14:13:49 +0000526 break;
527 }
narpra01b89b05f2019-01-16 09:53:09 +0000528 case LayerType::Gather:
529 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100530 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
531 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000532 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100533 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
534 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000535 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
536 input1,
537 OverrideDataType(output, dataType),
538 descriptor,
539 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000540 break;
541 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100542 case LayerType::GatherNd:
543 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100544 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
545 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100546 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
547 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
548 input1,
549 OverrideDataType(output, dataType),
550 reason);
551 break;
552 }
telsoa014fcda012018-03-09 14:13:49 +0000553 case LayerType::Input:
554 {
555 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000556 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000557 break;
558 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100559 case LayerType::InstanceNormalization:
560 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100561 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100562 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
563
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100564 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Mayce5045a2019-10-02 14:07:47 +0100565 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
566
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000567 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100568 OverrideDataType(input, dataType),
569 OverrideDataType(output, dataType),
570 descriptor,
571 reason);
572 break;
573 }
telsoa014fcda012018-03-09 14:13:49 +0000574 case LayerType::L2Normalization:
575 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100576 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100577 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
578
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100579 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100580 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100581
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000582 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100583 OverrideDataType(input, dataType),
584 OverrideDataType(output, dataType),
585 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100586 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100587 break;
588 }
James Conroyaba90cd2020-11-06 16:28:18 +0000589 case LayerType::LogicalBinary:
590 {
591 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
592
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100593 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
594 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
James Conroyaba90cd2020-11-06 16:28:18 +0000595 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
596
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000597 result = layerSupportObject.IsLogicalBinarySupported(input0,
598 input1,
599 output,
600 cLayer->GetParameters(),
601 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000602 break;
603 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100604 case LayerType::LogSoftmax:
605 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100606 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100607
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100608 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100609 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
610
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000611 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
612 OverrideDataType(output, dataType),
613 cLayer->GetParameters(),
614 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100615 break;
616 }
telsoa01c577f2c2018-08-31 09:22:23 +0100617 case LayerType::Lstm:
618 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100619 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100620 const LstmDescriptor& descriptor = cLayer->GetParameters();
621
622 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100623 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100624 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100625 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100626 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100627 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100628 dataType);
629 // All outputs
630 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
631 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
632 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
633 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
634
635 // Basic parameters
636 const TensorInfo& inputToForgetWeights
637 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
638 const TensorInfo& inputToCellWeights
639 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
640 const TensorInfo& inputToOutputWeights
641 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
642 const TensorInfo& recurrentToForgetWeights
643 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
644 const TensorInfo& recurrentToCellWeights
645 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
646 const TensorInfo& recurrentToOutputWeights
647 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
648 const TensorInfo& forgetGateBias
649 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
650 const TensorInfo& cellBias
651 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
652 const TensorInfo& outputGateBias
653 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
654
Jan Eilersd01a83c2019-07-03 18:20:40 +0100655 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100656
Jan Eilersd01a83c2019-07-03 18:20:40 +0100657 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
658 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
659 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
660 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
661 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
662 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
663 paramsInfo.m_ForgetGateBias = &forgetGateBias;
664 paramsInfo.m_CellBias = &cellBias;
665 paramsInfo.m_OutputGateBias = &outputGateBias;
666
667
668 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100669 TensorInfo optInputToInputWeights;
670 TensorInfo optRecurrentToInputWeights;
671 TensorInfo optCellToInputWeights;
672 TensorInfo optInputGateBias;
673 TensorInfo optProjectionWeights;
674 TensorInfo optProjectionBias;
675 TensorInfo optCellToForgetWeights;
676 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100677 TensorInfo optInputLayerNormWeights;
678 TensorInfo optForgetLayerNormWeights;
679 TensorInfo optCellLayerNormWeights;
680 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100681
682 if(!descriptor.m_CifgEnabled)
683 {
684 optInputToInputWeights =
685 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100686 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100687
688 optRecurrentToInputWeights =
689 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100690 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100691 optInputGateBias =
692 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100693 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100694 }
695
696 if(descriptor.m_ProjectionEnabled)
697 {
698 optProjectionWeights =
699 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100700 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100701 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
702 {
703 optProjectionBias =
704 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100705 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100706 }
707 }
708
709 if(descriptor.m_PeepholeEnabled)
710 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100711 if(!descriptor.m_CifgEnabled)
712 {
713 optCellToInputWeights =
714 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
715 dataType);
716 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
717 }
telsoa01c577f2c2018-08-31 09:22:23 +0100718 optCellToForgetWeights =
719 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100720 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100721 optCellToOutputWeights =
722 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100723 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100724 }
725
Jan Eilers38e05bd2019-06-26 13:10:09 +0100726 if(descriptor.m_LayerNormEnabled)
727 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100728 if (!descriptor.m_CifgEnabled)
729 {
730 optInputLayerNormWeights = OverrideDataType(
731 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
732 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
733 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100734
735 optForgetLayerNormWeights = OverrideDataType(
736 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100737 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100738
739 optCellLayerNormWeights = OverrideDataType(
740 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100741 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100742
743 optOutputLayerNormWeights = OverrideDataType(
744 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100745 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100746 }
747
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000748 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100749 input,
750 outputStateIn,
751 cellStateIn,
752 scratchBuffer,
753 outputStateOut,
754 cellStateOut,
755 output,
756 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100757 paramsInfo,
758 reason);
telsoa014fcda012018-03-09 14:13:49 +0000759 break;
760 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000761 case LayerType::Maximum:
762 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000763 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100764 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
765 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000766 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
767
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000768 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
769 OverrideDataType(input1, dataType),
770 OverrideDataType(output, dataType),
771 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000772 ARMNN_NO_DEPRECATE_WARN_END
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000773 break;
774 }
narpra01b89b05f2019-01-16 09:53:09 +0000775 case LayerType::MemCopy:
776 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100777 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000778 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000779
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000780 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
781 OverrideDataType(output, dataType),
782 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000783 break;
784 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100785 case LayerType::MemImport:
786 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100787 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertif674aa02019-08-01 15:56:25 +0100788 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
789
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000790 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
791 OverrideDataType(output, dataType),
792 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100793 break;
794 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100795 case LayerType::Merge:
796 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100797 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
798 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100799 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
800
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000801 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
802 OverrideDataType(input1, dataType),
803 OverrideDataType(output, dataType),
804 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100805 break;
806 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100807 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000808 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100809 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000810
telsoa01c577f2c2018-08-31 09:22:23 +0100811 // Get vector of all inputs.
812 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000813 {
telsoa01c577f2c2018-08-31 09:22:23 +0100814 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000815 };
Finn Williams3e54d032020-10-22 16:53:35 +0100816
817 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
818 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100819 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000820
telsoa01c577f2c2018-08-31 09:22:23 +0100821 auto getTensorInfoPtr = [](const TensorInfo& info)
822 {
823 return &info;
824 };
Finn Williams3e54d032020-10-22 16:53:35 +0100825
826 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
827 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100828 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000829
Nikhil Raj8599a412018-11-19 14:51:07 +0000830 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
831
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000832 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100833
834
telsoa014fcda012018-03-09 14:13:49 +0000835 break;
836 }
837 case LayerType::Multiplication:
838 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000839 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100840 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
841 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100842 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000843 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100844 OverrideDataType(input0, dataType),
845 OverrideDataType(input1, dataType),
846 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100847 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000848 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000849 break;
850 }
851 case LayerType::Normalization:
852 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100853 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100854 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000855 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000856 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
857 OverrideDataType(output, dataType),
858 cLayer->GetParameters(),
859 reason);
telsoa014fcda012018-03-09 14:13:49 +0000860 break;
861 }
862 case LayerType::Output:
863 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100864 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000865 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000866 break;
867 }
868 case LayerType::Permute:
869 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100870 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100871 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000872 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000873 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
874 OverrideDataType(output, dataType),
875 cLayer->GetParameters(),
876 reason);
telsoa014fcda012018-03-09 14:13:49 +0000877 break;
878 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100879 case LayerType::Pad:
880 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100881 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100882 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100883 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000884 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100885 OverrideDataType(input, dataType),
886 OverrideDataType(output, dataType),
887 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100888 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100889 break;
890 }
telsoa014fcda012018-03-09 14:13:49 +0000891 case LayerType::Pooling2d:
892 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100893 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100894 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000895 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000896 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
897 OverrideDataType(output, dataType),
898 cLayer->GetParameters(),
899 reason);
telsoa014fcda012018-03-09 14:13:49 +0000900 break;
901 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100902 case LayerType::Pooling3d:
903 {
904 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100905 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100906 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
907 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
908 OverrideDataType(output, dataType),
909 cLayer->GetParameters(),
910 reason);
911 break;
912 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000913 case LayerType::PreCompiled:
914 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100915 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100916 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000917 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
918 cLayer->GetParameters(),
919 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000920 break;
921 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000922 case LayerType::Quantize:
923 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100924 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000925 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000926 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000927 break;
928 }
James Conroy586a9aa2020-03-20 08:49:33 +0000929 case LayerType::QLstm:
930 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100931 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000932 const QLstmDescriptor& descriptor = cLayer->GetParameters();
933
934 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100935 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
936 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetTensorInfo();
937 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroy586a9aa2020-03-20 08:49:33 +0000938
939 // Outputs
940 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
941 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
942 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
943
944 // Lstm parameters
945 LstmInputParamsInfo paramsInfo;
946
947 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100948 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
949 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
950 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000951 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
952 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
953 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
954
955 paramsInfo.m_RecurrentToForgetWeights =
956 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
957 paramsInfo.m_RecurrentToCellWeights =
958 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
959 paramsInfo.m_RecurrentToOutputWeights =
960 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
961
962 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
963 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
964 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
965
966 if(!descriptor.m_CifgEnabled)
967 {
968 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
969 paramsInfo.m_RecurrentToInputWeights =
970 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
971 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
972 }
973
974 if(descriptor.m_ProjectionEnabled)
975 {
976 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100977
978 // Projection bias is optional even if projection is enabled
979 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
980 {
981 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
982 }
James Conroy586a9aa2020-03-20 08:49:33 +0000983 }
984
985 if(descriptor.m_PeepholeEnabled)
986 {
987 if (!descriptor.m_CifgEnabled)
988 {
989 paramsInfo.m_CellToInputWeights =
990 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
991 }
992
993 paramsInfo.m_CellToForgetWeights =
994 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
995 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
996 }
997
998 if(descriptor.m_LayerNormEnabled)
999 {
1000 if (!descriptor.m_CifgEnabled)
1001 {
1002 paramsInfo.m_InputLayerNormWeights =
1003 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
1004 }
1005
1006 paramsInfo.m_ForgetLayerNormWeights =
1007 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
1008 paramsInfo.m_CellLayerNormWeights =
1009 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
1010 paramsInfo.m_OutputLayerNormWeights =
1011 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
1012 }
1013
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001014 result = layerSupportObject.IsQLstmSupported(input,
1015 previousOutputIn,
1016 previousCellStateIn,
1017 outputStateOut,
1018 cellStateOut,
1019 output,
1020 descriptor,
1021 paramsInfo,
1022 reason);
James Conroy586a9aa2020-03-20 08:49:33 +00001023 break;
1024 }
James Conroyee18dc82019-07-17 11:27:46 +01001025 case LayerType::QuantizedLstm:
1026 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001027 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +01001028
1029 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001030 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1031 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetTensorInfo();
1032 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001033
1034 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001035 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1036 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001037
1038 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001039 QuantizedLstmInputParamsInfo paramsInfo;
1040
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001041 paramsInfo.m_InputToInputWeights =
1042 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1043 paramsInfo.m_InputToForgetWeights =
1044 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1045 paramsInfo.m_InputToCellWeights =
1046 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1047 paramsInfo.m_InputToOutputWeights =
1048 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001049
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001050 paramsInfo.m_RecurrentToInputWeights =
1051 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1052 paramsInfo.m_RecurrentToForgetWeights =
1053 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1054 paramsInfo.m_RecurrentToCellWeights =
1055 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1056 paramsInfo.m_RecurrentToOutputWeights =
1057 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001058
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001059 paramsInfo.m_InputGateBias =
1060 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1061 paramsInfo.m_ForgetGateBias =
1062 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1063 paramsInfo.m_CellBias =
1064 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1065 paramsInfo.m_OutputGateBias =
1066 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001067
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001068 result = layerSupportObject.IsQuantizedLstmSupported(input,
1069 previousCellStateIn,
1070 previousOutputIn,
1071 cellStateOut,
1072 output,
1073 paramsInfo,
1074 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001075 break;
1076 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001077 case LayerType::Division:
1078 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001079 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001080 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1081 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001082 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001083 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001084 OverrideDataType(input0, dataType),
1085 OverrideDataType(input1, dataType),
1086 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001087 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001088 ARMNN_NO_DEPRECATE_WARN_END
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001089 break;
1090 }
Finn Williams2605b232020-06-10 15:53:46 +01001091 case LayerType::Rank:
1092 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001093 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Finn Williams2605b232020-06-10 15:53:46 +01001094 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001095 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1096 OverrideDataType(output, dataType),
1097 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001098 break;
1099 }
telsoa014fcda012018-03-09 14:13:49 +00001100 case LayerType::Reshape:
1101 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001102 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001103 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001104 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001105 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1106 OverrideDataType(output, dataType),
1107 cLayer->GetParameters(),
1108 reason);
telsoa014fcda012018-03-09 14:13:49 +00001109 break;
1110 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001111 case LayerType::Resize:
1112 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001113 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001114 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001116 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1117 OverrideDataType(output, dataType),
1118 cLayer->GetParameters(),
1119 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001120 break;
1121 }
Tianle Cheng988354d2023-06-28 13:20:47 +01001122 case LayerType::ReverseV2:
1123 {
Tracy Narinebb8d7592023-07-13 16:50:54 +01001124 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1125 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
Tianle Cheng988354d2023-06-28 13:20:47 +01001126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Tracy Narinebb8d7592023-07-13 16:50:54 +01001127 result = layerSupportObject.IsReverseV2Supported(OverrideDataType(input0, dataType),
1128 OverrideDataType(input1, armnn::DataType::Signed32),
Tianle Cheng988354d2023-06-28 13:20:47 +01001129 OverrideDataType(output, dataType),
Tianle Cheng988354d2023-06-28 13:20:47 +01001130 reason);
1131 break;
1132 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001133 case LayerType::Shape:
1134 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001135 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Keith Davis3ae3f972021-05-21 16:33:48 +01001136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1137
1138 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1139 OverrideDataType(output, dataType),
1140 reason);
1141 break;
1142 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001143 case LayerType::Slice:
1144 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001145 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001146
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001147 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1149
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001150 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1151 OverrideDataType(output, dataType),
1152 cLayer->GetParameters(),
1153 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001154 break;
1155 }
telsoa014fcda012018-03-09 14:13:49 +00001156 case LayerType::Softmax:
1157 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001158 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001159 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001161 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1162 OverrideDataType(output, dataType),
1163 cLayer->GetParameters(),
1164 reason);
telsoa014fcda012018-03-09 14:13:49 +00001165 break;
1166 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001167 case LayerType::SpaceToBatchNd:
1168 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001169 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001170 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001171 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001172 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1173 OverrideDataType(output, dataType),
1174 cLayer->GetParameters(),
1175 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001176 break;
1177 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001178 case LayerType::SpaceToDepth:
1179 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001180 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001181
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001182 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001183 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1184
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001185 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1186 OverrideDataType(output, dataType),
1187 cLayer->GetParameters(),
1188 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001189 break;
1190 }
telsoa014fcda012018-03-09 14:13:49 +00001191 case LayerType::Splitter:
1192 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001193 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001194 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001195
1196 // Get vector of all outputs.
1197 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1198 {
1199 return OverrideDataType(slot.GetTensorInfo(), dataType);
1200 };
Finn Williams3e54d032020-10-22 16:53:35 +01001201 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1202 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001203 std::vector<TensorInfo> outputs(beginI, endI);
1204
1205 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1206
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001207 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1208 outputPtrs,
1209 cLayer->GetParameters(),
1210 reason);
telsoa014fcda012018-03-09 14:13:49 +00001211 break;
1212 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001213 case LayerType::Stack:
1214 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001215 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001216
1217 // Get vector of all inputs.
1218 auto getTensorInfo = [&dataType](const InputSlot& slot)
1219 {
1220 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1221 };
Finn Williams3e54d032020-10-22 16:53:35 +01001222 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1223 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001224 std::vector<TensorInfo> inputs(beginI, endI);
1225
1226 auto getTensorInfoPtr = [](const TensorInfo& info)
1227 {
1228 return &info;
1229 };
Finn Williams3e54d032020-10-22 16:53:35 +01001230 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1231 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001232 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1233
1234 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1235
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001236 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001237
1238 break;
1239 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001240 case LayerType::StandIn:
1241 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001242 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001243
1244 // Get vector of all inputs.
1245 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1246 {
1247 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1248 };
1249 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1250 {
1251 return OverrideDataType(slot.GetTensorInfo(), dataType);
1252 };
Finn Williams3e54d032020-10-22 16:53:35 +01001253 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1254 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001255 std::vector<TensorInfo> inputs(beginI, endI);
1256
Finn Williams3e54d032020-10-22 16:53:35 +01001257 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1258 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001259 std::vector<TensorInfo> outputs(beginO, endO);
1260
1261
1262 auto getTensorInfoPtr = [](const TensorInfo& info)
1263 {
1264 return &info;
1265 };
Finn Williams3e54d032020-10-22 16:53:35 +01001266 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1267 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001268 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1269
Finn Williams3e54d032020-10-22 16:53:35 +01001270 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1271 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001272 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1273
1274
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001275 result = layerSupportObject.IsStandInSupported(inputPtrs,
1276 outputPtrs,
1277 cLayer->GetParameters(),
1278 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001279 break;
1280 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001281 case LayerType::StridedSlice:
1282 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001283 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001284 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Conor Kennedy430b5d82018-11-14 15:28:28 +00001285 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001286 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1287 OverrideDataType(output, dataType),
1288 cLayer->GetParameters(),
1289 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001290 break;
1291 }
David Beckc2044fe2018-09-05 15:00:38 +01001292 case LayerType::Subtraction:
1293 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001294 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001295 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1296 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
David Beckc2044fe2018-09-05 15:00:38 +01001297 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001298 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001299 OverrideDataType(input0, dataType),
1300 OverrideDataType(input1, dataType),
1301 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001302 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001303 ARMNN_NO_DEPRECATE_WARN_END
David Beckc2044fe2018-09-05 15:00:38 +01001304 break;
1305 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001306 case LayerType::Switch:
1307 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001308 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1309 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Sadik Armaganeff363d2019-04-05 15:25:46 +01001310 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1311 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001312 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1313 OverrideDataType(input1, dataType),
1314 OverrideDataType(output0, dataType),
1315 OverrideDataType(output1, dataType),
1316 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001317 break;
1318 }
narpra0132b90462018-09-13 11:07:48 +01001319 case LayerType::Mean:
1320 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001321 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001322 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra0132b90462018-09-13 11:07:48 +01001323 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001324 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001325 OverrideDataType(input, dataType),
1326 OverrideDataType(output, dataType),
1327 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001328 reason);
narpra0132b90462018-09-13 11:07:48 +01001329 break;
1330 }
kevmay0190539692018-11-29 08:40:19 +00001331 case LayerType::Minimum:
1332 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001333 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001334 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1335 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
kevmay0190539692018-11-29 08:40:19 +00001336 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001337 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1338 OverrideDataType(input1, dataType),
1339 OverrideDataType(output, dataType),
1340 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001341 ARMNN_NO_DEPRECATE_WARN_END
kevmay0190539692018-11-29 08:40:19 +00001342 break;
1343 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001344 case LayerType::Prelu:
1345 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001346 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1347 const TensorInfo& alpha = layer.GetInputSlot(1).GetTensorInfo();
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001348 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001349 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1350 OverrideDataType(alpha, dataType),
1351 OverrideDataType(output, dataType),
1352 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001353 break;
1354 }
Teresa Charlin79a06a52023-07-13 17:16:45 +01001355 case LayerType::Tile:
1356 {
1357 auto cLayer = PolymorphicDowncast<const TileLayer*>(&layer);
1358 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1359 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1360
1361 result = layerSupportObject.IsTileSupported(OverrideDataType(input, dataType),
1362 OverrideDataType(output, dataType),
1363 cLayer->GetParameters(),
1364 reason);
1365
1366 break;
1367 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001368 case LayerType::Transpose:
1369 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001370 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001371 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001372 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001373 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1374 OverrideDataType(output, dataType),
1375 cLayer->GetParameters(),
1376 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001377 break;
1378 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001379 case LayerType::TransposeConvolution2d:
1380 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001381 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001382
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001383 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001384 dataType);
1385 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1386
1387 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1388
1389 Optional<TensorInfo> biases;
1390 if (descriptor.m_BiasEnabled)
1391 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001392 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001393 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1394 GetBiasTypeFromWeightsType(dataType));
1395 }
1396
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001397 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001398 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1399
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001400 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1401 output,
1402 descriptor,
1403 weights,
1404 biases,
1405 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001406
1407 break;
1408 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001409 case LayerType::Reduce:
1410 {
1411 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001412 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001413 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1414
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001415 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1416 OverrideDataType(output, dataType),
1417 cLayer->GetParameters(),
1418 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001419 break;
1420 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001421 case LayerType::UnidirectionalSequenceLstm:
1422 {
1423 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1424 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1425
1426 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001427 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001428 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001429 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001430 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001431 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001432 dataType);
1433 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001434 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1435 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1436 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001437
1438 // Basic parameters
1439 const TensorInfo& inputToForgetWeights
1440 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1441 const TensorInfo& inputToCellWeights
1442 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1443 const TensorInfo& inputToOutputWeights
1444 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1445 const TensorInfo& recurrentToForgetWeights
1446 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1447 const TensorInfo& recurrentToCellWeights
1448 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1449 const TensorInfo& recurrentToOutputWeights
1450 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1451 const TensorInfo& forgetGateBias
1452 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1453 const TensorInfo& cellBias
1454 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1455 const TensorInfo& outputGateBias
1456 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1457
1458 LstmInputParamsInfo paramsInfo;
1459
1460 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1461 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1462 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1463 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1464 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1465 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1466 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1467 paramsInfo.m_CellBias = &cellBias;
1468 paramsInfo.m_OutputGateBias = &outputGateBias;
1469
1470 // Optional parameters
1471 TensorInfo optInputToInputWeights;
1472 TensorInfo optRecurrentToInputWeights;
1473 TensorInfo optCellToInputWeights;
1474 TensorInfo optInputGateBias;
1475 TensorInfo optProjectionWeights;
1476 TensorInfo optProjectionBias;
1477 TensorInfo optCellToForgetWeights;
1478 TensorInfo optCellToOutputWeights;
1479 TensorInfo optInputLayerNormWeights;
1480 TensorInfo optForgetLayerNormWeights;
1481 TensorInfo optCellLayerNormWeights;
1482 TensorInfo optOutputLayerNormWeights;
1483
1484 if(!descriptor.m_CifgEnabled)
1485 {
1486 optInputToInputWeights =
1487 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1488 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1489
1490 optRecurrentToInputWeights =
1491 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1492 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1493 optInputGateBias =
1494 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1495 paramsInfo.m_InputGateBias = &optInputGateBias;
1496 }
1497
1498 if(descriptor.m_ProjectionEnabled)
1499 {
1500 optProjectionWeights =
1501 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1502 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1503 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1504 {
1505 optProjectionBias =
1506 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1507 paramsInfo.m_ProjectionBias = &optProjectionBias;
1508 }
1509 }
1510
1511 if(descriptor.m_PeepholeEnabled)
1512 {
1513 if(!descriptor.m_CifgEnabled)
1514 {
1515 optCellToInputWeights =
1516 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1517 dataType);
1518 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1519 }
1520 optCellToForgetWeights =
1521 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1522 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1523 optCellToOutputWeights =
1524 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1525 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1526 }
1527
1528 if(descriptor.m_LayerNormEnabled)
1529 {
1530 if (!descriptor.m_CifgEnabled)
1531 {
1532 optInputLayerNormWeights = OverrideDataType(
1533 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1534 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1535 }
1536
1537 optForgetLayerNormWeights = OverrideDataType(
1538 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1539 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1540
1541 optCellLayerNormWeights = OverrideDataType(
1542 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1543 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1544
1545 optOutputLayerNormWeights = OverrideDataType(
1546 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1547 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1548 }
1549
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001550 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1551 outputStateIn,
1552 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001553 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001554 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001555 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001556 descriptor,
1557 paramsInfo,
1558 reason);
1559 break;
1560 }
telsoa014fcda012018-03-09 14:13:49 +00001561 default:
1562 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001563 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001564 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001565 result = false;
1566 break;
1567 }
1568 }
telsoa014fcda012018-03-09 14:13:49 +00001569 return result;
1570}
1571
Sadik Armagan045f6be2020-09-10 13:37:32 +01001572bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1573 const IConnectableLayer& connectableLayer,
1574 Optional<DataType> dataType,
1575 std::string& outReasonIfUnsupported)
1576{
1577 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1578}
1579
David Beckdcb751f2018-10-03 11:42:42 +01001580bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001581 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001582 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001583{
Jan Eilersbb446e52020-04-02 13:56:54 +01001584 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001585 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1586}
1587
Sadik Armagan045f6be2020-09-10 13:37:32 +01001588bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1589 Optional<DataType> dataType,
1590 std::string& outReasonIfUnsupported,
1591 const ModelOptions& modelOptions)
1592{
1593 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1594 return IsLayerConfigurationSupported(layer->GetBackendId(),
1595 connectableLayer,
1596 dataType,
1597 outReasonIfUnsupported,
1598 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001599}
1600
Sadik Armagan04a72972020-09-14 15:44:18 +01001601bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1602 const IConnectableLayer& connectableLayer,
1603 Optional<DataType> dataType,
1604 std::string& outReasonIfUnsupported,
1605 const ModelOptions& modelOptions)
1606{
1607 return IsLayerConfigurationSupported(backendId,
1608 connectableLayer,
1609 dataType,
1610 outReasonIfUnsupported,
1611 modelOptions);
1612}
Cian McGriskin7894ef92023-08-01 14:04:09 +01001613
1614/// Backends should implement their own CreateWorkload function with a switch statement.
1615/// The case for the switch should be the LayerType and based on that they will call their
1616/// specific workload creation functionality.
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001617std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type,
1618 const QueueDescriptor& descriptor,
1619 const WorkloadInfo& info) const
1620{
Cian McGriskin7894ef92023-08-01 14:04:09 +01001621 IgnoreUnused(descriptor);
1622 IgnoreUnused(info);
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001623 switch(type)
1624 {
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001625 default:
Cian McGriskin7894ef92023-08-01 14:04:09 +01001626 return std::unique_ptr<IWorkload>();
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001627 }
1628}
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001629
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001630} // namepsace armnn