blob: 2538211a41e69f7779f116fb569789fa072a3192 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
David Beck111b5d92018-11-12 14:59:37 +000019#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000020
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
telsoa01c577f2c2018-08-31 09:22:23 +010024namespace
25{
Finn Williams3e54d032020-10-22 16:53:35 +010026using LayerList = std::list<Layer*>;
27using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010028
David Beck29c75de2018-10-23 13:35:58 +010029const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
30{
31 if (!type)
32 {
33 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010034 }
35
Matthew Sloyan81beae32021-07-13 19:46:11 +010036 return TensorInfo(info.GetShape(),
37 type.value(),
38 info.GetQuantizationScale(),
39 info.GetQuantizationOffset(),
40 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
Sadik Armagana097d2a2021-11-24 15:47:28 +000045inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
46{
47 if (!weightsType)
48 {
49 return weightsType;
50 }
51
52 switch(weightsType.value())
53 {
54 case armnn::DataType::BFloat16:
55 case armnn::DataType::Float16:
56 case armnn::DataType::Float32:
57 return weightsType;
58 case armnn::DataType::QAsymmS8:
59 case armnn::DataType::QAsymmU8:
60 case armnn::DataType::QSymmS8:
61 case armnn::DataType::QSymmS16:
62 return armnn::DataType::Signed32;
63 default:
64 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
65 }
66 return armnn::EmptyOptional();
67}
68
69
Sadik Armagan045f6be2020-09-10 13:37:32 +010070bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
71 const IConnectableLayer& connectableLayer,
72 Optional<DataType> dataType,
73 std::string& outReasonIfUnsupported,
74 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000075{
David Beck33f0ae02018-10-18 15:13:56 +010076 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000077 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010078 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010079
David Beck111b5d92018-11-12 14:59:37 +000080 auto const& backendRegistry = BackendRegistryInstance();
81 if (!backendRegistry.IsBackendRegistered(backendId))
82 {
83 std::stringstream ss;
84 ss << connectableLayer.GetName() << " is not supported on " << backendId
85 << " because this backend is not registered.";
86
87 outReasonIfUnsupported = ss.str();
88 return false;
89 }
90
91 auto backendFactory = backendRegistry.GetFactory(backendId);
92 auto backendObject = backendFactory();
Mike Kelly3ec30772023-03-08 13:47:17 +000093 auto layerSupport = backendObject->GetLayerSupport(modelOptions);
94 auto layerSupportObject = LayerSupportHandle(layerSupport, backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100101 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000112 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100113 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
114 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000116 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100117 OverrideDataType(input0, dataType),
118 OverrideDataType(input1, dataType),
119 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100120 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000121 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000122 break;
123 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100124 case LayerType::ArgMinMax:
125 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100126 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100127 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
128
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100129 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000131 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000133 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100134 descriptor,
135 reason);
136 break;
137 }
Samuel Yap6b478092022-07-06 15:36:03 +0100138 case LayerType::BatchMatMul:
139 {
140 auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
141 const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
142
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100143 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
144 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Samuel Yap6b478092022-07-06 15:36:03 +0100145 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
146 result = layerSupportObject.IsBatchMatMulSupported(
147 OverrideDataType(input0, dataType),
148 OverrideDataType(input1, dataType),
149 OverrideDataType(output, dataType),
150 descriptor,
151 reason);
152 break;
153 }
telsoa014fcda012018-03-09 14:13:49 +0000154 case LayerType::BatchNormalization:
155 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100156 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100157 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
159 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
160 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
161 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
162 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100164 OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 OverrideDataType(mean, dataType),
167 OverrideDataType(var, dataType),
168 OverrideDataType(beta, dataType),
169 OverrideDataType(gamma, dataType),
170 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100171 reason);
telsoa014fcda012018-03-09 14:13:49 +0000172 break;
173 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000174 case LayerType::BatchToSpaceNd:
175 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100176 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000177 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100178 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000179
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000180 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
181 OverrideDataType(output, dataType),
182 cLayer->GetParameters(),
183 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000184 break;
185 }
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100186 case LayerType::BroadcastTo:
187 {
188 auto cLayer = PolymorphicDowncast<const BroadcastToLayer*>(&layer);
189 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
190 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
191
192 result = layerSupportObject.IsBroadcastToSupported(OverrideDataType(input, dataType),
193 OverrideDataType(output, dataType),
194 cLayer->GetParameters(),
195 reason);
196 break;
197 }
mathad01b392e982021-04-07 12:07:30 +0100198 case LayerType::Cast:
199 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100200 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
mathad01b392e982021-04-07 12:07:30 +0100201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
202
203 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
204 OverrideDataType(output, dataType),
205 reason);
206 break;
207 }
Simon Obute51f67772021-09-03 15:50:13 +0100208 case LayerType::ChannelShuffle:
209 {
210 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
211
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100212 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
213 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Simon Obute51f67772021-09-03 15:50:13 +0100214
215 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
216
217 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
218 OverrideDataType(output, dataType),
219 descriptor,
220 reason);
221 break;
222 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100223 case LayerType::Comparison:
224 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100225 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100226
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100227 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
228 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100229 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
230
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000231 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
232 OverrideDataType(input1, dataType),
233 OverrideDataType(output, DataType::Boolean),
234 cLayer->GetParameters(),
235 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100236 break;
237 }
telsoa014fcda012018-03-09 14:13:49 +0000238 case LayerType::Constant:
239 {
240 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000241 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100242 break;
243 }
244 case LayerType::ConvertFp16ToFp32:
245 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100246 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000248 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100249 break;
250 }
251 case LayerType::ConvertFp32ToFp16:
252 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100253 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100254 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000255 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000256 break;
257 }
258 case LayerType::Convolution2d:
259 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100260 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100261
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100262 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
arovir01a6824102018-08-28 17:40:45 +0100263 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100264 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100265 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
266 "Convolution2dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100267 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100268 dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100269
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100270 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100271
arovir01a6824102018-08-28 17:40:45 +0100272 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100273 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100274 if (descriptor.m_BiasEnabled)
275 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100276 ARMNN_ASSERT_MSG(layer.GetInputSlot(2).GetConnection(),
277 "Convolution2dLayer: Bias should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100278 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100279 GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100280 }
281
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000282 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100283 input,
284 output,
285 descriptor,
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100286 weights,
arovir01a6824102018-08-28 17:40:45 +0100287 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100288 reason);
telsoa014fcda012018-03-09 14:13:49 +0000289 break;
290 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100291 case LayerType::Convolution3d:
292 {
293 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
294
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100295 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100296 dataType);
297 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100298
299 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
300 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100301 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100302 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100303
304 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
305
306 // Construct optional biases object based on the value of m_BiasEnabled
307 Optional<TensorInfo> biases;
308 if (descriptor.m_BiasEnabled)
309 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100310 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100311 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100312 }
313
314 result = layerSupportObject.IsConvolution3dSupported(
315 input,
316 output,
317 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100318 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100319 biases,
320 reason);
321 break;
322 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000323 case LayerType::Debug:
324 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100325 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000326 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
327
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000328 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000329 OverrideDataType(output, dataType),
330 reason);
331 break;
332 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100333 case LayerType::DepthToSpace:
334 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100335 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100336
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100337 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100338 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
339
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000340 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100341 OverrideDataType(output, dataType),
342 cLayer->GetParameters(),
343 reason);
344 break;
345 }
telsoa014fcda012018-03-09 14:13:49 +0000346 case LayerType::DepthwiseConvolution2d:
347 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100348 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100349 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100350 dataType);
351 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100352 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100353 dataType);
354
355 ARMNN_ASSERT(cLayer->GetInputSlot(1).GetConnection() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100356
telsoa01c577f2c2018-08-31 09:22:23 +0100357 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100358
359 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100360 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100361 if (descriptor.m_BiasEnabled)
362 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100363 biases = OverrideDataType(cLayer->GetInputSlot(2).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100364 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100365 }
telsoa01c577f2c2018-08-31 09:22:23 +0100366
Cathal Corbett06902652022-04-14 17:55:11 +0100367 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
368 output,
369 descriptor,
370 weights,
371 biases,
372 reason);
telsoa014fcda012018-03-09 14:13:49 +0000373 break;
374 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000375 case LayerType::Dequantize:
376 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100377 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000378 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
379
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000380 result = layerSupportObject.IsDequantizeSupported(input,
381 OverrideDataType(output, dataType),
382 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000383 break;
384 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000385 case LayerType::DetectionPostProcess:
386 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100387 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100388 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetTensorInfo();
389 const TensorInfo& scores = layer.GetInputSlot(1).GetTensorInfo();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000390 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
391
392 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
393 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
394 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
395 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
396
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000397 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000398 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
399 scores,
400 anchors,
401 detectionBoxes,
402 detectionClasses,
403 detectionScores,
404 numDetections,
405 descriptor,
406 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000407 break;
408 }
Mike Kelly3ec30772023-03-08 13:47:17 +0000409 case LayerType::ElementwiseBinary:
410 {
411 auto cLayer = PolymorphicDowncast<const ElementwiseBinaryLayer*>(&layer);
412
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100413 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
414 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Mike Kelly3ec30772023-03-08 13:47:17 +0000415 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
416 std::vector<TensorInfo> infos = { OverrideDataType(input0, dataType),
417 OverrideDataType(input1, dataType),
418 OverrideDataType(output, dataType) };
419 result = layerSupport->IsLayerSupported(LayerType::ElementwiseBinary,
420 infos,
421 cLayer->GetParameters(),
422 EmptyOptional(),
423 EmptyOptional(),
424 reason);
425 break;
426 }
josh minor4a3c6102020-01-06 16:40:46 -0600427 case LayerType::ElementwiseUnary:
428 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100429 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600430
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100431 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
josh minor4a3c6102020-01-06 16:40:46 -0600432 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
433
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000434 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
435 OverrideDataType(output, dataType),
436 cLayer->GetParameters(),
437 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600438 break;
439 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100440 case LayerType::Fill:
441 {
442 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100443 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Ryan OSheaec6c6802020-06-05 17:17:06 +0100444 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
445 const FillDescriptor& descriptor = cLayer->GetParameters();
446
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000447 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100448 OverrideDataType(input, dataType),
449 OverrideDataType(output, dataType),
450 descriptor,
451 reason);
452 break;
453 }
telsoa014fcda012018-03-09 14:13:49 +0000454 case LayerType::FakeQuantization:
455 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100456 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100457 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000458 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
459 cLayer->GetParameters(),
460 reason);
telsoa014fcda012018-03-09 14:13:49 +0000461 break;
462 }
463 case LayerType::Floor:
464 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100465 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000466 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000467 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
468 OverrideDataType(output, dataType),
469 reason);
telsoa014fcda012018-03-09 14:13:49 +0000470 break;
471 }
472 case LayerType::FullyConnected:
473 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100474 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100475 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100476 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000477
478 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
479 TensorInfo weightsInfo;
480 const TensorInfo* weightsInfoPtr = nullptr;
481
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100482 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000483 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100484
485 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000486 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000487 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100488 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
489 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
490 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
491
telsoa01c577f2c2018-08-31 09:22:23 +0100492 if (descriptor.m_BiasEnabled)
493 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100494 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(), dataType);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100495 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100496 }
497 else
498 {
499 // If biases are not enabled pass a dummy tensorinfo for the validation
500 switch(input.GetDataType())
501 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000502 case DataType::BFloat16:
503 {
504 biasInfoPtr = &dummyBFloat16Bias;
505 break;
506 }
telsoa01c577f2c2018-08-31 09:22:23 +0100507 case DataType::Float16:
508 {
509 biasInfoPtr = &dummyFloat16Bias;
510 break;
511 }
512 case DataType::Float32:
513 {
514 biasInfoPtr = &dummyFloat32Bias;
515 break;
516 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000517 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000518 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000519 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000520 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100521 {
522 biasInfoPtr = &dummyQA8Bias;
523 break;
524 }
525 default:
526 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100527 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100528 }
529 }
530 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000531 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100532 OverrideDataType(input, dataType),
533 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000534 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100535 *biasInfoPtr,
536 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100537 reason);
telsoa014fcda012018-03-09 14:13:49 +0000538 break;
539 }
Teresa Charlin9145e382023-08-17 18:44:58 +0100540 case LayerType::Fused:
541 {
542 auto cLayer = PolymorphicDowncast<const FusedLayer*>(&layer);
543
544 // Get vector of all outputs.
545 auto getOutTensorInfo = [&dataType](const OutputSlot& slot)
546 {
547 return OverrideDataType(slot.GetTensorInfo(), dataType);
548 };
549 auto beginOutputs = MakeTransformIterator(layer.GetOutputSlots().begin(), getOutTensorInfo);
550 auto endOutputs = MakeTransformIterator(layer.GetOutputSlots().end(), getOutTensorInfo);
551 std::vector<TensorInfo> outputs(beginOutputs, endOutputs);
552 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
553
554 // Get vector of all inputs.
555 auto getInputTensorInfo = [&dataType](const InputSlot& slot)
556 {
557 return OverrideDataType(slot.GetTensorInfo(), dataType);
558 };
559 auto beginInputs = MakeTransformIterator(layer.GetInputSlots().begin(), getInputTensorInfo);
560 auto endInputs = MakeTransformIterator(layer.GetInputSlots().end(), getInputTensorInfo);
561 std::vector<TensorInfo> inputs(beginInputs, endInputs);
562 const std::vector<std::reference_wrapper<TensorInfo>> inputPtrs(inputs.begin(), inputs.end());
563
564 result = layerSupportObject.IsFusedSupported(inputPtrs,
565 outputPtrs,
566 cLayer->GetParameters(),
567 reason);
568 break;
569 }
narpra01b89b05f2019-01-16 09:53:09 +0000570 case LayerType::Gather:
571 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100572 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
573 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000574 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100575 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
576 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000577 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
578 input1,
579 OverrideDataType(output, dataType),
580 descriptor,
581 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000582 break;
583 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100584 case LayerType::GatherNd:
585 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100586 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
587 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100588 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
589 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
590 input1,
591 OverrideDataType(output, dataType),
592 reason);
593 break;
594 }
telsoa014fcda012018-03-09 14:13:49 +0000595 case LayerType::Input:
596 {
597 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000598 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000599 break;
600 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100601 case LayerType::InstanceNormalization:
602 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100603 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100604 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
605
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100606 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Mayce5045a2019-10-02 14:07:47 +0100607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
608
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000609 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100610 OverrideDataType(input, dataType),
611 OverrideDataType(output, dataType),
612 descriptor,
613 reason);
614 break;
615 }
telsoa014fcda012018-03-09 14:13:49 +0000616 case LayerType::L2Normalization:
617 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100618 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100619 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
620
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100621 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100622 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100623
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000624 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100625 OverrideDataType(input, dataType),
626 OverrideDataType(output, dataType),
627 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100628 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100629 break;
630 }
James Conroyaba90cd2020-11-06 16:28:18 +0000631 case LayerType::LogicalBinary:
632 {
633 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
634
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100635 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
636 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
James Conroyaba90cd2020-11-06 16:28:18 +0000637 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
638
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000639 result = layerSupportObject.IsLogicalBinarySupported(input0,
640 input1,
641 output,
642 cLayer->GetParameters(),
643 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000644 break;
645 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100646 case LayerType::LogSoftmax:
647 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100648 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100649
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100650 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100651 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
652
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000653 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
654 OverrideDataType(output, dataType),
655 cLayer->GetParameters(),
656 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100657 break;
658 }
telsoa01c577f2c2018-08-31 09:22:23 +0100659 case LayerType::Lstm:
660 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100661 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100662 const LstmDescriptor& descriptor = cLayer->GetParameters();
663
664 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100665 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100666 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100667 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100668 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100669 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100670 dataType);
671 // All outputs
672 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
673 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
674 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
675 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
676
677 // Basic parameters
678 const TensorInfo& inputToForgetWeights
679 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
680 const TensorInfo& inputToCellWeights
681 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
682 const TensorInfo& inputToOutputWeights
683 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
684 const TensorInfo& recurrentToForgetWeights
685 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
686 const TensorInfo& recurrentToCellWeights
687 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
688 const TensorInfo& recurrentToOutputWeights
689 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
690 const TensorInfo& forgetGateBias
691 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
692 const TensorInfo& cellBias
693 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
694 const TensorInfo& outputGateBias
695 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
696
Jan Eilersd01a83c2019-07-03 18:20:40 +0100697 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100698
Jan Eilersd01a83c2019-07-03 18:20:40 +0100699 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
700 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
701 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
702 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
703 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
704 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
705 paramsInfo.m_ForgetGateBias = &forgetGateBias;
706 paramsInfo.m_CellBias = &cellBias;
707 paramsInfo.m_OutputGateBias = &outputGateBias;
708
709
710 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100711 TensorInfo optInputToInputWeights;
712 TensorInfo optRecurrentToInputWeights;
713 TensorInfo optCellToInputWeights;
714 TensorInfo optInputGateBias;
715 TensorInfo optProjectionWeights;
716 TensorInfo optProjectionBias;
717 TensorInfo optCellToForgetWeights;
718 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100719 TensorInfo optInputLayerNormWeights;
720 TensorInfo optForgetLayerNormWeights;
721 TensorInfo optCellLayerNormWeights;
722 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100723
724 if(!descriptor.m_CifgEnabled)
725 {
726 optInputToInputWeights =
727 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100728 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100729
730 optRecurrentToInputWeights =
731 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100732 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100733 optInputGateBias =
734 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100735 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100736 }
737
738 if(descriptor.m_ProjectionEnabled)
739 {
740 optProjectionWeights =
741 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100742 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100743 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
744 {
745 optProjectionBias =
746 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100747 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100748 }
749 }
750
751 if(descriptor.m_PeepholeEnabled)
752 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100753 if(!descriptor.m_CifgEnabled)
754 {
755 optCellToInputWeights =
756 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
757 dataType);
758 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
759 }
telsoa01c577f2c2018-08-31 09:22:23 +0100760 optCellToForgetWeights =
761 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100762 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100763 optCellToOutputWeights =
764 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100765 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100766 }
767
Jan Eilers38e05bd2019-06-26 13:10:09 +0100768 if(descriptor.m_LayerNormEnabled)
769 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100770 if (!descriptor.m_CifgEnabled)
771 {
772 optInputLayerNormWeights = OverrideDataType(
773 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
774 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
775 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100776
777 optForgetLayerNormWeights = OverrideDataType(
778 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100779 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100780
781 optCellLayerNormWeights = OverrideDataType(
782 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100783 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100784
785 optOutputLayerNormWeights = OverrideDataType(
786 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100787 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100788 }
789
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000790 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100791 input,
792 outputStateIn,
793 cellStateIn,
794 scratchBuffer,
795 outputStateOut,
796 cellStateOut,
797 output,
798 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100799 paramsInfo,
800 reason);
telsoa014fcda012018-03-09 14:13:49 +0000801 break;
802 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000803 case LayerType::Maximum:
804 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000805 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100806 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
807 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000808 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
809
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000810 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
811 OverrideDataType(input1, dataType),
812 OverrideDataType(output, dataType),
813 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000814 ARMNN_NO_DEPRECATE_WARN_END
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000815 break;
816 }
narpra01b89b05f2019-01-16 09:53:09 +0000817 case LayerType::MemCopy:
818 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100819 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000820 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000821
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000822 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
823 OverrideDataType(output, dataType),
824 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000825 break;
826 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100827 case LayerType::MemImport:
828 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100829 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertif674aa02019-08-01 15:56:25 +0100830 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
831
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000832 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
833 OverrideDataType(output, dataType),
834 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100835 break;
836 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100837 case LayerType::Merge:
838 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100839 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
840 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100841 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
842
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000843 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
844 OverrideDataType(input1, dataType),
845 OverrideDataType(output, dataType),
846 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100847 break;
848 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100849 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000850 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100851 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000852
telsoa01c577f2c2018-08-31 09:22:23 +0100853 // Get vector of all inputs.
854 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000855 {
telsoa01c577f2c2018-08-31 09:22:23 +0100856 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000857 };
Finn Williams3e54d032020-10-22 16:53:35 +0100858
859 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
860 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100861 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000862
telsoa01c577f2c2018-08-31 09:22:23 +0100863 auto getTensorInfoPtr = [](const TensorInfo& info)
864 {
865 return &info;
866 };
Finn Williams3e54d032020-10-22 16:53:35 +0100867
868 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
869 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100870 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000871
Nikhil Raj8599a412018-11-19 14:51:07 +0000872 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
873
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000874 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100875
876
telsoa014fcda012018-03-09 14:13:49 +0000877 break;
878 }
879 case LayerType::Multiplication:
880 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000881 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100882 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
883 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100884 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000885 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100886 OverrideDataType(input0, dataType),
887 OverrideDataType(input1, dataType),
888 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100889 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000890 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000891 break;
892 }
893 case LayerType::Normalization:
894 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100895 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100896 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000897 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000898 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
899 OverrideDataType(output, dataType),
900 cLayer->GetParameters(),
901 reason);
telsoa014fcda012018-03-09 14:13:49 +0000902 break;
903 }
904 case LayerType::Output:
905 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100906 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000907 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000908 break;
909 }
910 case LayerType::Permute:
911 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100912 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100913 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000914 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000915 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
916 OverrideDataType(output, dataType),
917 cLayer->GetParameters(),
918 reason);
telsoa014fcda012018-03-09 14:13:49 +0000919 break;
920 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100921 case LayerType::Pad:
922 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100923 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100924 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100925 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000926 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100927 OverrideDataType(input, dataType),
928 OverrideDataType(output, dataType),
929 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100930 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100931 break;
932 }
telsoa014fcda012018-03-09 14:13:49 +0000933 case LayerType::Pooling2d:
934 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100935 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100936 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000937 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000938 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
939 OverrideDataType(output, dataType),
940 cLayer->GetParameters(),
941 reason);
telsoa014fcda012018-03-09 14:13:49 +0000942 break;
943 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100944 case LayerType::Pooling3d:
945 {
946 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100947 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100948 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
949 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
950 OverrideDataType(output, dataType),
951 cLayer->GetParameters(),
952 reason);
953 break;
954 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000955 case LayerType::PreCompiled:
956 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100957 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100958 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000959 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
960 cLayer->GetParameters(),
961 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000962 break;
963 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000964 case LayerType::Quantize:
965 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100966 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000967 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000968 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000969 break;
970 }
James Conroy586a9aa2020-03-20 08:49:33 +0000971 case LayerType::QLstm:
972 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100973 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000974 const QLstmDescriptor& descriptor = cLayer->GetParameters();
975
976 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100977 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
978 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetTensorInfo();
979 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroy586a9aa2020-03-20 08:49:33 +0000980
981 // Outputs
982 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
983 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
984 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
985
986 // Lstm parameters
987 LstmInputParamsInfo paramsInfo;
988
989 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100990 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
991 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
992 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000993 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
994 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
995 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
996
997 paramsInfo.m_RecurrentToForgetWeights =
998 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
999 paramsInfo.m_RecurrentToCellWeights =
1000 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
1001 paramsInfo.m_RecurrentToOutputWeights =
1002 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
1003
1004 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
1005 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
1006 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
1007
1008 if(!descriptor.m_CifgEnabled)
1009 {
1010 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
1011 paramsInfo.m_RecurrentToInputWeights =
1012 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
1013 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
1014 }
1015
1016 if(descriptor.m_ProjectionEnabled)
1017 {
1018 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +01001019
1020 // Projection bias is optional even if projection is enabled
1021 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1022 {
1023 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
1024 }
James Conroy586a9aa2020-03-20 08:49:33 +00001025 }
1026
1027 if(descriptor.m_PeepholeEnabled)
1028 {
1029 if (!descriptor.m_CifgEnabled)
1030 {
1031 paramsInfo.m_CellToInputWeights =
1032 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
1033 }
1034
1035 paramsInfo.m_CellToForgetWeights =
1036 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
1037 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
1038 }
1039
1040 if(descriptor.m_LayerNormEnabled)
1041 {
1042 if (!descriptor.m_CifgEnabled)
1043 {
1044 paramsInfo.m_InputLayerNormWeights =
1045 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
1046 }
1047
1048 paramsInfo.m_ForgetLayerNormWeights =
1049 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
1050 paramsInfo.m_CellLayerNormWeights =
1051 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
1052 paramsInfo.m_OutputLayerNormWeights =
1053 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
1054 }
1055
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001056 result = layerSupportObject.IsQLstmSupported(input,
1057 previousOutputIn,
1058 previousCellStateIn,
1059 outputStateOut,
1060 cellStateOut,
1061 output,
1062 descriptor,
1063 paramsInfo,
1064 reason);
James Conroy586a9aa2020-03-20 08:49:33 +00001065 break;
1066 }
James Conroyee18dc82019-07-17 11:27:46 +01001067 case LayerType::QuantizedLstm:
1068 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001069 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +01001070
1071 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001072 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1073 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetTensorInfo();
1074 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001075
1076 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001077 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1078 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001079
1080 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001081 QuantizedLstmInputParamsInfo paramsInfo;
1082
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001083 paramsInfo.m_InputToInputWeights =
1084 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1085 paramsInfo.m_InputToForgetWeights =
1086 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1087 paramsInfo.m_InputToCellWeights =
1088 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1089 paramsInfo.m_InputToOutputWeights =
1090 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001091
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001092 paramsInfo.m_RecurrentToInputWeights =
1093 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1094 paramsInfo.m_RecurrentToForgetWeights =
1095 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1096 paramsInfo.m_RecurrentToCellWeights =
1097 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1098 paramsInfo.m_RecurrentToOutputWeights =
1099 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001100
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001101 paramsInfo.m_InputGateBias =
1102 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1103 paramsInfo.m_ForgetGateBias =
1104 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1105 paramsInfo.m_CellBias =
1106 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1107 paramsInfo.m_OutputGateBias =
1108 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001109
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001110 result = layerSupportObject.IsQuantizedLstmSupported(input,
1111 previousCellStateIn,
1112 previousOutputIn,
1113 cellStateOut,
1114 output,
1115 paramsInfo,
1116 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001117 break;
1118 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001119 case LayerType::Division:
1120 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001121 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001122 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1123 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001124 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001125 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001126 OverrideDataType(input0, dataType),
1127 OverrideDataType(input1, dataType),
1128 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001129 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001130 ARMNN_NO_DEPRECATE_WARN_END
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001131 break;
1132 }
Finn Williams2605b232020-06-10 15:53:46 +01001133 case LayerType::Rank:
1134 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001135 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Finn Williams2605b232020-06-10 15:53:46 +01001136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001137 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1138 OverrideDataType(output, dataType),
1139 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001140 break;
1141 }
telsoa014fcda012018-03-09 14:13:49 +00001142 case LayerType::Reshape:
1143 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001144 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001145 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001146 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001147 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1148 OverrideDataType(output, dataType),
1149 cLayer->GetParameters(),
1150 reason);
telsoa014fcda012018-03-09 14:13:49 +00001151 break;
1152 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001153 case LayerType::Resize:
1154 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001155 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001156 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001157 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001158 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1159 OverrideDataType(output, dataType),
1160 cLayer->GetParameters(),
1161 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001162 break;
1163 }
Tianle Cheng988354d2023-06-28 13:20:47 +01001164 case LayerType::ReverseV2:
1165 {
Tracy Narinebb8d7592023-07-13 16:50:54 +01001166 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1167 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
Tianle Cheng988354d2023-06-28 13:20:47 +01001168 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Tracy Narinebb8d7592023-07-13 16:50:54 +01001169 result = layerSupportObject.IsReverseV2Supported(OverrideDataType(input0, dataType),
1170 OverrideDataType(input1, armnn::DataType::Signed32),
Tianle Cheng988354d2023-06-28 13:20:47 +01001171 OverrideDataType(output, dataType),
Tianle Cheng988354d2023-06-28 13:20:47 +01001172 reason);
1173 break;
1174 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001175 case LayerType::Shape:
1176 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001177 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Keith Davis3ae3f972021-05-21 16:33:48 +01001178 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1179
1180 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1181 OverrideDataType(output, dataType),
1182 reason);
1183 break;
1184 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001185 case LayerType::Slice:
1186 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001187 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001188
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001189 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001190 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1191
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001192 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1193 OverrideDataType(output, dataType),
1194 cLayer->GetParameters(),
1195 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001196 break;
1197 }
telsoa014fcda012018-03-09 14:13:49 +00001198 case LayerType::Softmax:
1199 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001200 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001201 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001202 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001203 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1204 OverrideDataType(output, dataType),
1205 cLayer->GetParameters(),
1206 reason);
telsoa014fcda012018-03-09 14:13:49 +00001207 break;
1208 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001209 case LayerType::SpaceToBatchNd:
1210 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001211 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001212 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001213 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001214 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1215 OverrideDataType(output, dataType),
1216 cLayer->GetParameters(),
1217 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001218 break;
1219 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001220 case LayerType::SpaceToDepth:
1221 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001222 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001223
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001224 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001225 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1226
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001227 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1228 OverrideDataType(output, dataType),
1229 cLayer->GetParameters(),
1230 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001231 break;
1232 }
telsoa014fcda012018-03-09 14:13:49 +00001233 case LayerType::Splitter:
1234 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001235 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001236 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001237
1238 // Get vector of all outputs.
1239 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1240 {
1241 return OverrideDataType(slot.GetTensorInfo(), dataType);
1242 };
Finn Williams3e54d032020-10-22 16:53:35 +01001243 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1244 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001245 std::vector<TensorInfo> outputs(beginI, endI);
1246
1247 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1248
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001249 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1250 outputPtrs,
1251 cLayer->GetParameters(),
1252 reason);
telsoa014fcda012018-03-09 14:13:49 +00001253 break;
1254 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001255 case LayerType::Stack:
1256 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001257 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001258
1259 // Get vector of all inputs.
1260 auto getTensorInfo = [&dataType](const InputSlot& slot)
1261 {
1262 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1263 };
Finn Williams3e54d032020-10-22 16:53:35 +01001264 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1265 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001266 std::vector<TensorInfo> inputs(beginI, endI);
1267
1268 auto getTensorInfoPtr = [](const TensorInfo& info)
1269 {
1270 return &info;
1271 };
Finn Williams3e54d032020-10-22 16:53:35 +01001272 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1273 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001274 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1275
1276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1277
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001278 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001279
1280 break;
1281 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001282 case LayerType::StandIn:
1283 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001284 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001285
1286 // Get vector of all inputs.
1287 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1288 {
1289 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1290 };
1291 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1292 {
1293 return OverrideDataType(slot.GetTensorInfo(), dataType);
1294 };
Finn Williams3e54d032020-10-22 16:53:35 +01001295 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1296 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001297 std::vector<TensorInfo> inputs(beginI, endI);
1298
Finn Williams3e54d032020-10-22 16:53:35 +01001299 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1300 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001301 std::vector<TensorInfo> outputs(beginO, endO);
1302
1303
1304 auto getTensorInfoPtr = [](const TensorInfo& info)
1305 {
1306 return &info;
1307 };
Finn Williams3e54d032020-10-22 16:53:35 +01001308 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1309 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001310 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1311
Finn Williams3e54d032020-10-22 16:53:35 +01001312 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1313 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001314 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1315
1316
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001317 result = layerSupportObject.IsStandInSupported(inputPtrs,
1318 outputPtrs,
1319 cLayer->GetParameters(),
1320 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001321 break;
1322 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001323 case LayerType::StridedSlice:
1324 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001325 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001326 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Conor Kennedy430b5d82018-11-14 15:28:28 +00001327 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001328 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1329 OverrideDataType(output, dataType),
1330 cLayer->GetParameters(),
1331 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001332 break;
1333 }
David Beckc2044fe2018-09-05 15:00:38 +01001334 case LayerType::Subtraction:
1335 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001336 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001337 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1338 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
David Beckc2044fe2018-09-05 15:00:38 +01001339 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001340 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001341 OverrideDataType(input0, dataType),
1342 OverrideDataType(input1, dataType),
1343 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001344 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001345 ARMNN_NO_DEPRECATE_WARN_END
David Beckc2044fe2018-09-05 15:00:38 +01001346 break;
1347 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001348 case LayerType::Switch:
1349 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001350 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1351 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Sadik Armaganeff363d2019-04-05 15:25:46 +01001352 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1353 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001354 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1355 OverrideDataType(input1, dataType),
1356 OverrideDataType(output0, dataType),
1357 OverrideDataType(output1, dataType),
1358 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001359 break;
1360 }
narpra0132b90462018-09-13 11:07:48 +01001361 case LayerType::Mean:
1362 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001363 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001364 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra0132b90462018-09-13 11:07:48 +01001365 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001366 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001367 OverrideDataType(input, dataType),
1368 OverrideDataType(output, dataType),
1369 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001370 reason);
narpra0132b90462018-09-13 11:07:48 +01001371 break;
1372 }
kevmay0190539692018-11-29 08:40:19 +00001373 case LayerType::Minimum:
1374 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001375 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001376 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1377 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
kevmay0190539692018-11-29 08:40:19 +00001378 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001379 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1380 OverrideDataType(input1, dataType),
1381 OverrideDataType(output, dataType),
1382 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001383 ARMNN_NO_DEPRECATE_WARN_END
kevmay0190539692018-11-29 08:40:19 +00001384 break;
1385 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001386 case LayerType::Prelu:
1387 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001388 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1389 const TensorInfo& alpha = layer.GetInputSlot(1).GetTensorInfo();
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001390 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001391 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1392 OverrideDataType(alpha, dataType),
1393 OverrideDataType(output, dataType),
1394 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001395 break;
1396 }
Teresa Charlin79a06a52023-07-13 17:16:45 +01001397 case LayerType::Tile:
1398 {
1399 auto cLayer = PolymorphicDowncast<const TileLayer*>(&layer);
1400 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1401 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1402
1403 result = layerSupportObject.IsTileSupported(OverrideDataType(input, dataType),
1404 OverrideDataType(output, dataType),
1405 cLayer->GetParameters(),
1406 reason);
1407
1408 break;
1409 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001410 case LayerType::Transpose:
1411 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001412 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001413 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001414 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001415 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1416 OverrideDataType(output, dataType),
1417 cLayer->GetParameters(),
1418 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001419 break;
1420 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001421 case LayerType::TransposeConvolution2d:
1422 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001423 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001424
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001425 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001426 dataType);
1427 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1428
1429 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1430
1431 Optional<TensorInfo> biases;
1432 if (descriptor.m_BiasEnabled)
1433 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001434 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001435 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1436 GetBiasTypeFromWeightsType(dataType));
1437 }
1438
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001439 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001440 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1441
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001442 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1443 output,
1444 descriptor,
1445 weights,
1446 biases,
1447 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001448
1449 break;
1450 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001451 case LayerType::Reduce:
1452 {
1453 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001454 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001455 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1456
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001457 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1458 OverrideDataType(output, dataType),
1459 cLayer->GetParameters(),
1460 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001461 break;
1462 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001463 case LayerType::UnidirectionalSequenceLstm:
1464 {
1465 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1466 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1467
1468 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001469 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001470 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001471 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001472 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001473 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001474 dataType);
1475 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001476 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1477 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1478 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001479
1480 // Basic parameters
1481 const TensorInfo& inputToForgetWeights
1482 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1483 const TensorInfo& inputToCellWeights
1484 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1485 const TensorInfo& inputToOutputWeights
1486 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1487 const TensorInfo& recurrentToForgetWeights
1488 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1489 const TensorInfo& recurrentToCellWeights
1490 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1491 const TensorInfo& recurrentToOutputWeights
1492 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1493 const TensorInfo& forgetGateBias
1494 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1495 const TensorInfo& cellBias
1496 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1497 const TensorInfo& outputGateBias
1498 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1499
1500 LstmInputParamsInfo paramsInfo;
1501
1502 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1503 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1504 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1505 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1506 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1507 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1508 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1509 paramsInfo.m_CellBias = &cellBias;
1510 paramsInfo.m_OutputGateBias = &outputGateBias;
1511
1512 // Optional parameters
1513 TensorInfo optInputToInputWeights;
1514 TensorInfo optRecurrentToInputWeights;
1515 TensorInfo optCellToInputWeights;
1516 TensorInfo optInputGateBias;
1517 TensorInfo optProjectionWeights;
1518 TensorInfo optProjectionBias;
1519 TensorInfo optCellToForgetWeights;
1520 TensorInfo optCellToOutputWeights;
1521 TensorInfo optInputLayerNormWeights;
1522 TensorInfo optForgetLayerNormWeights;
1523 TensorInfo optCellLayerNormWeights;
1524 TensorInfo optOutputLayerNormWeights;
1525
1526 if(!descriptor.m_CifgEnabled)
1527 {
1528 optInputToInputWeights =
1529 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1530 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1531
1532 optRecurrentToInputWeights =
1533 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1534 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1535 optInputGateBias =
1536 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1537 paramsInfo.m_InputGateBias = &optInputGateBias;
1538 }
1539
1540 if(descriptor.m_ProjectionEnabled)
1541 {
1542 optProjectionWeights =
1543 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1544 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1545 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1546 {
1547 optProjectionBias =
1548 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1549 paramsInfo.m_ProjectionBias = &optProjectionBias;
1550 }
1551 }
1552
1553 if(descriptor.m_PeepholeEnabled)
1554 {
1555 if(!descriptor.m_CifgEnabled)
1556 {
1557 optCellToInputWeights =
1558 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1559 dataType);
1560 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1561 }
1562 optCellToForgetWeights =
1563 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1564 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1565 optCellToOutputWeights =
1566 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1567 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1568 }
1569
1570 if(descriptor.m_LayerNormEnabled)
1571 {
1572 if (!descriptor.m_CifgEnabled)
1573 {
1574 optInputLayerNormWeights = OverrideDataType(
1575 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1576 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1577 }
1578
1579 optForgetLayerNormWeights = OverrideDataType(
1580 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1581 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1582
1583 optCellLayerNormWeights = OverrideDataType(
1584 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1585 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1586
1587 optOutputLayerNormWeights = OverrideDataType(
1588 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1589 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1590 }
1591
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001592 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1593 outputStateIn,
1594 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001595 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001596 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001597 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001598 descriptor,
1599 paramsInfo,
1600 reason);
1601 break;
1602 }
telsoa014fcda012018-03-09 14:13:49 +00001603 default:
1604 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001605 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001606 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001607 result = false;
1608 break;
1609 }
1610 }
telsoa014fcda012018-03-09 14:13:49 +00001611 return result;
1612}
1613
Sadik Armagan045f6be2020-09-10 13:37:32 +01001614bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1615 const IConnectableLayer& connectableLayer,
1616 Optional<DataType> dataType,
1617 std::string& outReasonIfUnsupported)
1618{
1619 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1620}
1621
David Beckdcb751f2018-10-03 11:42:42 +01001622bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001623 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001624 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001625{
Jan Eilersbb446e52020-04-02 13:56:54 +01001626 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001627 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1628}
1629
Sadik Armagan045f6be2020-09-10 13:37:32 +01001630bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1631 Optional<DataType> dataType,
1632 std::string& outReasonIfUnsupported,
1633 const ModelOptions& modelOptions)
1634{
1635 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1636 return IsLayerConfigurationSupported(layer->GetBackendId(),
1637 connectableLayer,
1638 dataType,
1639 outReasonIfUnsupported,
1640 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001641}
1642
Sadik Armagan04a72972020-09-14 15:44:18 +01001643bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1644 const IConnectableLayer& connectableLayer,
1645 Optional<DataType> dataType,
1646 std::string& outReasonIfUnsupported,
1647 const ModelOptions& modelOptions)
1648{
1649 return IsLayerConfigurationSupported(backendId,
1650 connectableLayer,
1651 dataType,
1652 outReasonIfUnsupported,
1653 modelOptions);
1654}
Cian McGriskin7894ef92023-08-01 14:04:09 +01001655
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001656} // namepsace armnn