blob: 1f8d4dae1d59aa41a3ec88796c5b30c1b76ede5d [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
David Beck111b5d92018-11-12 14:59:37 +000019#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000020
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
telsoa01c577f2c2018-08-31 09:22:23 +010024namespace
25{
Finn Williams3e54d032020-10-22 16:53:35 +010026using LayerList = std::list<Layer*>;
27using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010028
David Beck29c75de2018-10-23 13:35:58 +010029const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
30{
31 if (!type)
32 {
33 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010034 }
35
Matthew Sloyan81beae32021-07-13 19:46:11 +010036 return TensorInfo(info.GetShape(),
37 type.value(),
38 info.GetQuantizationScale(),
39 info.GetQuantizationOffset(),
40 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
Sadik Armagana097d2a2021-11-24 15:47:28 +000045inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
46{
47 if (!weightsType)
48 {
49 return weightsType;
50 }
51
52 switch(weightsType.value())
53 {
54 case armnn::DataType::BFloat16:
55 case armnn::DataType::Float16:
56 case armnn::DataType::Float32:
57 return weightsType;
58 case armnn::DataType::QAsymmS8:
59 case armnn::DataType::QAsymmU8:
60 case armnn::DataType::QSymmS8:
61 case armnn::DataType::QSymmS16:
62 return armnn::DataType::Signed32;
63 default:
Colm Donelanb4ef1632024-02-01 15:00:43 +000064 throw InvalidArgumentException("GetBiasTypeFromWeightsType(): Unsupported data type.");
Sadik Armagana097d2a2021-11-24 15:47:28 +000065 }
66 return armnn::EmptyOptional();
67}
68
69
Sadik Armagan045f6be2020-09-10 13:37:32 +010070bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
71 const IConnectableLayer& connectableLayer,
72 Optional<DataType> dataType,
73 std::string& outReasonIfUnsupported,
74 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000075{
David Beck33f0ae02018-10-18 15:13:56 +010076 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000077 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010078 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010079
David Beck111b5d92018-11-12 14:59:37 +000080 auto const& backendRegistry = BackendRegistryInstance();
81 if (!backendRegistry.IsBackendRegistered(backendId))
82 {
83 std::stringstream ss;
84 ss << connectableLayer.GetName() << " is not supported on " << backendId
85 << " because this backend is not registered.";
86
87 outReasonIfUnsupported = ss.str();
88 return false;
89 }
90
91 auto backendFactory = backendRegistry.GetFactory(backendId);
92 auto backendObject = backendFactory();
Mike Kelly3ec30772023-03-08 13:47:17 +000093 auto layerSupport = backendObject->GetLayerSupport(modelOptions);
94 auto layerSupportObject = LayerSupportHandle(layerSupport, backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100101 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000112 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100113 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
114 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000116 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100117 OverrideDataType(input0, dataType),
118 OverrideDataType(input1, dataType),
119 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100120 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000121 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000122 break;
123 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100124 case LayerType::ArgMinMax:
125 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100126 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100127 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
128
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100129 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000131 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000133 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100134 descriptor,
135 reason);
136 break;
137 }
Samuel Yap6b478092022-07-06 15:36:03 +0100138 case LayerType::BatchMatMul:
139 {
140 auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
141 const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
142
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100143 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
144 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Samuel Yap6b478092022-07-06 15:36:03 +0100145 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
146 result = layerSupportObject.IsBatchMatMulSupported(
147 OverrideDataType(input0, dataType),
148 OverrideDataType(input1, dataType),
149 OverrideDataType(output, dataType),
150 descriptor,
151 reason);
152 break;
153 }
telsoa014fcda012018-03-09 14:13:49 +0000154 case LayerType::BatchNormalization:
155 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100156 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100157 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
159 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
160 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
161 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
162 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100164 OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 OverrideDataType(mean, dataType),
167 OverrideDataType(var, dataType),
168 OverrideDataType(beta, dataType),
169 OverrideDataType(gamma, dataType),
170 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100171 reason);
telsoa014fcda012018-03-09 14:13:49 +0000172 break;
173 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000174 case LayerType::BatchToSpaceNd:
175 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100176 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000177 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100178 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000179
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000180 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
181 OverrideDataType(output, dataType),
182 cLayer->GetParameters(),
183 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000184 break;
185 }
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100186 case LayerType::BroadcastTo:
187 {
188 auto cLayer = PolymorphicDowncast<const BroadcastToLayer*>(&layer);
189 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
190 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
191
192 result = layerSupportObject.IsBroadcastToSupported(OverrideDataType(input, dataType),
193 OverrideDataType(output, dataType),
194 cLayer->GetParameters(),
195 reason);
196 break;
197 }
mathad01b392e982021-04-07 12:07:30 +0100198 case LayerType::Cast:
199 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100200 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
mathad01b392e982021-04-07 12:07:30 +0100201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
202
203 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
204 OverrideDataType(output, dataType),
205 reason);
206 break;
207 }
Simon Obute51f67772021-09-03 15:50:13 +0100208 case LayerType::ChannelShuffle:
209 {
210 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
211
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100212 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
213 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Simon Obute51f67772021-09-03 15:50:13 +0100214
215 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
216
217 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
218 OverrideDataType(output, dataType),
219 descriptor,
220 reason);
221 break;
222 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100223 case LayerType::Comparison:
224 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100225 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100226
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100227 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
228 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100229 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
230
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000231 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
232 OverrideDataType(input1, dataType),
233 OverrideDataType(output, DataType::Boolean),
234 cLayer->GetParameters(),
235 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100236 break;
237 }
telsoa014fcda012018-03-09 14:13:49 +0000238 case LayerType::Constant:
239 {
240 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000241 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100242 break;
243 }
244 case LayerType::ConvertFp16ToFp32:
245 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100246 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000248 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100249 break;
250 }
251 case LayerType::ConvertFp32ToFp16:
252 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100253 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100254 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000255 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000256 break;
257 }
258 case LayerType::Convolution2d:
259 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100260 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100261
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100262 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
arovir01a6824102018-08-28 17:40:45 +0100263 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100264 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Colm Donelanb4ef1632024-02-01 15:00:43 +0000265
266 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(layer.GetInputSlot(1).GetConnection(),
267 "Convolution2dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100268 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100269 dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100270
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100271 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100272
arovir01a6824102018-08-28 17:40:45 +0100273 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100274 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100275 if (descriptor.m_BiasEnabled)
276 {
Colm Donelanb4ef1632024-02-01 15:00:43 +0000277 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(layer.GetInputSlot(2).GetConnection(),
278 "Convolution2dLayer:Bias should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100279 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100280 GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100281 }
282
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000283 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100284 input,
285 output,
286 descriptor,
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100287 weights,
arovir01a6824102018-08-28 17:40:45 +0100288 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100289 reason);
telsoa014fcda012018-03-09 14:13:49 +0000290 break;
291 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100292 case LayerType::Convolution3d:
293 {
294 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
295
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100296 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100297 dataType);
298 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100299
Colm Donelanb4ef1632024-02-01 15:00:43 +0000300 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(layer.GetInputSlot(1).GetConnection(),
301 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100302 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100303 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100304
305 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
306
307 // Construct optional biases object based on the value of m_BiasEnabled
308 Optional<TensorInfo> biases;
309 if (descriptor.m_BiasEnabled)
310 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100311 biases = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100312 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100313 }
314
315 result = layerSupportObject.IsConvolution3dSupported(
316 input,
317 output,
318 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100319 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100320 biases,
321 reason);
322 break;
323 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000324 case LayerType::Debug:
325 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100326 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000327 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
328
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000329 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000330 OverrideDataType(output, dataType),
331 reason);
332 break;
333 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100334 case LayerType::DepthToSpace:
335 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100336 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100337
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100338 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100339 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
340
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000341 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100342 OverrideDataType(output, dataType),
343 cLayer->GetParameters(),
344 reason);
345 break;
346 }
telsoa014fcda012018-03-09 14:13:49 +0000347 case LayerType::DepthwiseConvolution2d:
348 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100349 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100350 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100351 dataType);
352 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100353 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100354 dataType);
355
telsoa01c577f2c2018-08-31 09:22:23 +0100356 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100357
358 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100359 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100360 if (descriptor.m_BiasEnabled)
361 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100362 biases = OverrideDataType(cLayer->GetInputSlot(2).GetTensorInfo(),
Cathal Corbett06902652022-04-14 17:55:11 +0100363 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100364 }
telsoa01c577f2c2018-08-31 09:22:23 +0100365
Cathal Corbett06902652022-04-14 17:55:11 +0100366 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
367 output,
368 descriptor,
369 weights,
370 biases,
371 reason);
telsoa014fcda012018-03-09 14:13:49 +0000372 break;
373 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000374 case LayerType::Dequantize:
375 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100376 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000377 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
378
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000379 result = layerSupportObject.IsDequantizeSupported(input,
380 OverrideDataType(output, dataType),
381 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000382 break;
383 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000384 case LayerType::DetectionPostProcess:
385 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100386 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100387 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetTensorInfo();
388 const TensorInfo& scores = layer.GetInputSlot(1).GetTensorInfo();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000389 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
390
391 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
392 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
393 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
394 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
395
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000396 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000397 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
398 scores,
399 anchors,
400 detectionBoxes,
401 detectionClasses,
402 detectionScores,
403 numDetections,
404 descriptor,
405 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000406 break;
407 }
Mike Kelly3ec30772023-03-08 13:47:17 +0000408 case LayerType::ElementwiseBinary:
409 {
410 auto cLayer = PolymorphicDowncast<const ElementwiseBinaryLayer*>(&layer);
411
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100412 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
413 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Mike Kelly3ec30772023-03-08 13:47:17 +0000414 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
415 std::vector<TensorInfo> infos = { OverrideDataType(input0, dataType),
416 OverrideDataType(input1, dataType),
417 OverrideDataType(output, dataType) };
418 result = layerSupport->IsLayerSupported(LayerType::ElementwiseBinary,
419 infos,
420 cLayer->GetParameters(),
421 EmptyOptional(),
422 EmptyOptional(),
423 reason);
424 break;
425 }
josh minor4a3c6102020-01-06 16:40:46 -0600426 case LayerType::ElementwiseUnary:
427 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100428 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600429
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100430 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
josh minor4a3c6102020-01-06 16:40:46 -0600431 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
432
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000433 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
434 OverrideDataType(output, dataType),
435 cLayer->GetParameters(),
436 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600437 break;
438 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100439 case LayerType::Fill:
440 {
441 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100442 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Ryan OSheaec6c6802020-06-05 17:17:06 +0100443 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
444 const FillDescriptor& descriptor = cLayer->GetParameters();
445
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000446 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100447 OverrideDataType(input, dataType),
448 OverrideDataType(output, dataType),
449 descriptor,
450 reason);
451 break;
452 }
telsoa014fcda012018-03-09 14:13:49 +0000453 case LayerType::FakeQuantization:
454 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100455 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100456 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000457 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
458 cLayer->GetParameters(),
459 reason);
telsoa014fcda012018-03-09 14:13:49 +0000460 break;
461 }
462 case LayerType::Floor:
463 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100464 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000465 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000466 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
467 OverrideDataType(output, dataType),
468 reason);
telsoa014fcda012018-03-09 14:13:49 +0000469 break;
470 }
471 case LayerType::FullyConnected:
472 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100473 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100474 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100475 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000476
477 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
478 TensorInfo weightsInfo;
479 const TensorInfo* weightsInfoPtr = nullptr;
480
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100481 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000482 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100483
484 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000485 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000486 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100487 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
488 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
489 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
490
telsoa01c577f2c2018-08-31 09:22:23 +0100491 if (descriptor.m_BiasEnabled)
492 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100493 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(), dataType);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100494 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100495 }
496 else
497 {
498 // If biases are not enabled pass a dummy tensorinfo for the validation
499 switch(input.GetDataType())
500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000501 case DataType::BFloat16:
502 {
503 biasInfoPtr = &dummyBFloat16Bias;
504 break;
505 }
telsoa01c577f2c2018-08-31 09:22:23 +0100506 case DataType::Float16:
507 {
508 biasInfoPtr = &dummyFloat16Bias;
509 break;
510 }
511 case DataType::Float32:
512 {
513 biasInfoPtr = &dummyFloat32Bias;
514 break;
515 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000516 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000517 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000518 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000519 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100520 {
521 biasInfoPtr = &dummyQA8Bias;
522 break;
523 }
524 default:
525 {
Colm Donelanb4ef1632024-02-01 15:00:43 +0000526 throw InvalidArgumentException("Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100527 }
528 }
529 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000530 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100531 OverrideDataType(input, dataType),
532 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000533 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100534 *biasInfoPtr,
535 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100536 reason);
telsoa014fcda012018-03-09 14:13:49 +0000537 break;
538 }
Teresa Charlin9145e382023-08-17 18:44:58 +0100539 case LayerType::Fused:
540 {
541 auto cLayer = PolymorphicDowncast<const FusedLayer*>(&layer);
542
543 // Get vector of all outputs.
544 auto getOutTensorInfo = [&dataType](const OutputSlot& slot)
545 {
546 return OverrideDataType(slot.GetTensorInfo(), dataType);
547 };
548 auto beginOutputs = MakeTransformIterator(layer.GetOutputSlots().begin(), getOutTensorInfo);
549 auto endOutputs = MakeTransformIterator(layer.GetOutputSlots().end(), getOutTensorInfo);
550 std::vector<TensorInfo> outputs(beginOutputs, endOutputs);
551 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
552
553 // Get vector of all inputs.
554 auto getInputTensorInfo = [&dataType](const InputSlot& slot)
555 {
556 return OverrideDataType(slot.GetTensorInfo(), dataType);
557 };
558 auto beginInputs = MakeTransformIterator(layer.GetInputSlots().begin(), getInputTensorInfo);
559 auto endInputs = MakeTransformIterator(layer.GetInputSlots().end(), getInputTensorInfo);
560 std::vector<TensorInfo> inputs(beginInputs, endInputs);
561 const std::vector<std::reference_wrapper<TensorInfo>> inputPtrs(inputs.begin(), inputs.end());
562
563 result = layerSupportObject.IsFusedSupported(inputPtrs,
564 outputPtrs,
565 cLayer->GetParameters(),
566 reason);
567 break;
568 }
narpra01b89b05f2019-01-16 09:53:09 +0000569 case LayerType::Gather:
570 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100571 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
572 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000573 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100574 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
575 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000576 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
577 input1,
578 OverrideDataType(output, dataType),
579 descriptor,
580 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000581 break;
582 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100583 case LayerType::GatherNd:
584 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100585 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
586 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100587 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
588 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
589 input1,
590 OverrideDataType(output, dataType),
591 reason);
592 break;
593 }
telsoa014fcda012018-03-09 14:13:49 +0000594 case LayerType::Input:
595 {
596 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000597 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000598 break;
599 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100600 case LayerType::InstanceNormalization:
601 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100602 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100603 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
604
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100605 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Mayce5045a2019-10-02 14:07:47 +0100606 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
607
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000608 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100609 OverrideDataType(input, dataType),
610 OverrideDataType(output, dataType),
611 descriptor,
612 reason);
613 break;
614 }
telsoa014fcda012018-03-09 14:13:49 +0000615 case LayerType::L2Normalization:
616 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100617 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100618 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
619
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100620 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100621 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100622
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000623 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100624 OverrideDataType(input, dataType),
625 OverrideDataType(output, dataType),
626 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100627 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100628 break;
629 }
James Conroyaba90cd2020-11-06 16:28:18 +0000630 case LayerType::LogicalBinary:
631 {
632 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
633
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100634 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
635 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
James Conroyaba90cd2020-11-06 16:28:18 +0000636 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
637
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000638 result = layerSupportObject.IsLogicalBinarySupported(input0,
639 input1,
640 output,
641 cLayer->GetParameters(),
642 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000643 break;
644 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100645 case LayerType::LogSoftmax:
646 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100647 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100648
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100649 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100650 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
651
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000652 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
653 OverrideDataType(output, dataType),
654 cLayer->GetParameters(),
655 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100656 break;
657 }
telsoa01c577f2c2018-08-31 09:22:23 +0100658 case LayerType::Lstm:
659 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100660 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100661 const LstmDescriptor& descriptor = cLayer->GetParameters();
662
663 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100664 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100665 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100666 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100667 dataType);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100668 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
telsoa01c577f2c2018-08-31 09:22:23 +0100669 dataType);
670 // All outputs
671 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
672 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
673 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
674 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
675
676 // Basic parameters
677 const TensorInfo& inputToForgetWeights
678 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
679 const TensorInfo& inputToCellWeights
680 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
681 const TensorInfo& inputToOutputWeights
682 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
683 const TensorInfo& recurrentToForgetWeights
684 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
685 const TensorInfo& recurrentToCellWeights
686 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
687 const TensorInfo& recurrentToOutputWeights
688 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
689 const TensorInfo& forgetGateBias
690 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
691 const TensorInfo& cellBias
692 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
693 const TensorInfo& outputGateBias
694 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
695
Jan Eilersd01a83c2019-07-03 18:20:40 +0100696 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100697
Jan Eilersd01a83c2019-07-03 18:20:40 +0100698 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
699 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
700 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
701 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
702 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
703 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
704 paramsInfo.m_ForgetGateBias = &forgetGateBias;
705 paramsInfo.m_CellBias = &cellBias;
706 paramsInfo.m_OutputGateBias = &outputGateBias;
707
708
709 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100710 TensorInfo optInputToInputWeights;
711 TensorInfo optRecurrentToInputWeights;
712 TensorInfo optCellToInputWeights;
713 TensorInfo optInputGateBias;
714 TensorInfo optProjectionWeights;
715 TensorInfo optProjectionBias;
716 TensorInfo optCellToForgetWeights;
717 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100718 TensorInfo optInputLayerNormWeights;
719 TensorInfo optForgetLayerNormWeights;
720 TensorInfo optCellLayerNormWeights;
721 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100722
723 if(!descriptor.m_CifgEnabled)
724 {
725 optInputToInputWeights =
726 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100727 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100728
729 optRecurrentToInputWeights =
730 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100731 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100732 optInputGateBias =
733 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100734 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100735 }
736
737 if(descriptor.m_ProjectionEnabled)
738 {
739 optProjectionWeights =
740 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100741 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100742 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
743 {
744 optProjectionBias =
745 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100746 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100747 }
748 }
749
750 if(descriptor.m_PeepholeEnabled)
751 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100752 if(!descriptor.m_CifgEnabled)
753 {
754 optCellToInputWeights =
755 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
756 dataType);
757 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
758 }
telsoa01c577f2c2018-08-31 09:22:23 +0100759 optCellToForgetWeights =
760 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100761 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100762 optCellToOutputWeights =
763 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100764 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100765 }
766
Jan Eilers38e05bd2019-06-26 13:10:09 +0100767 if(descriptor.m_LayerNormEnabled)
768 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100769 if (!descriptor.m_CifgEnabled)
770 {
771 optInputLayerNormWeights = OverrideDataType(
772 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
773 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
774 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100775
776 optForgetLayerNormWeights = OverrideDataType(
777 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100778 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100779
780 optCellLayerNormWeights = OverrideDataType(
781 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100782 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100783
784 optOutputLayerNormWeights = OverrideDataType(
785 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100786 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100787 }
788
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000789 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100790 input,
791 outputStateIn,
792 cellStateIn,
793 scratchBuffer,
794 outputStateOut,
795 cellStateOut,
796 output,
797 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100798 paramsInfo,
799 reason);
telsoa014fcda012018-03-09 14:13:49 +0000800 break;
801 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000802 case LayerType::Maximum:
803 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000804 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100805 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
806 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000807 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
808
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000809 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
810 OverrideDataType(input1, dataType),
811 OverrideDataType(output, dataType),
812 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000813 ARMNN_NO_DEPRECATE_WARN_END
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000814 break;
815 }
narpra01b89b05f2019-01-16 09:53:09 +0000816 case LayerType::MemCopy:
817 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100818 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra01b89b05f2019-01-16 09:53:09 +0000819 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000820
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000821 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
822 OverrideDataType(output, dataType),
823 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000824 break;
825 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100826 case LayerType::MemImport:
827 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100828 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertif674aa02019-08-01 15:56:25 +0100829 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
830
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000831 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
832 OverrideDataType(output, dataType),
833 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100834 break;
835 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100836 case LayerType::Merge:
837 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100838 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
839 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100840 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
841
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000842 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
843 OverrideDataType(input1, dataType),
844 OverrideDataType(output, dataType),
845 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100846 break;
847 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100848 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000849 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100850 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000851
telsoa01c577f2c2018-08-31 09:22:23 +0100852 // Get vector of all inputs.
853 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000854 {
telsoa01c577f2c2018-08-31 09:22:23 +0100855 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000856 };
Finn Williams3e54d032020-10-22 16:53:35 +0100857
858 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
859 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100860 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000861
telsoa01c577f2c2018-08-31 09:22:23 +0100862 auto getTensorInfoPtr = [](const TensorInfo& info)
863 {
864 return &info;
865 };
Finn Williams3e54d032020-10-22 16:53:35 +0100866
867 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
868 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100869 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000870
Nikhil Raj8599a412018-11-19 14:51:07 +0000871 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
872
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000873 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100874
875
telsoa014fcda012018-03-09 14:13:49 +0000876 break;
877 }
878 case LayerType::Multiplication:
879 {
Mike Kelly3ec30772023-03-08 13:47:17 +0000880 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100881 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
882 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100883 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000884 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100885 OverrideDataType(input0, dataType),
886 OverrideDataType(input1, dataType),
887 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100888 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +0000889 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000890 break;
891 }
892 case LayerType::Normalization:
893 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100894 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100895 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000896 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000897 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
898 OverrideDataType(output, dataType),
899 cLayer->GetParameters(),
900 reason);
telsoa014fcda012018-03-09 14:13:49 +0000901 break;
902 }
903 case LayerType::Output:
904 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100905 const TensorInfo& output = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000906 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000907 break;
908 }
909 case LayerType::Permute:
910 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100911 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100912 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000913 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000914 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
915 OverrideDataType(output, dataType),
916 cLayer->GetParameters(),
917 reason);
telsoa014fcda012018-03-09 14:13:49 +0000918 break;
919 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100920 case LayerType::Pad:
921 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100922 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100923 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100924 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000925 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100926 OverrideDataType(input, dataType),
927 OverrideDataType(output, dataType),
928 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100929 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100930 break;
931 }
telsoa014fcda012018-03-09 14:13:49 +0000932 case LayerType::Pooling2d:
933 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100934 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100935 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000936 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000937 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
938 OverrideDataType(output, dataType),
939 cLayer->GetParameters(),
940 reason);
telsoa014fcda012018-03-09 14:13:49 +0000941 break;
942 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100943 case LayerType::Pooling3d:
944 {
945 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100946 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100947 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
948 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
949 OverrideDataType(output, dataType),
950 cLayer->GetParameters(),
951 reason);
952 break;
953 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000954 case LayerType::PreCompiled:
955 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100956 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100957 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000958 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
959 cLayer->GetParameters(),
960 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000961 break;
962 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000963 case LayerType::Quantize:
964 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100965 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000966 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000967 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000968 break;
969 }
James Conroy586a9aa2020-03-20 08:49:33 +0000970 case LayerType::QLstm:
971 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100972 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000973 const QLstmDescriptor& descriptor = cLayer->GetParameters();
974
975 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +0100976 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
977 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetTensorInfo();
978 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroy586a9aa2020-03-20 08:49:33 +0000979
980 // Outputs
981 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
982 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
983 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
984
985 // Lstm parameters
986 LstmInputParamsInfo paramsInfo;
987
988 // Basic parameters
989 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
990 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
991 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
992
993 paramsInfo.m_RecurrentToForgetWeights =
994 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
995 paramsInfo.m_RecurrentToCellWeights =
996 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
997 paramsInfo.m_RecurrentToOutputWeights =
998 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
999
1000 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
1001 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
1002 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
1003
1004 if(!descriptor.m_CifgEnabled)
1005 {
1006 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
1007 paramsInfo.m_RecurrentToInputWeights =
1008 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
1009 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
1010 }
1011
1012 if(descriptor.m_ProjectionEnabled)
1013 {
1014 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +01001015
1016 // Projection bias is optional even if projection is enabled
1017 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1018 {
1019 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
1020 }
James Conroy586a9aa2020-03-20 08:49:33 +00001021 }
1022
1023 if(descriptor.m_PeepholeEnabled)
1024 {
1025 if (!descriptor.m_CifgEnabled)
1026 {
1027 paramsInfo.m_CellToInputWeights =
1028 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
1029 }
1030
1031 paramsInfo.m_CellToForgetWeights =
1032 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
1033 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
1034 }
1035
1036 if(descriptor.m_LayerNormEnabled)
1037 {
1038 if (!descriptor.m_CifgEnabled)
1039 {
1040 paramsInfo.m_InputLayerNormWeights =
1041 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
1042 }
1043
1044 paramsInfo.m_ForgetLayerNormWeights =
1045 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
1046 paramsInfo.m_CellLayerNormWeights =
1047 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
1048 paramsInfo.m_OutputLayerNormWeights =
1049 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
1050 }
1051
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001052 result = layerSupportObject.IsQLstmSupported(input,
1053 previousOutputIn,
1054 previousCellStateIn,
1055 outputStateOut,
1056 cellStateOut,
1057 output,
1058 descriptor,
1059 paramsInfo,
1060 reason);
James Conroy586a9aa2020-03-20 08:49:33 +00001061 break;
1062 }
James Conroyee18dc82019-07-17 11:27:46 +01001063 case LayerType::QuantizedLstm:
1064 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001065 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +01001066
1067 // Inputs
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001068 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1069 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetTensorInfo();
1070 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001071
1072 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001073 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1074 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001075
1076 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001077 QuantizedLstmInputParamsInfo paramsInfo;
1078
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001079 paramsInfo.m_InputToInputWeights =
1080 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1081 paramsInfo.m_InputToForgetWeights =
1082 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1083 paramsInfo.m_InputToCellWeights =
1084 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1085 paramsInfo.m_InputToOutputWeights =
1086 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001087
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001088 paramsInfo.m_RecurrentToInputWeights =
1089 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1090 paramsInfo.m_RecurrentToForgetWeights =
1091 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1092 paramsInfo.m_RecurrentToCellWeights =
1093 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1094 paramsInfo.m_RecurrentToOutputWeights =
1095 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001096
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001097 paramsInfo.m_InputGateBias =
1098 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1099 paramsInfo.m_ForgetGateBias =
1100 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1101 paramsInfo.m_CellBias =
1102 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1103 paramsInfo.m_OutputGateBias =
1104 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001105
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001106 result = layerSupportObject.IsQuantizedLstmSupported(input,
1107 previousCellStateIn,
1108 previousOutputIn,
1109 cellStateOut,
1110 output,
1111 paramsInfo,
1112 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001113 break;
1114 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001115 case LayerType::Division:
1116 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001117 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001118 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1119 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001120 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001121 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001122 OverrideDataType(input0, dataType),
1123 OverrideDataType(input1, dataType),
1124 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001125 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001126 ARMNN_NO_DEPRECATE_WARN_END
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001127 break;
1128 }
Finn Williams2605b232020-06-10 15:53:46 +01001129 case LayerType::Rank:
1130 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001131 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Finn Williams2605b232020-06-10 15:53:46 +01001132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001133 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1134 OverrideDataType(output, dataType),
1135 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001136 break;
1137 }
telsoa014fcda012018-03-09 14:13:49 +00001138 case LayerType::Reshape:
1139 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001140 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001141 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001142 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001143 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1144 OverrideDataType(output, dataType),
1145 cLayer->GetParameters(),
1146 reason);
telsoa014fcda012018-03-09 14:13:49 +00001147 break;
1148 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001149 case LayerType::Resize:
1150 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001151 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001152 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001153 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001154 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1155 OverrideDataType(output, dataType),
1156 cLayer->GetParameters(),
1157 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001158 break;
1159 }
Tianle Cheng988354d2023-06-28 13:20:47 +01001160 case LayerType::ReverseV2:
1161 {
Tracy Narinebb8d7592023-07-13 16:50:54 +01001162 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1163 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
Tianle Cheng988354d2023-06-28 13:20:47 +01001164 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Tracy Narinebb8d7592023-07-13 16:50:54 +01001165 result = layerSupportObject.IsReverseV2Supported(OverrideDataType(input0, dataType),
1166 OverrideDataType(input1, armnn::DataType::Signed32),
Tianle Cheng988354d2023-06-28 13:20:47 +01001167 OverrideDataType(output, dataType),
Tianle Cheng988354d2023-06-28 13:20:47 +01001168 reason);
1169 break;
1170 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001171 case LayerType::Shape:
1172 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001173 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Keith Davis3ae3f972021-05-21 16:33:48 +01001174 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1175
1176 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1177 OverrideDataType(output, dataType),
1178 reason);
1179 break;
1180 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001181 case LayerType::Slice:
1182 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001183 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001184
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001185 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1187
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001188 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1189 OverrideDataType(output, dataType),
1190 cLayer->GetParameters(),
1191 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001192 break;
1193 }
telsoa014fcda012018-03-09 14:13:49 +00001194 case LayerType::Softmax:
1195 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001196 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001197 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001198 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001199 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1200 OverrideDataType(output, dataType),
1201 cLayer->GetParameters(),
1202 reason);
telsoa014fcda012018-03-09 14:13:49 +00001203 break;
1204 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001205 case LayerType::SpaceToBatchNd:
1206 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001207 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001208 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001209 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001210 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1211 OverrideDataType(output, dataType),
1212 cLayer->GetParameters(),
1213 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001214 break;
1215 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001216 case LayerType::SpaceToDepth:
1217 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001218 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001219
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001220 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1222
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001223 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1224 OverrideDataType(output, dataType),
1225 cLayer->GetParameters(),
1226 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001227 break;
1228 }
telsoa014fcda012018-03-09 14:13:49 +00001229 case LayerType::Splitter:
1230 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001231 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001232 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001233
1234 // Get vector of all outputs.
1235 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1236 {
1237 return OverrideDataType(slot.GetTensorInfo(), dataType);
1238 };
Finn Williams3e54d032020-10-22 16:53:35 +01001239 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1240 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001241 std::vector<TensorInfo> outputs(beginI, endI);
1242
1243 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1244
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001245 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1246 outputPtrs,
1247 cLayer->GetParameters(),
1248 reason);
telsoa014fcda012018-03-09 14:13:49 +00001249 break;
1250 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001251 case LayerType::Stack:
1252 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001253 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001254
1255 // Get vector of all inputs.
1256 auto getTensorInfo = [&dataType](const InputSlot& slot)
1257 {
1258 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1259 };
Finn Williams3e54d032020-10-22 16:53:35 +01001260 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1261 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001262 std::vector<TensorInfo> inputs(beginI, endI);
1263
1264 auto getTensorInfoPtr = [](const TensorInfo& info)
1265 {
1266 return &info;
1267 };
Finn Williams3e54d032020-10-22 16:53:35 +01001268 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1269 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001270 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1271
1272 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1273
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001274 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001275
1276 break;
1277 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001278 case LayerType::StandIn:
1279 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001280 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001281
1282 // Get vector of all inputs.
1283 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1284 {
1285 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1286 };
1287 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1288 {
1289 return OverrideDataType(slot.GetTensorInfo(), dataType);
1290 };
Finn Williams3e54d032020-10-22 16:53:35 +01001291 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1292 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001293 std::vector<TensorInfo> inputs(beginI, endI);
1294
Finn Williams3e54d032020-10-22 16:53:35 +01001295 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1296 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001297 std::vector<TensorInfo> outputs(beginO, endO);
1298
1299
1300 auto getTensorInfoPtr = [](const TensorInfo& info)
1301 {
1302 return &info;
1303 };
Finn Williams3e54d032020-10-22 16:53:35 +01001304 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1305 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001306 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1307
Finn Williams3e54d032020-10-22 16:53:35 +01001308 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1309 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001310 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1311
1312
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001313 result = layerSupportObject.IsStandInSupported(inputPtrs,
1314 outputPtrs,
1315 cLayer->GetParameters(),
1316 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001317 break;
1318 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001319 case LayerType::StridedSlice:
1320 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001321 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001322 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Conor Kennedy430b5d82018-11-14 15:28:28 +00001323 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001324 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1325 OverrideDataType(output, dataType),
1326 cLayer->GetParameters(),
1327 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001328 break;
1329 }
David Beckc2044fe2018-09-05 15:00:38 +01001330 case LayerType::Subtraction:
1331 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001332 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001333 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1334 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
David Beckc2044fe2018-09-05 15:00:38 +01001335 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001336 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001337 OverrideDataType(input0, dataType),
1338 OverrideDataType(input1, dataType),
1339 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001340 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001341 ARMNN_NO_DEPRECATE_WARN_END
David Beckc2044fe2018-09-05 15:00:38 +01001342 break;
1343 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001344 case LayerType::Switch:
1345 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001346 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1347 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
Sadik Armaganeff363d2019-04-05 15:25:46 +01001348 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1349 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001350 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1351 OverrideDataType(input1, dataType),
1352 OverrideDataType(output0, dataType),
1353 OverrideDataType(output1, dataType),
1354 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001355 break;
1356 }
narpra0132b90462018-09-13 11:07:48 +01001357 case LayerType::Mean:
1358 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001359 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001360 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
narpra0132b90462018-09-13 11:07:48 +01001361 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001362 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001363 OverrideDataType(input, dataType),
1364 OverrideDataType(output, dataType),
1365 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001366 reason);
narpra0132b90462018-09-13 11:07:48 +01001367 break;
1368 }
kevmay0190539692018-11-29 08:40:19 +00001369 case LayerType::Minimum:
1370 {
Mike Kelly3ec30772023-03-08 13:47:17 +00001371 ARMNN_NO_DEPRECATE_WARN_BEGIN
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001372 const TensorInfo& input0 = layer.GetInputSlot(0).GetTensorInfo();
1373 const TensorInfo& input1 = layer.GetInputSlot(1).GetTensorInfo();
kevmay0190539692018-11-29 08:40:19 +00001374 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001375 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1376 OverrideDataType(input1, dataType),
1377 OverrideDataType(output, dataType),
1378 reason);
Mike Kelly3ec30772023-03-08 13:47:17 +00001379 ARMNN_NO_DEPRECATE_WARN_END
kevmay0190539692018-11-29 08:40:19 +00001380 break;
1381 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001382 case LayerType::Prelu:
1383 {
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001384 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1385 const TensorInfo& alpha = layer.GetInputSlot(1).GetTensorInfo();
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001386 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001387 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1388 OverrideDataType(alpha, dataType),
1389 OverrideDataType(output, dataType),
1390 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001391 break;
1392 }
Teresa Charlin79a06a52023-07-13 17:16:45 +01001393 case LayerType::Tile:
1394 {
1395 auto cLayer = PolymorphicDowncast<const TileLayer*>(&layer);
1396 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
1397 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1398
1399 result = layerSupportObject.IsTileSupported(OverrideDataType(input, dataType),
1400 OverrideDataType(output, dataType),
1401 cLayer->GetParameters(),
1402 reason);
1403
1404 break;
1405 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001406 case LayerType::Transpose:
1407 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001408 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001409 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001410 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001411 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1412 OverrideDataType(output, dataType),
1413 cLayer->GetParameters(),
1414 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001415 break;
1416 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001417 case LayerType::TransposeConvolution2d:
1418 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001419 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001420
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001421 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001422 dataType);
1423 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1424
1425 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1426
1427 Optional<TensorInfo> biases;
1428 if (descriptor.m_BiasEnabled)
1429 {
Colm Donelanb4ef1632024-02-01 15:00:43 +00001430 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(
1431 cLayer->m_Bias.get() != nullptr,
1432 "TransposeConvolution2d: Bias was enabled in the descriptor but no value was supplied.");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001433 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1434 GetBiasTypeFromWeightsType(dataType));
1435 }
1436
Colm Donelanb4ef1632024-02-01 15:00:43 +00001437 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(cLayer->m_Weight.get() != nullptr,
1438 "TransposeConvolution2d: Weights cannot be null.");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001439 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1440
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001441 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1442 output,
1443 descriptor,
1444 weights,
1445 biases,
1446 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001447
1448 break;
1449 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001450 case LayerType::Reduce:
1451 {
1452 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001453 const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001454 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1455
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001456 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1457 OverrideDataType(output, dataType),
1458 cLayer->GetParameters(),
1459 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001460 break;
1461 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001462 case LayerType::UnidirectionalSequenceLstm:
1463 {
1464 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1465 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1466
1467 // All inputs.
Mike Kellya9ac6ba2023-06-30 15:18:26 +01001468 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001469 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001470 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001471 dataType);
Mike Kelly4cc341c2023-07-07 15:43:06 +01001472 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001473 dataType);
1474 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001475 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1476 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1477 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001478
1479 // Basic parameters
1480 const TensorInfo& inputToForgetWeights
1481 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1482 const TensorInfo& inputToCellWeights
1483 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1484 const TensorInfo& inputToOutputWeights
1485 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1486 const TensorInfo& recurrentToForgetWeights
1487 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1488 const TensorInfo& recurrentToCellWeights
1489 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1490 const TensorInfo& recurrentToOutputWeights
1491 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1492 const TensorInfo& forgetGateBias
1493 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1494 const TensorInfo& cellBias
1495 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1496 const TensorInfo& outputGateBias
1497 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1498
1499 LstmInputParamsInfo paramsInfo;
1500
1501 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1502 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1503 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1504 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1505 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1506 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1507 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1508 paramsInfo.m_CellBias = &cellBias;
1509 paramsInfo.m_OutputGateBias = &outputGateBias;
1510
1511 // Optional parameters
1512 TensorInfo optInputToInputWeights;
1513 TensorInfo optRecurrentToInputWeights;
1514 TensorInfo optCellToInputWeights;
1515 TensorInfo optInputGateBias;
1516 TensorInfo optProjectionWeights;
1517 TensorInfo optProjectionBias;
1518 TensorInfo optCellToForgetWeights;
1519 TensorInfo optCellToOutputWeights;
1520 TensorInfo optInputLayerNormWeights;
1521 TensorInfo optForgetLayerNormWeights;
1522 TensorInfo optCellLayerNormWeights;
1523 TensorInfo optOutputLayerNormWeights;
1524
1525 if(!descriptor.m_CifgEnabled)
1526 {
1527 optInputToInputWeights =
1528 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1529 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1530
1531 optRecurrentToInputWeights =
1532 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1533 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1534 optInputGateBias =
1535 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1536 paramsInfo.m_InputGateBias = &optInputGateBias;
1537 }
1538
1539 if(descriptor.m_ProjectionEnabled)
1540 {
1541 optProjectionWeights =
1542 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1543 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1544 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1545 {
1546 optProjectionBias =
1547 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1548 paramsInfo.m_ProjectionBias = &optProjectionBias;
1549 }
1550 }
1551
1552 if(descriptor.m_PeepholeEnabled)
1553 {
1554 if(!descriptor.m_CifgEnabled)
1555 {
1556 optCellToInputWeights =
1557 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1558 dataType);
1559 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1560 }
1561 optCellToForgetWeights =
1562 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1563 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1564 optCellToOutputWeights =
1565 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1566 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1567 }
1568
1569 if(descriptor.m_LayerNormEnabled)
1570 {
1571 if (!descriptor.m_CifgEnabled)
1572 {
1573 optInputLayerNormWeights = OverrideDataType(
1574 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1575 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1576 }
1577
1578 optForgetLayerNormWeights = OverrideDataType(
1579 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1580 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1581
1582 optCellLayerNormWeights = OverrideDataType(
1583 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1584 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1585
1586 optOutputLayerNormWeights = OverrideDataType(
1587 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1588 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1589 }
1590
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001591 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1592 outputStateIn,
1593 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001594 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001595 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001596 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001597 descriptor,
1598 paramsInfo,
1599 reason);
1600 break;
1601 }
telsoa014fcda012018-03-09 14:13:49 +00001602 default:
1603 {
David Beck33f0ae02018-10-18 15:13:56 +01001604 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001605 result = false;
1606 break;
1607 }
1608 }
telsoa014fcda012018-03-09 14:13:49 +00001609 return result;
1610}
1611
Sadik Armagan045f6be2020-09-10 13:37:32 +01001612bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1613 const IConnectableLayer& connectableLayer,
1614 Optional<DataType> dataType,
1615 std::string& outReasonIfUnsupported)
1616{
1617 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1618}
1619
David Beckdcb751f2018-10-03 11:42:42 +01001620bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001621 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001622 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001623{
Jan Eilersbb446e52020-04-02 13:56:54 +01001624 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001625 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1626}
1627
Sadik Armagan045f6be2020-09-10 13:37:32 +01001628bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1629 Optional<DataType> dataType,
1630 std::string& outReasonIfUnsupported,
1631 const ModelOptions& modelOptions)
1632{
1633 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1634 return IsLayerConfigurationSupported(layer->GetBackendId(),
1635 connectableLayer,
1636 dataType,
1637 outReasonIfUnsupported,
1638 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001639}
1640
Sadik Armagan04a72972020-09-14 15:44:18 +01001641bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1642 const IConnectableLayer& connectableLayer,
1643 Optional<DataType> dataType,
1644 std::string& outReasonIfUnsupported,
1645 const ModelOptions& modelOptions)
1646{
1647 return IsLayerConfigurationSupported(backendId,
1648 connectableLayer,
1649 dataType,
1650 outReasonIfUnsupported,
1651 modelOptions);
1652}
Cian McGriskin7894ef92023-08-01 14:04:09 +01001653
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001654} // namepsace armnn