blob: f624ee60218c3fa4e579641d20309bab661b65a9 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
18#include <armnn/backends/TensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
David Beck111b5d92018-11-12 14:59:37 +000020#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000021
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025namespace
26{
Finn Williams3e54d032020-10-22 16:53:35 +010027using LayerList = std::list<Layer*>;
28using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010029
David Beck29c75de2018-10-23 13:35:58 +010030const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
31{
32 if (!type)
33 {
34 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010035 }
36
Matthew Sloyan81beae32021-07-13 19:46:11 +010037 return TensorInfo(info.GetShape(),
38 type.value(),
39 info.GetQuantizationScale(),
40 info.GetQuantizationOffset(),
41 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
Sadik Armagana097d2a2021-11-24 15:47:28 +000046inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
47{
48 if (!weightsType)
49 {
50 return weightsType;
51 }
52
53 switch(weightsType.value())
54 {
55 case armnn::DataType::BFloat16:
56 case armnn::DataType::Float16:
57 case armnn::DataType::Float32:
58 return weightsType;
59 case armnn::DataType::QAsymmS8:
60 case armnn::DataType::QAsymmU8:
61 case armnn::DataType::QSymmS8:
62 case armnn::DataType::QSymmS16:
63 return armnn::DataType::Signed32;
64 default:
65 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
66 }
67 return armnn::EmptyOptional();
68}
69
70
Sadik Armagan045f6be2020-09-10 13:37:32 +010071bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
72 const IConnectableLayer& connectableLayer,
73 Optional<DataType> dataType,
74 std::string& outReasonIfUnsupported,
75 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000076{
David Beck33f0ae02018-10-18 15:13:56 +010077 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000078 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010079 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010080
David Beck111b5d92018-11-12 14:59:37 +000081 auto const& backendRegistry = BackendRegistryInstance();
82 if (!backendRegistry.IsBackendRegistered(backendId))
83 {
84 std::stringstream ss;
85 ss << connectableLayer.GetName() << " is not supported on " << backendId
86 << " because this backend is not registered.";
87
88 outReasonIfUnsupported = ss.str();
89 return false;
90 }
91
92 auto backendFactory = backendRegistry.GetFactory(backendId);
93 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000094 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000101 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
112 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
113 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
114 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000115 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100116 OverrideDataType(input0, dataType),
117 OverrideDataType(input1, dataType),
118 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100119 reason);
telsoa014fcda012018-03-09 14:13:49 +0000120 break;
121 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100122 case LayerType::ArgMinMax:
123 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100124 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100125 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
126
127 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000129 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000131 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 descriptor,
133 reason);
134 break;
135 }
telsoa014fcda012018-03-09 14:13:49 +0000136 case LayerType::BatchNormalization:
137 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100138 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000139 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100140 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
141 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
142 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
143 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
144 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000145 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100146 OverrideDataType(input, dataType),
147 OverrideDataType(output, dataType),
148 OverrideDataType(mean, dataType),
149 OverrideDataType(var, dataType),
150 OverrideDataType(beta, dataType),
151 OverrideDataType(gamma, dataType),
152 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100153 reason);
telsoa014fcda012018-03-09 14:13:49 +0000154 break;
155 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000156 case LayerType::BatchToSpaceNd:
157 {
158 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
159 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100160 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000161
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000162 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
163 OverrideDataType(output, dataType),
164 cLayer->GetParameters(),
165 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000166 break;
167 }
mathad01b392e982021-04-07 12:07:30 +0100168 case LayerType::Cast:
169 {
170 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
171 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
172
173 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
174 OverrideDataType(output, dataType),
175 reason);
176 break;
177 }
Simon Obute51f67772021-09-03 15:50:13 +0100178 case LayerType::ChannelShuffle:
179 {
180 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
181
182 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
183 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
184
185 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
186
187 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
188 OverrideDataType(output, dataType),
189 descriptor,
190 reason);
191 break;
192 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100193 case LayerType::Comparison:
194 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100195 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100196
197 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
198 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
199 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
200
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000201 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
202 OverrideDataType(input1, dataType),
203 OverrideDataType(output, DataType::Boolean),
204 cLayer->GetParameters(),
205 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100206 break;
207 }
telsoa014fcda012018-03-09 14:13:49 +0000208 case LayerType::Constant:
209 {
210 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000211 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100212 break;
213 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000214 case LayerType::ConvertBf16ToFp32:
215 {
216 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
217 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000218 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000219 break;
220 }
telsoa01c577f2c2018-08-31 09:22:23 +0100221 case LayerType::ConvertFp16ToFp32:
222 {
223 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
224 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000225 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100226 break;
227 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000228 case LayerType::ConvertFp32ToBf16:
229 {
230 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
231 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000232 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000233 break;
234 }
telsoa01c577f2c2018-08-31 09:22:23 +0100235 case LayerType::ConvertFp32ToFp16:
236 {
237 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000239 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000240 break;
241 }
242 case LayerType::Convolution2d:
243 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100244 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100245
246 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
247 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100248 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100249 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100250
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100251 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100252
arovir01a6824102018-08-28 17:40:45 +0100253 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100254 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100255 if (descriptor.m_BiasEnabled)
256 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100257 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100258 }
259
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000260 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100261 input,
262 output,
263 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100264 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100265 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100266 reason);
telsoa014fcda012018-03-09 14:13:49 +0000267 break;
268 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100269 case LayerType::Convolution3d:
270 {
271 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
272
273 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
274 dataType);
275 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100276
277 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
278 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
279 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
280 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100281
282 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
283
284 // Construct optional biases object based on the value of m_BiasEnabled
285 Optional<TensorInfo> biases;
286 if (descriptor.m_BiasEnabled)
287 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100288 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
289 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100290 }
291
292 result = layerSupportObject.IsConvolution3dSupported(
293 input,
294 output,
295 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100296 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100297 biases,
298 reason);
299 break;
300 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000301 case LayerType::Debug:
302 {
303 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
304 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
305
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000306 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000307 OverrideDataType(output, dataType),
308 reason);
309 break;
310 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100311 case LayerType::DepthToSpace:
312 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100313 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100314
315 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
316 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
317
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000318 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100319 OverrideDataType(output, dataType),
320 cLayer->GetParameters(),
321 reason);
322 break;
323 }
telsoa014fcda012018-03-09 14:13:49 +0000324 case LayerType::DepthwiseConvolution2d:
325 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100326 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Cathal Corbett06902652022-04-14 17:55:11 +0100327 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
328 dataType);
329 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
330 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
331 dataType);
332
333 ARMNN_ASSERT(cLayer->GetInputSlot(1).GetConnection() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100334
telsoa01c577f2c2018-08-31 09:22:23 +0100335 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100336
337 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100338 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100339 if (descriptor.m_BiasEnabled)
340 {
Cathal Corbett06902652022-04-14 17:55:11 +0100341 biases = OverrideDataType(cLayer->GetInputSlot(2).GetConnection()->GetTensorInfo(),
342 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100343 }
telsoa01c577f2c2018-08-31 09:22:23 +0100344
Cathal Corbett06902652022-04-14 17:55:11 +0100345 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
346 output,
347 descriptor,
348 weights,
349 biases,
350 reason);
telsoa014fcda012018-03-09 14:13:49 +0000351 break;
352 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000353 case LayerType::Dequantize:
354 {
355 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
356 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
357
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000358 result = layerSupportObject.IsDequantizeSupported(input,
359 OverrideDataType(output, dataType),
360 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000361 break;
362 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000363 case LayerType::DetectionPostProcess:
364 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100365 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000366 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
367 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
368 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
369
370 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
371 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
372 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
373 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
374
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000375 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000376 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
377 scores,
378 anchors,
379 detectionBoxes,
380 detectionClasses,
381 detectionScores,
382 numDetections,
383 descriptor,
384 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000385 break;
386 }
josh minor4a3c6102020-01-06 16:40:46 -0600387 case LayerType::ElementwiseUnary:
388 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100389 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600390
391 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
392 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
393
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000394 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
395 OverrideDataType(output, dataType),
396 cLayer->GetParameters(),
397 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600398 break;
399 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100400 case LayerType::Fill:
401 {
402 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
403 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
404 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
405 const FillDescriptor& descriptor = cLayer->GetParameters();
406
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000407 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100408 OverrideDataType(input, dataType),
409 OverrideDataType(output, dataType),
410 descriptor,
411 reason);
412 break;
413 }
telsoa014fcda012018-03-09 14:13:49 +0000414 case LayerType::FakeQuantization:
415 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100416 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000417 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000418 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
419 cLayer->GetParameters(),
420 reason);
telsoa014fcda012018-03-09 14:13:49 +0000421 break;
422 }
423 case LayerType::Floor:
424 {
425 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
426 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000427 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
428 OverrideDataType(output, dataType),
429 reason);
telsoa014fcda012018-03-09 14:13:49 +0000430 break;
431 }
432 case LayerType::FullyConnected:
433 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100434 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000435 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100436 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000437
438 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
439 TensorInfo weightsInfo;
440 const TensorInfo* weightsInfoPtr = nullptr;
441
Matthew Sloyan81beae32021-07-13 19:46:11 +0100442 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000443 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100444
445 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000446 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000447 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100448 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
449 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
450 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
451
telsoa01c577f2c2018-08-31 09:22:23 +0100452 if (descriptor.m_BiasEnabled)
453 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100454 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
455 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100456 }
457 else
458 {
459 // If biases are not enabled pass a dummy tensorinfo for the validation
460 switch(input.GetDataType())
461 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000462 case DataType::BFloat16:
463 {
464 biasInfoPtr = &dummyBFloat16Bias;
465 break;
466 }
telsoa01c577f2c2018-08-31 09:22:23 +0100467 case DataType::Float16:
468 {
469 biasInfoPtr = &dummyFloat16Bias;
470 break;
471 }
472 case DataType::Float32:
473 {
474 biasInfoPtr = &dummyFloat32Bias;
475 break;
476 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000477 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000478 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000479 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000480 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100481 {
482 biasInfoPtr = &dummyQA8Bias;
483 break;
484 }
485 default:
486 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100487 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100488 }
489 }
490 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000491 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100492 OverrideDataType(input, dataType),
493 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000494 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100495 *biasInfoPtr,
496 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100497 reason);
telsoa014fcda012018-03-09 14:13:49 +0000498 break;
499 }
narpra01b89b05f2019-01-16 09:53:09 +0000500 case LayerType::Gather:
501 {
502 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
503 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
504 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100505 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
506 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000507 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
508 input1,
509 OverrideDataType(output, dataType),
510 descriptor,
511 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000512 break;
513 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100514 case LayerType::GatherNd:
515 {
516 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
517 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
518 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
519 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
520 input1,
521 OverrideDataType(output, dataType),
522 reason);
523 break;
524 }
telsoa014fcda012018-03-09 14:13:49 +0000525 case LayerType::Input:
526 {
527 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000528 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000529 break;
530 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100531 case LayerType::InstanceNormalization:
532 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100533 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100534 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
535
536 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
537 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
538
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000539 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100540 OverrideDataType(input, dataType),
541 OverrideDataType(output, dataType),
542 descriptor,
543 reason);
544 break;
545 }
telsoa014fcda012018-03-09 14:13:49 +0000546 case LayerType::L2Normalization:
547 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100548 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100549 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
550
telsoa014fcda012018-03-09 14:13:49 +0000551 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100552 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100553
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000554 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100555 OverrideDataType(input, dataType),
556 OverrideDataType(output, dataType),
557 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100558 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100559 break;
560 }
James Conroyaba90cd2020-11-06 16:28:18 +0000561 case LayerType::LogicalBinary:
562 {
563 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
564
565 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
566 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
567 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
568
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000569 result = layerSupportObject.IsLogicalBinarySupported(input0,
570 input1,
571 output,
572 cLayer->GetParameters(),
573 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000574 break;
575 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100576 case LayerType::LogSoftmax:
577 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100578 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100579
580 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
581 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
582
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000583 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
584 OverrideDataType(output, dataType),
585 cLayer->GetParameters(),
586 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100587 break;
588 }
telsoa01c577f2c2018-08-31 09:22:23 +0100589 case LayerType::Lstm:
590 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100591 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100592 const LstmDescriptor& descriptor = cLayer->GetParameters();
593
594 // All inputs.
595 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
596 dataType);
597 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
598 dataType);
599 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
600 dataType);
601 // All outputs
602 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
603 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
604 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
605 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
606
607 // Basic parameters
608 const TensorInfo& inputToForgetWeights
609 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
610 const TensorInfo& inputToCellWeights
611 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
612 const TensorInfo& inputToOutputWeights
613 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
614 const TensorInfo& recurrentToForgetWeights
615 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
616 const TensorInfo& recurrentToCellWeights
617 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
618 const TensorInfo& recurrentToOutputWeights
619 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
620 const TensorInfo& forgetGateBias
621 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
622 const TensorInfo& cellBias
623 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
624 const TensorInfo& outputGateBias
625 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
626
Jan Eilersd01a83c2019-07-03 18:20:40 +0100627 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100628
Jan Eilersd01a83c2019-07-03 18:20:40 +0100629 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
630 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
631 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
632 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
633 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
634 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
635 paramsInfo.m_ForgetGateBias = &forgetGateBias;
636 paramsInfo.m_CellBias = &cellBias;
637 paramsInfo.m_OutputGateBias = &outputGateBias;
638
639
640 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100641 TensorInfo optInputToInputWeights;
642 TensorInfo optRecurrentToInputWeights;
643 TensorInfo optCellToInputWeights;
644 TensorInfo optInputGateBias;
645 TensorInfo optProjectionWeights;
646 TensorInfo optProjectionBias;
647 TensorInfo optCellToForgetWeights;
648 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100649 TensorInfo optInputLayerNormWeights;
650 TensorInfo optForgetLayerNormWeights;
651 TensorInfo optCellLayerNormWeights;
652 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100653
654 if(!descriptor.m_CifgEnabled)
655 {
656 optInputToInputWeights =
657 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100658 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100659
660 optRecurrentToInputWeights =
661 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100662 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100663 optInputGateBias =
664 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100665 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100666 }
667
668 if(descriptor.m_ProjectionEnabled)
669 {
670 optProjectionWeights =
671 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100672 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100673 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
674 {
675 optProjectionBias =
676 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100677 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100678 }
679 }
680
681 if(descriptor.m_PeepholeEnabled)
682 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100683 if(!descriptor.m_CifgEnabled)
684 {
685 optCellToInputWeights =
686 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
687 dataType);
688 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
689 }
telsoa01c577f2c2018-08-31 09:22:23 +0100690 optCellToForgetWeights =
691 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100692 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100693 optCellToOutputWeights =
694 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100695 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100696 }
697
Jan Eilers38e05bd2019-06-26 13:10:09 +0100698 if(descriptor.m_LayerNormEnabled)
699 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100700 if (!descriptor.m_CifgEnabled)
701 {
702 optInputLayerNormWeights = OverrideDataType(
703 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
704 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
705 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100706
707 optForgetLayerNormWeights = OverrideDataType(
708 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100709 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100710
711 optCellLayerNormWeights = OverrideDataType(
712 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100713 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100714
715 optOutputLayerNormWeights = OverrideDataType(
716 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100717 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100718 }
719
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000720 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100721 input,
722 outputStateIn,
723 cellStateIn,
724 scratchBuffer,
725 outputStateOut,
726 cellStateOut,
727 output,
728 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100729 paramsInfo,
730 reason);
telsoa014fcda012018-03-09 14:13:49 +0000731 break;
732 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000733 case LayerType::Maximum:
734 {
735 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
736 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
737 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
738
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000739 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
740 OverrideDataType(input1, dataType),
741 OverrideDataType(output, dataType),
742 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000743 break;
744 }
narpra01b89b05f2019-01-16 09:53:09 +0000745 case LayerType::MemCopy:
746 {
747 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
748 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000749
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000750 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
751 OverrideDataType(output, dataType),
752 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000753 break;
754 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100755 case LayerType::MemImport:
756 {
757 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
759
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000760 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
761 OverrideDataType(output, dataType),
762 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100763 break;
764 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100765 case LayerType::Merge:
766 {
767 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
768 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
769 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
770
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000771 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
772 OverrideDataType(input1, dataType),
773 OverrideDataType(output, dataType),
774 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100775 break;
776 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100777 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000778 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100779 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000780
telsoa01c577f2c2018-08-31 09:22:23 +0100781 // Get vector of all inputs.
782 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000783 {
telsoa01c577f2c2018-08-31 09:22:23 +0100784 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000785 };
Finn Williams3e54d032020-10-22 16:53:35 +0100786
787 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
788 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100789 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000790
telsoa01c577f2c2018-08-31 09:22:23 +0100791 auto getTensorInfoPtr = [](const TensorInfo& info)
792 {
793 return &info;
794 };
Finn Williams3e54d032020-10-22 16:53:35 +0100795
796 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
797 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100798 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000799
Nikhil Raj8599a412018-11-19 14:51:07 +0000800 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
801
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000802 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100803
804
telsoa014fcda012018-03-09 14:13:49 +0000805 break;
806 }
807 case LayerType::Multiplication:
808 {
809 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
810 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100811 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000812 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100813 OverrideDataType(input0, dataType),
814 OverrideDataType(input1, dataType),
815 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100816 reason);
telsoa014fcda012018-03-09 14:13:49 +0000817 break;
818 }
819 case LayerType::Normalization:
820 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100821 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000822 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
823 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000824 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
825 OverrideDataType(output, dataType),
826 cLayer->GetParameters(),
827 reason);
telsoa014fcda012018-03-09 14:13:49 +0000828 break;
829 }
830 case LayerType::Output:
831 {
832 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000833 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000834 break;
835 }
836 case LayerType::Permute:
837 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100838 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000839 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
840 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000841 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
842 OverrideDataType(output, dataType),
843 cLayer->GetParameters(),
844 reason);
telsoa014fcda012018-03-09 14:13:49 +0000845 break;
846 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100847 case LayerType::Pad:
848 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100849 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100850 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
851 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000852 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100853 OverrideDataType(input, dataType),
854 OverrideDataType(output, dataType),
855 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100856 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100857 break;
858 }
telsoa014fcda012018-03-09 14:13:49 +0000859 case LayerType::Pooling2d:
860 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100861 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000862 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
863 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000864 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
865 OverrideDataType(output, dataType),
866 cLayer->GetParameters(),
867 reason);
telsoa014fcda012018-03-09 14:13:49 +0000868 break;
869 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100870 case LayerType::Pooling3d:
871 {
872 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
873 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
874 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
875 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
876 OverrideDataType(output, dataType),
877 cLayer->GetParameters(),
878 reason);
879 break;
880 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000881 case LayerType::PreCompiled:
882 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100883 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000884 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000885 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
886 cLayer->GetParameters(),
887 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000888 break;
889 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000890 case LayerType::Quantize:
891 {
892 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
893 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000894 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000895 break;
896 }
James Conroy586a9aa2020-03-20 08:49:33 +0000897 case LayerType::QLstm:
898 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100899 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000900 const QLstmDescriptor& descriptor = cLayer->GetParameters();
901
902 // Inputs
903 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
904 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
905 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
906
907 // Outputs
908 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
909 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
910 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
911
912 // Lstm parameters
913 LstmInputParamsInfo paramsInfo;
914
915 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100916 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
917 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
918 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000919 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
920 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
921 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
922
923 paramsInfo.m_RecurrentToForgetWeights =
924 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
925 paramsInfo.m_RecurrentToCellWeights =
926 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
927 paramsInfo.m_RecurrentToOutputWeights =
928 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
929
930 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
931 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
932 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
933
934 if(!descriptor.m_CifgEnabled)
935 {
936 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
937 paramsInfo.m_RecurrentToInputWeights =
938 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
939 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
940 }
941
942 if(descriptor.m_ProjectionEnabled)
943 {
944 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100945
946 // Projection bias is optional even if projection is enabled
947 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
948 {
949 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
950 }
James Conroy586a9aa2020-03-20 08:49:33 +0000951 }
952
953 if(descriptor.m_PeepholeEnabled)
954 {
955 if (!descriptor.m_CifgEnabled)
956 {
957 paramsInfo.m_CellToInputWeights =
958 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
959 }
960
961 paramsInfo.m_CellToForgetWeights =
962 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
963 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
964 }
965
966 if(descriptor.m_LayerNormEnabled)
967 {
968 if (!descriptor.m_CifgEnabled)
969 {
970 paramsInfo.m_InputLayerNormWeights =
971 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
972 }
973
974 paramsInfo.m_ForgetLayerNormWeights =
975 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
976 paramsInfo.m_CellLayerNormWeights =
977 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
978 paramsInfo.m_OutputLayerNormWeights =
979 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
980 }
981
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000982 result = layerSupportObject.IsQLstmSupported(input,
983 previousOutputIn,
984 previousCellStateIn,
985 outputStateOut,
986 cellStateOut,
987 output,
988 descriptor,
989 paramsInfo,
990 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000991 break;
992 }
James Conroyee18dc82019-07-17 11:27:46 +0100993 case LayerType::QuantizedLstm:
994 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100995 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100996
997 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100998 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
999 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1000 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001001
1002 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001003 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1004 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001005
1006 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001007 QuantizedLstmInputParamsInfo paramsInfo;
1008
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001009 paramsInfo.m_InputToInputWeights =
1010 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1011 paramsInfo.m_InputToForgetWeights =
1012 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1013 paramsInfo.m_InputToCellWeights =
1014 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1015 paramsInfo.m_InputToOutputWeights =
1016 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001017
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001018 paramsInfo.m_RecurrentToInputWeights =
1019 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1020 paramsInfo.m_RecurrentToForgetWeights =
1021 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1022 paramsInfo.m_RecurrentToCellWeights =
1023 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1024 paramsInfo.m_RecurrentToOutputWeights =
1025 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001026
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001027 paramsInfo.m_InputGateBias =
1028 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1029 paramsInfo.m_ForgetGateBias =
1030 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1031 paramsInfo.m_CellBias =
1032 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1033 paramsInfo.m_OutputGateBias =
1034 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001035
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001036 result = layerSupportObject.IsQuantizedLstmSupported(input,
1037 previousCellStateIn,
1038 previousOutputIn,
1039 cellStateOut,
1040 output,
1041 paramsInfo,
1042 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001043 break;
1044 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001045 case LayerType::Division:
1046 {
1047 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1048 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1049 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001050 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001051 OverrideDataType(input0, dataType),
1052 OverrideDataType(input1, dataType),
1053 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001054 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001055 break;
1056 }
Finn Williams2605b232020-06-10 15:53:46 +01001057 case LayerType::Rank:
1058 {
1059 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1060 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001061 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1062 OverrideDataType(output, dataType),
1063 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001064 break;
1065 }
telsoa014fcda012018-03-09 14:13:49 +00001066 case LayerType::Reshape:
1067 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001068 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001069 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001070 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001071 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1072 OverrideDataType(output, dataType),
1073 cLayer->GetParameters(),
1074 reason);
telsoa014fcda012018-03-09 14:13:49 +00001075 break;
1076 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001077 case LayerType::Resize:
1078 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001079 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001080 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001081 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001082 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1083 OverrideDataType(output, dataType),
1084 cLayer->GetParameters(),
1085 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001086 break;
1087 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001088 case LayerType::Shape:
1089 {
1090 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1091 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1092
1093 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1094 OverrideDataType(output, dataType),
1095 reason);
1096 break;
1097 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001098 case LayerType::Slice:
1099 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001100 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001101
1102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1103 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1104
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001105 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1106 OverrideDataType(output, dataType),
1107 cLayer->GetParameters(),
1108 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001109 break;
1110 }
telsoa014fcda012018-03-09 14:13:49 +00001111 case LayerType::Softmax:
1112 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001113 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001114 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001115 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001116 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1117 OverrideDataType(output, dataType),
1118 cLayer->GetParameters(),
1119 reason);
telsoa014fcda012018-03-09 14:13:49 +00001120 break;
1121 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001122 case LayerType::SpaceToBatchNd:
1123 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001124 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001127 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1128 OverrideDataType(output, dataType),
1129 cLayer->GetParameters(),
1130 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001131 break;
1132 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001133 case LayerType::SpaceToDepth:
1134 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001135 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001136
1137 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1138 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1139
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001140 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1141 OverrideDataType(output, dataType),
1142 cLayer->GetParameters(),
1143 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001144 break;
1145 }
telsoa014fcda012018-03-09 14:13:49 +00001146 case LayerType::Splitter:
1147 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001148 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001149 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001150
1151 // Get vector of all outputs.
1152 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1153 {
1154 return OverrideDataType(slot.GetTensorInfo(), dataType);
1155 };
Finn Williams3e54d032020-10-22 16:53:35 +01001156 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1157 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001158 std::vector<TensorInfo> outputs(beginI, endI);
1159
1160 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1161
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001162 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1163 outputPtrs,
1164 cLayer->GetParameters(),
1165 reason);
telsoa014fcda012018-03-09 14:13:49 +00001166 break;
1167 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001168 case LayerType::Stack:
1169 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001170 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001171
1172 // Get vector of all inputs.
1173 auto getTensorInfo = [&dataType](const InputSlot& slot)
1174 {
1175 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1176 };
Finn Williams3e54d032020-10-22 16:53:35 +01001177 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1178 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001179 std::vector<TensorInfo> inputs(beginI, endI);
1180
1181 auto getTensorInfoPtr = [](const TensorInfo& info)
1182 {
1183 return &info;
1184 };
Finn Williams3e54d032020-10-22 16:53:35 +01001185 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1186 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001187 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1188
1189 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1190
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001191 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001192
1193 break;
1194 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001195 case LayerType::StandIn:
1196 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001197 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001198
1199 // Get vector of all inputs.
1200 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1201 {
1202 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1203 };
1204 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1205 {
1206 return OverrideDataType(slot.GetTensorInfo(), dataType);
1207 };
Finn Williams3e54d032020-10-22 16:53:35 +01001208 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1209 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001210 std::vector<TensorInfo> inputs(beginI, endI);
1211
Finn Williams3e54d032020-10-22 16:53:35 +01001212 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1213 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001214 std::vector<TensorInfo> outputs(beginO, endO);
1215
1216
1217 auto getTensorInfoPtr = [](const TensorInfo& info)
1218 {
1219 return &info;
1220 };
Finn Williams3e54d032020-10-22 16:53:35 +01001221 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1222 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001223 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1224
Finn Williams3e54d032020-10-22 16:53:35 +01001225 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1226 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001227 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1228
1229
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001230 result = layerSupportObject.IsStandInSupported(inputPtrs,
1231 outputPtrs,
1232 cLayer->GetParameters(),
1233 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001234 break;
1235 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001236 case LayerType::StridedSlice:
1237 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001238 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001239 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1240 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001241 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1242 OverrideDataType(output, dataType),
1243 cLayer->GetParameters(),
1244 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001245 break;
1246 }
David Beckc2044fe2018-09-05 15:00:38 +01001247 case LayerType::Subtraction:
1248 {
1249 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1250 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1251 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001252 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001253 OverrideDataType(input0, dataType),
1254 OverrideDataType(input1, dataType),
1255 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001256 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001257 break;
1258 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001259 case LayerType::Switch:
1260 {
1261 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1262 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1263 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1264 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001265 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1266 OverrideDataType(input1, dataType),
1267 OverrideDataType(output0, dataType),
1268 OverrideDataType(output1, dataType),
1269 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001270 break;
1271 }
narpra0132b90462018-09-13 11:07:48 +01001272 case LayerType::Mean:
1273 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001274 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001275 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001277 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001278 OverrideDataType(input, dataType),
1279 OverrideDataType(output, dataType),
1280 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001281 reason);
narpra0132b90462018-09-13 11:07:48 +01001282 break;
1283 }
kevmay0190539692018-11-29 08:40:19 +00001284 case LayerType::Minimum:
1285 {
1286 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1287 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1288 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001289 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1290 OverrideDataType(input1, dataType),
1291 OverrideDataType(output, dataType),
1292 reason);
kevmay0190539692018-11-29 08:40:19 +00001293 break;
1294 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001295 case LayerType::Prelu:
1296 {
1297 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1298 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1299 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001300 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1301 OverrideDataType(alpha, dataType),
1302 OverrideDataType(output, dataType),
1303 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001304 break;
1305 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001306 case LayerType::Transpose:
1307 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001308 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001309 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1310 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001311 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1312 OverrideDataType(output, dataType),
1313 cLayer->GetParameters(),
1314 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001315 break;
1316 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001317 case LayerType::TransposeConvolution2d:
1318 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001319 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001320
1321 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1322 dataType);
1323 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1324
1325 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1326
1327 Optional<TensorInfo> biases;
1328 if (descriptor.m_BiasEnabled)
1329 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001330 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001331 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1332 GetBiasTypeFromWeightsType(dataType));
1333 }
1334
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001335 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001336 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1337
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001338 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1339 output,
1340 descriptor,
1341 weights,
1342 biases,
1343 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001344
1345 break;
1346 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001347 case LayerType::Reduce:
1348 {
1349 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1350 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1351 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1352
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001353 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1354 OverrideDataType(output, dataType),
1355 cLayer->GetParameters(),
1356 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001357 break;
1358 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001359 case LayerType::UnidirectionalSequenceLstm:
1360 {
1361 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1362 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1363
1364 // All inputs.
1365 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1366 dataType);
1367 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1368 dataType);
1369 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1370 dataType);
1371 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001372 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1373 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1374 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001375
1376 // Basic parameters
1377 const TensorInfo& inputToForgetWeights
1378 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1379 const TensorInfo& inputToCellWeights
1380 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1381 const TensorInfo& inputToOutputWeights
1382 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1383 const TensorInfo& recurrentToForgetWeights
1384 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1385 const TensorInfo& recurrentToCellWeights
1386 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1387 const TensorInfo& recurrentToOutputWeights
1388 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1389 const TensorInfo& forgetGateBias
1390 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1391 const TensorInfo& cellBias
1392 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1393 const TensorInfo& outputGateBias
1394 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1395
1396 LstmInputParamsInfo paramsInfo;
1397
1398 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1399 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1400 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1401 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1402 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1403 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1404 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1405 paramsInfo.m_CellBias = &cellBias;
1406 paramsInfo.m_OutputGateBias = &outputGateBias;
1407
1408 // Optional parameters
1409 TensorInfo optInputToInputWeights;
1410 TensorInfo optRecurrentToInputWeights;
1411 TensorInfo optCellToInputWeights;
1412 TensorInfo optInputGateBias;
1413 TensorInfo optProjectionWeights;
1414 TensorInfo optProjectionBias;
1415 TensorInfo optCellToForgetWeights;
1416 TensorInfo optCellToOutputWeights;
1417 TensorInfo optInputLayerNormWeights;
1418 TensorInfo optForgetLayerNormWeights;
1419 TensorInfo optCellLayerNormWeights;
1420 TensorInfo optOutputLayerNormWeights;
1421
1422 if(!descriptor.m_CifgEnabled)
1423 {
1424 optInputToInputWeights =
1425 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1426 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1427
1428 optRecurrentToInputWeights =
1429 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1430 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1431 optInputGateBias =
1432 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1433 paramsInfo.m_InputGateBias = &optInputGateBias;
1434 }
1435
1436 if(descriptor.m_ProjectionEnabled)
1437 {
1438 optProjectionWeights =
1439 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1440 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1441 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1442 {
1443 optProjectionBias =
1444 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1445 paramsInfo.m_ProjectionBias = &optProjectionBias;
1446 }
1447 }
1448
1449 if(descriptor.m_PeepholeEnabled)
1450 {
1451 if(!descriptor.m_CifgEnabled)
1452 {
1453 optCellToInputWeights =
1454 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1455 dataType);
1456 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1457 }
1458 optCellToForgetWeights =
1459 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1460 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1461 optCellToOutputWeights =
1462 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1463 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1464 }
1465
1466 if(descriptor.m_LayerNormEnabled)
1467 {
1468 if (!descriptor.m_CifgEnabled)
1469 {
1470 optInputLayerNormWeights = OverrideDataType(
1471 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1472 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1473 }
1474
1475 optForgetLayerNormWeights = OverrideDataType(
1476 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1477 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1478
1479 optCellLayerNormWeights = OverrideDataType(
1480 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1481 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1482
1483 optOutputLayerNormWeights = OverrideDataType(
1484 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1485 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1486 }
1487
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001488 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1489 outputStateIn,
1490 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001491 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001492 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001493 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001494 descriptor,
1495 paramsInfo,
1496 reason);
1497 break;
1498 }
telsoa014fcda012018-03-09 14:13:49 +00001499 default:
1500 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001501 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001502 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001503 result = false;
1504 break;
1505 }
1506 }
telsoa014fcda012018-03-09 14:13:49 +00001507 return result;
1508}
1509
Sadik Armagan045f6be2020-09-10 13:37:32 +01001510bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1511 const IConnectableLayer& connectableLayer,
1512 Optional<DataType> dataType,
1513 std::string& outReasonIfUnsupported)
1514{
1515 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1516}
1517
David Beckdcb751f2018-10-03 11:42:42 +01001518bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001519 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001520 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001521{
Jan Eilersbb446e52020-04-02 13:56:54 +01001522 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001523 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1524}
1525
1526// TODO merge with defaulted modelOptions above
1527bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1528 Optional<DataType> dataType,
1529 std::string& outReasonIfUnsupported,
1530 const ModelOptions& modelOptions)
1531{
1532 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1533 return IsLayerConfigurationSupported(layer->GetBackendId(),
1534 connectableLayer,
1535 dataType,
1536 outReasonIfUnsupported,
1537 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001538}
1539
Sadik Armagan04a72972020-09-14 15:44:18 +01001540bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1541 const IConnectableLayer& connectableLayer,
1542 Optional<DataType> dataType,
1543 std::string& outReasonIfUnsupported,
1544 const ModelOptions& modelOptions)
1545{
1546 return IsLayerConfigurationSupported(backendId,
1547 connectableLayer,
1548 dataType,
1549 outReasonIfUnsupported,
1550 modelOptions);
1551}
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001552ARMNN_NO_DEPRECATE_WARN_BEGIN
1553std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type,
1554 const QueueDescriptor& descriptor,
1555 const WorkloadInfo& info) const
1556{
1557 switch(type)
1558 {
1559 case LayerType::Activation :
1560 {
1561 auto activationQueueDescriptor = PolymorphicDowncast<const ActivationQueueDescriptor*>(&descriptor);
1562 return CreateActivation(*activationQueueDescriptor, info);
1563 }
1564 case LayerType::Addition :
1565 {
1566 auto additionQueueDescriptor = PolymorphicDowncast<const AdditionQueueDescriptor*>(&descriptor);
1567 return CreateAddition(*additionQueueDescriptor, info);
1568 }
1569 case LayerType::ArgMinMax :
1570 {
1571 auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
1572 return CreateArgMinMax(*argMinMaxQueueDescriptor, info);
1573 }
1574 case LayerType::BatchNormalization :
1575 {
1576 auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
1577 return CreateBatchNormalization(*batchNormQueueDescriptor, info);
1578 }
1579 case LayerType::BatchToSpaceNd :
1580 {
1581 auto batchToSpaceNdQueueDescriptor
1582 = PolymorphicDowncast<const BatchToSpaceNdQueueDescriptor*>(&descriptor);
1583 return CreateBatchToSpaceNd(*batchToSpaceNdQueueDescriptor, info);
1584 }
1585 case LayerType::Cast :
1586 {
1587 auto castQueueDescriptor = PolymorphicDowncast<const CastQueueDescriptor*>(&descriptor);
1588 return CreateCast(*castQueueDescriptor, info);
1589 }
1590 case LayerType::ChannelShuffle :
1591 {
1592 auto channelShuffleQueueDescriptor
1593 = PolymorphicDowncast<const ChannelShuffleQueueDescriptor*>(&descriptor);
1594 return CreateChannelShuffle(*channelShuffleQueueDescriptor, info);
1595 }
1596 case LayerType::Comparison :
1597 {
1598 auto comparisonQueueDescriptor = PolymorphicDowncast<const ComparisonQueueDescriptor*>(&descriptor);
1599 return CreateComparison(*comparisonQueueDescriptor, info);
1600 }
1601 case LayerType::Concat :
1602 {
1603 auto concatQueueDescriptor = PolymorphicDowncast<const ConcatQueueDescriptor*>(&descriptor);
1604 return CreateConcat(*concatQueueDescriptor, info);
1605 }
1606 case LayerType::Constant :
1607 {
1608 auto constantQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor);
1609 return CreateConstant(*constantQueueDescriptor, info);
1610 }
1611 case LayerType::ConvertBf16ToFp32 :
1612 {
1613 auto convertBf16ToFp32QueueDescriptor
1614 = PolymorphicDowncast<const ConvertBf16ToFp32QueueDescriptor*>(&descriptor);
1615 return CreateConvertBf16ToFp32(*convertBf16ToFp32QueueDescriptor, info);
1616 }
1617 case LayerType::ConvertFp16ToFp32:
1618 {
1619 auto convertFp16ToFp32QueueDescriptor
1620 = PolymorphicDowncast<const ConvertFp16ToFp32QueueDescriptor*>(&descriptor);
1621 return CreateConvertFp16ToFp32(*convertFp16ToFp32QueueDescriptor, info);
1622 }
1623 case LayerType::ConvertFp32ToBf16:
1624 {
1625 auto convertFp32ToBf16QueueDescriptor
1626 = PolymorphicDowncast<const ConvertFp32ToBf16QueueDescriptor*>(&descriptor);
1627 return CreateConvertFp32ToBf16(*convertFp32ToBf16QueueDescriptor, info);
1628 }
1629 case LayerType::ConvertFp32ToFp16:
1630 {
1631 auto convertFp32ToFp16QueueDescriptor
1632 = PolymorphicDowncast<const ConvertFp32ToFp16QueueDescriptor*>(&descriptor);
1633 return CreateConvertFp32ToFp16(*convertFp32ToFp16QueueDescriptor, info);
1634 }
1635 case LayerType::Convolution2d:
1636 {
1637 auto convolution2dQueueDescriptor = PolymorphicDowncast<const Convolution2dQueueDescriptor*>(&descriptor);
1638 return CreateConvolution2d(*convolution2dQueueDescriptor, info);
1639 }
1640 case LayerType::Convolution3d:
1641 {
1642 auto convolution3dQueueDescriptor = PolymorphicDowncast<const Convolution3dQueueDescriptor*>(&descriptor);
1643 return CreateConvolution3d(*convolution3dQueueDescriptor, info);
1644 }
1645 case LayerType::Debug:
1646 {
1647 auto debugQueueDescriptor = PolymorphicDowncast<const DebugQueueDescriptor*>(&descriptor);
1648 return CreateDebug(*debugQueueDescriptor, info);
1649 }
1650 case LayerType::DepthToSpace:
1651 {
1652 auto depthToSpaceQueueDescriptor = PolymorphicDowncast<const DepthToSpaceQueueDescriptor*>(&descriptor);
1653 return CreateDepthToSpace(*depthToSpaceQueueDescriptor, info);
1654 }
1655 case LayerType::DepthwiseConvolution2d:
1656 {
1657 auto depthwiseConvolution2DQueueDescriptor
1658 = PolymorphicDowncast<const DepthwiseConvolution2dQueueDescriptor*>(&descriptor);
1659 return CreateDepthwiseConvolution2d(*depthwiseConvolution2DQueueDescriptor, info);
1660 }
1661 case LayerType::Dequantize:
1662 {
1663 auto dequantizeQueueDescriptor = PolymorphicDowncast<const DequantizeQueueDescriptor*>(&descriptor);
1664 return CreateDequantize(*dequantizeQueueDescriptor, info);
1665 }
1666 case LayerType::DetectionPostProcess:
1667 {
1668 auto detectionPostProcessQueueDescriptor
1669 = PolymorphicDowncast<const DetectionPostProcessQueueDescriptor*>(&descriptor);
1670 return CreateDetectionPostProcess(*detectionPostProcessQueueDescriptor, info);
1671 }
1672 case LayerType::Division:
1673 {
1674 auto divisionQueueDescriptor = PolymorphicDowncast<const DivisionQueueDescriptor*>(&descriptor);
1675 return CreateDivision(*divisionQueueDescriptor, info);
1676 }
1677 case LayerType::ElementwiseUnary:
1678 {
1679 auto elementwiseUnaryQueueDescriptor
1680 = PolymorphicDowncast<const ElementwiseUnaryQueueDescriptor*>(&descriptor);
1681 return CreateElementwiseUnary(*elementwiseUnaryQueueDescriptor, info);
1682
1683 }
1684 case LayerType::FakeQuantization:
1685 {
1686 auto fakeQuantizationQueueDescriptor
1687 = PolymorphicDowncast<const FakeQuantizationQueueDescriptor*>(&descriptor);
1688 return CreateFakeQuantization(*fakeQuantizationQueueDescriptor, info);
1689 }
1690 case LayerType::Fill:
1691 {
1692 auto fillQueueDescriptor = PolymorphicDowncast<const FillQueueDescriptor*>(&descriptor);
1693 return CreateFill(*fillQueueDescriptor, info);
1694 }
1695 case LayerType::Floor:
1696 {
1697 auto floorQueueDescriptor = PolymorphicDowncast<const FloorQueueDescriptor*>(&descriptor);
1698 return CreateFloor(*floorQueueDescriptor, info);
1699 }
1700 case LayerType::FullyConnected:
1701 {
1702 auto fullyConnectedQueueDescriptor
1703 = PolymorphicDowncast<const FullyConnectedQueueDescriptor*>(&descriptor);
1704 return CreateFullyConnected(*fullyConnectedQueueDescriptor, info);
1705 }
1706 case LayerType::Gather:
1707 {
1708 auto gatherQueueDescriptor = PolymorphicDowncast<const GatherQueueDescriptor*>(&descriptor);
1709 return CreateGather(*gatherQueueDescriptor, info);
1710 }
1711 case LayerType::Input:
1712 {
1713 auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor);
1714 return CreateInput(*inputQueueDescriptor, info);
1715 }
1716 case LayerType::InstanceNormalization:
1717 {
1718 auto instanceNormalizationQueueDescriptor
1719 = PolymorphicDowncast<const InstanceNormalizationQueueDescriptor*>(&descriptor);
1720 return CreateInstanceNormalization(*instanceNormalizationQueueDescriptor, info);
1721 }
1722 case LayerType::L2Normalization:
1723 {
1724 auto l2NormalizationQueueDescriptor
1725 = PolymorphicDowncast<const L2NormalizationQueueDescriptor*>(&descriptor);
1726 return CreateL2Normalization(*l2NormalizationQueueDescriptor, info);
1727 }
1728 case LayerType::LogicalBinary:
1729 {
1730 auto logicalBinaryQueueDescriptor = PolymorphicDowncast<const LogicalBinaryQueueDescriptor*>(&descriptor);
1731 return CreateLogicalBinary(*logicalBinaryQueueDescriptor, info);
1732 }
1733 case LayerType::LogSoftmax:
1734 {
1735 auto logSoftmaxQueueDescriptor = PolymorphicDowncast<const LogSoftmaxQueueDescriptor*>(&descriptor);
1736 return CreateLogSoftmax(*logSoftmaxQueueDescriptor, info);
1737 }
1738 case LayerType::Lstm:
1739 {
1740 auto lstmQueueDescriptor = PolymorphicDowncast<const LstmQueueDescriptor*>(&descriptor);
1741 return CreateLstm(*lstmQueueDescriptor, info);
1742 }
1743 case LayerType::Maximum:
1744 {
1745 auto maximumQueueDescriptor = PolymorphicDowncast<const MaximumQueueDescriptor*>(&descriptor);
1746 return CreateMaximum(*maximumQueueDescriptor, info);
1747 }
1748 case LayerType::Mean:
1749 {
1750 auto meanQueueDescriptor = PolymorphicDowncast<const MeanQueueDescriptor*>(&descriptor);
1751 return CreateMean(*meanQueueDescriptor, info);
1752 }
1753 case LayerType::MemCopy:
1754 {
1755 auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
1756 return CreateMemCopy(*memCopyQueueDescriptor, info);
1757 }
1758 case LayerType::MemImport:
1759 {
1760 auto memImportQueueDescriptor = PolymorphicDowncast<const MemImportQueueDescriptor*>(&descriptor);
1761 return CreateMemImport(*memImportQueueDescriptor, info);
1762 }
1763 case LayerType::Minimum:
1764 {
1765 auto minimumQueueDescriptor = PolymorphicDowncast<const MinimumQueueDescriptor*>(&descriptor);
1766 return CreateMinimum(*minimumQueueDescriptor, info);
1767 }
1768 case LayerType::Multiplication:
1769 {
1770 auto multiplicationQueueDescriptor
1771 = PolymorphicDowncast<const MultiplicationQueueDescriptor*>(&descriptor);
1772 return CreateMultiplication(*multiplicationQueueDescriptor, info);
1773 }
1774 case LayerType::Normalization:
1775 {
1776 auto normalizationQueueDescriptor = PolymorphicDowncast<const NormalizationQueueDescriptor*>(&descriptor);
1777 return CreateNormalization(*normalizationQueueDescriptor, info);
1778 }
1779 case LayerType::Output:
1780 {
1781 auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor);
1782 return CreateOutput(*outputQueueDescriptor, info);
1783 }
1784 case LayerType::Pad:
1785 {
1786 auto padQueueDescriptor = PolymorphicDowncast<const PadQueueDescriptor*>(&descriptor);
1787 return CreatePad(*padQueueDescriptor, info);
1788 }
1789 case LayerType::Permute:
1790 {
1791 auto permuteQueueDescriptor = PolymorphicDowncast<const PermuteQueueDescriptor*>(&descriptor);
1792 return CreatePermute(*permuteQueueDescriptor, info);
1793 }
1794 case LayerType::Pooling2d:
1795 {
1796 auto pooling2dQueueDescriptor = PolymorphicDowncast<const Pooling2dQueueDescriptor*>(&descriptor);
1797 return CreatePooling2d(*pooling2dQueueDescriptor, info);
1798 }
1799 case LayerType::Pooling3d:
1800 {
1801 auto pooling3dQueueDescriptor = PolymorphicDowncast<const Pooling3dQueueDescriptor*>(&descriptor);
1802 return CreatePooling3d(*pooling3dQueueDescriptor, info);
1803 }
1804 case LayerType::PreCompiled:
1805 {
1806 auto preCompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
1807 return CreatePreCompiled(*preCompiledQueueDescriptor, info);
1808 }
1809 case LayerType::Prelu:
1810 {
1811 auto preluQueueDescriptor = PolymorphicDowncast<const PreluQueueDescriptor*>(&descriptor);
1812 return CreatePrelu(*preluQueueDescriptor, info);
1813 }
1814 case LayerType::QLstm:
1815 {
1816 auto qlstmQueueDescriptor = PolymorphicDowncast<const QLstmQueueDescriptor*>(&descriptor);
1817 return CreateQLstm(*qlstmQueueDescriptor, info);
1818 }
1819 case LayerType::Quantize:
1820 {
1821 auto quantizeQueueDescriptor = PolymorphicDowncast<const QuantizeQueueDescriptor*>(&descriptor);
1822 return CreateQuantize(*quantizeQueueDescriptor, info);
1823 }
1824 case LayerType::Rank:
1825 {
1826 auto rankQueueDescriptor = PolymorphicDowncast<const RankQueueDescriptor*>(&descriptor);
1827 return CreateRank(*rankQueueDescriptor, info);
1828 }
1829 case LayerType::Reduce:
1830 {
1831 auto reduceQueueDescriptor = PolymorphicDowncast<const ReduceQueueDescriptor*>(&descriptor);
1832 return CreateReduce(*reduceQueueDescriptor, info);
1833 }
1834 case LayerType::Reshape:
1835 {
1836 auto reshapeQueueDescriptor = PolymorphicDowncast<const ReshapeQueueDescriptor*>(&descriptor);
1837 return CreateReshape(*reshapeQueueDescriptor, info);
1838 }
1839 case LayerType::Resize:
1840 {
1841 auto resizeQueueDescriptor = PolymorphicDowncast<const ResizeQueueDescriptor*>(&descriptor);
1842 return CreateResize(*resizeQueueDescriptor, info);
1843 }
1844 case LayerType::Shape:
1845 {
1846 auto shapeQueueDescriptor = PolymorphicDowncast<const ShapeQueueDescriptor*>(&descriptor);
1847 return CreateShape(*shapeQueueDescriptor, info);
1848 }
1849 case LayerType::Slice:
1850 {
1851 auto sliceQueueDescriptor = PolymorphicDowncast<const SliceQueueDescriptor*>(&descriptor);
1852 return CreateSlice(*sliceQueueDescriptor, info);
1853 }
1854 case LayerType::Softmax:
1855 {
1856 auto softmaxQueueDescriptor = PolymorphicDowncast<const SoftmaxQueueDescriptor*>(&descriptor);
1857 return CreateSoftmax(*softmaxQueueDescriptor, info);
1858 }
1859 case LayerType::SpaceToBatchNd:
1860 {
1861 auto spaceToBatchNdQueueDescriptor
1862 = PolymorphicDowncast<const SpaceToBatchNdQueueDescriptor*>(&descriptor);
1863 return CreateSpaceToBatchNd(*spaceToBatchNdQueueDescriptor, info);
1864 }
1865 case LayerType::SpaceToDepth:
1866 {
1867 auto spaceToDepthQueueDescriptor = PolymorphicDowncast<const SpaceToDepthQueueDescriptor*>(&descriptor);
1868 return CreateSpaceToDepth(*spaceToDepthQueueDescriptor, info);
1869 }
1870 case LayerType::Splitter:
1871 {
1872 auto splitterQueueDescriptor = PolymorphicDowncast<const SplitterQueueDescriptor*>(&descriptor);
1873 return CreateSplitter(*splitterQueueDescriptor, info);
1874 }
1875 case LayerType::Stack:
1876 {
1877 auto stackQueueDescriptor = PolymorphicDowncast<const StackQueueDescriptor*>(&descriptor);
1878 return CreateStack(*stackQueueDescriptor, info);
1879 }
1880 case LayerType::StridedSlice:
1881 {
1882 auto stridedSliceQueueDescriptor = PolymorphicDowncast<const StridedSliceQueueDescriptor*>(&descriptor);
1883 return CreateStridedSlice(*stridedSliceQueueDescriptor, info);
1884 }
1885 case LayerType::Subtraction:
1886 {
1887 auto subtractionQueueDescriptor = PolymorphicDowncast<const SubtractionQueueDescriptor*>(&descriptor);
1888 return CreateSubtraction(*subtractionQueueDescriptor, info);
1889 }
1890 case LayerType::Transpose:
1891 {
1892 auto transposeQueueDescriptor = PolymorphicDowncast<const TransposeQueueDescriptor*>(&descriptor);
1893 return CreateTranspose(*transposeQueueDescriptor, info);
1894 }
1895 case LayerType::TransposeConvolution2d:
1896 {
1897 auto transposeConvolution2dQueueDescriptor
1898 = PolymorphicDowncast<const TransposeConvolution2dQueueDescriptor*>(&descriptor);
1899 return CreateTransposeConvolution2d(*transposeConvolution2dQueueDescriptor, info);
1900 }
1901 case LayerType::UnidirectionalSequenceLstm:
1902 {
1903 auto unidirectionalSequenceLstmQueueDescriptor
1904 = PolymorphicDowncast<const UnidirectionalSequenceLstmQueueDescriptor*>(&descriptor);
1905 return CreateUnidirectionalSequenceLstm(*unidirectionalSequenceLstmQueueDescriptor, info);
1906 }
1907 default:
1908 return nullptr;
1909 }
1910}
1911ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan04a72972020-09-14 15:44:18 +01001912
Derek Lamberti901ea112019-12-10 22:07:09 +00001913std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1914 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001915{
1916 return std::unique_ptr<IWorkload>();
1917}
1918
Derek Lamberti901ea112019-12-10 22:07:09 +00001919std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1920 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001921{
1922 return std::unique_ptr<IWorkload>();
1923}
1924
Derek Lamberti901ea112019-12-10 22:07:09 +00001925std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1926 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001927{
1928 return std::unique_ptr<IWorkload>();
1929}
1930
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001931std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001932 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001933{
1934 return std::unique_ptr<IWorkload>();
1935}
1936
Derek Lamberti901ea112019-12-10 22:07:09 +00001937std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1938 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001939{
1940 return std::unique_ptr<IWorkload>();
1941}
1942
mathad01b392e982021-04-07 12:07:30 +01001943std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1944 const WorkloadInfo& /*info*/) const
1945{
1946 return std::unique_ptr<IWorkload>();
1947}
1948
Simon Obute51f67772021-09-03 15:50:13 +01001949std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1950 const WorkloadInfo& /*info*/) const
1951{
1952 return std::unique_ptr<IWorkload>();
1953}
1954
Derek Lamberti901ea112019-12-10 22:07:09 +00001955std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1956 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001957{
1958 return std::unique_ptr<IWorkload>();
1959}
1960
Derek Lamberti901ea112019-12-10 22:07:09 +00001961std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1962 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001963{
1964 return std::unique_ptr<IWorkload>();
1965}
1966
Derek Lamberti901ea112019-12-10 22:07:09 +00001967std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1968 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001969{
1970 return std::unique_ptr<IWorkload>();
1971}
1972
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001973std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1974 const WorkloadInfo& /*info*/) const
1975{
1976 return std::unique_ptr<IWorkload>();
1977}
1978
Derek Lamberti901ea112019-12-10 22:07:09 +00001979std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1980 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001981{
1982 return std::unique_ptr<IWorkload>();
1983}
1984
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001985std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1986 const WorkloadInfo& /*info*/) const
1987{
1988 return std::unique_ptr<IWorkload>();
1989}
1990
Derek Lamberti901ea112019-12-10 22:07:09 +00001991std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1992 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001993{
1994 return std::unique_ptr<IWorkload>();
1995}
1996
Derek Lamberti901ea112019-12-10 22:07:09 +00001997std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1998 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001999{
2000 return std::unique_ptr<IWorkload>();
2001}
2002
Matthew Sloyanb63a3112021-09-08 13:05:51 +01002003std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
2004 const WorkloadInfo& /*info*/) const
2005{
2006 return std::unique_ptr<IWorkload>();
2007}
2008
Derek Lamberti901ea112019-12-10 22:07:09 +00002009std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
2010 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002011{
2012 return std::unique_ptr<IWorkload>();
2013}
2014
Derek Lamberti901ea112019-12-10 22:07:09 +00002015std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
2016 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002017{
2018 return std::unique_ptr<IWorkload>();
2019}
2020
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002021std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002022 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002023{
2024 return std::unique_ptr<IWorkload>();
2025}
2026
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002027std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00002028 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002029{
2030 return std::unique_ptr<IWorkload>();
2031}
2032
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002033std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00002034 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002035{
2036 return std::unique_ptr<IWorkload>();
2037}
2038
Derek Lamberti901ea112019-12-10 22:07:09 +00002039std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
2040 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002041{
2042 return std::unique_ptr<IWorkload>();
2043}
2044
josh minor4a3c6102020-01-06 16:40:46 -06002045std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2046 const WorkloadInfo& /*info*/) const
2047{
2048 return std::unique_ptr<IWorkload>();
2049}
2050
Derek Lamberti901ea112019-12-10 22:07:09 +00002051std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
2052 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002053{
2054 return std::unique_ptr<IWorkload>();
2055}
2056
Ryan OSheaec6c6802020-06-05 17:17:06 +01002057std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
2058 const WorkloadInfo& /*info*/) const
2059{
2060 return std::unique_ptr<IWorkload>();
2061}
2062
Derek Lamberti901ea112019-12-10 22:07:09 +00002063std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
2064 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002065{
2066 return std::unique_ptr<IWorkload>();
2067}
2068
Derek Lamberti901ea112019-12-10 22:07:09 +00002069std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
2070 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002071{
2072 return std::unique_ptr<IWorkload>();
2073}
2074
Derek Lamberti901ea112019-12-10 22:07:09 +00002075std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
2076 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002077{
2078 return std::unique_ptr<IWorkload>();
2079}
2080
Kevin Mayce5045a2019-10-02 14:07:47 +01002081std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00002082 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
2083 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01002084{
2085 return std::unique_ptr<IWorkload>();
2086}
2087
Derek Lamberti901ea112019-12-10 22:07:09 +00002088std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
2089 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002090{
2091 return std::unique_ptr<IWorkload>();
2092}
2093
James Conroyaba90cd2020-11-06 16:28:18 +00002094std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
2095 const WorkloadInfo& /*info*/) const
2096{
2097 return std::unique_ptr<IWorkload>();
2098}
2099
2100std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2101 const WorkloadInfo& /*info*/) const
2102{
2103 return std::unique_ptr<IWorkload>();
2104}
2105
Derek Lamberti901ea112019-12-10 22:07:09 +00002106std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
2107 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01002108{
2109 return std::unique_ptr<IWorkload>();
2110}
2111
Derek Lamberti901ea112019-12-10 22:07:09 +00002112std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
2113 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002114{
2115 return std::unique_ptr<IWorkload>();
2116}
2117
Derek Lamberti901ea112019-12-10 22:07:09 +00002118std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
2119 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002120{
2121 return std::unique_ptr<IWorkload>();
2122}
2123
Derek Lamberti901ea112019-12-10 22:07:09 +00002124std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
2125 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002126{
2127 return std::unique_ptr<IWorkload>();
2128}
2129
Derek Lamberti901ea112019-12-10 22:07:09 +00002130std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
2131 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002132{
2133 return std::unique_ptr<IWorkload>();
2134}
2135
Derek Lamberti901ea112019-12-10 22:07:09 +00002136std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
2137 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01002138{
2139 return std::unique_ptr<IWorkload>();
2140}
2141
Derek Lamberti901ea112019-12-10 22:07:09 +00002142std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
2143 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002144{
2145 return std::unique_ptr<IWorkload>();
2146}
2147
Derek Lamberti901ea112019-12-10 22:07:09 +00002148std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
2149 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002150{
2151 return std::unique_ptr<IWorkload>();
2152}
2153
Derek Lamberti901ea112019-12-10 22:07:09 +00002154std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
2155 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002156{
2157 return std::unique_ptr<IWorkload>();
2158}
2159
Derek Lamberti901ea112019-12-10 22:07:09 +00002160std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
2161 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002162{
2163 return std::unique_ptr<IWorkload>();
2164}
2165
Derek Lamberti901ea112019-12-10 22:07:09 +00002166std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
2167 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002168{
2169 return std::unique_ptr<IWorkload>();
2170}
2171
Derek Lamberti901ea112019-12-10 22:07:09 +00002172std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
2173 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002174{
2175 return std::unique_ptr<IWorkload>();
2176}
2177
Derek Lamberti901ea112019-12-10 22:07:09 +00002178std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002179 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002180{
2181 return std::unique_ptr<IWorkload>();
2182}
2183
Derek Lamberti901ea112019-12-10 22:07:09 +00002184std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
2185 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002186{
2187 return std::unique_ptr<IWorkload>();
2188}
2189
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002190std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling3d(const Pooling3dQueueDescriptor& /*descriptor*/,
2191 const WorkloadInfo& /*info*/) const
2192{
2193 return std::unique_ptr<IWorkload>();
2194}
2195
Derek Lamberti901ea112019-12-10 22:07:09 +00002196std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
2197 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002198{
2199 return std::unique_ptr<IWorkload>();
2200}
2201
Derek Lamberti901ea112019-12-10 22:07:09 +00002202std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
2203 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002204{
2205 return std::unique_ptr<IWorkload>();
2206}
2207
Derek Lamberti901ea112019-12-10 22:07:09 +00002208std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
2209 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002210{
2211 return std::unique_ptr<IWorkload>();
2212}
2213
James Conroy586a9aa2020-03-20 08:49:33 +00002214std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
2215 const WorkloadInfo& /*info*/) const
2216{
2217 return std::unique_ptr<IWorkload>();
2218}
2219
Derek Lamberti901ea112019-12-10 22:07:09 +00002220std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
2221 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01002222{
2223 return std::unique_ptr<IWorkload>();
2224}
Finn Williams2605b232020-06-10 15:53:46 +01002225std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
2226 const WorkloadInfo& /*info*/) const
2227{
2228 return std::unique_ptr<IWorkload>();
2229}
James Conroyee18dc82019-07-17 11:27:46 +01002230
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002231std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
2232 const WorkloadInfo& /*info*/) const
2233{
2234 return std::unique_ptr<IWorkload>();
2235}
2236
Derek Lamberti901ea112019-12-10 22:07:09 +00002237std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
2238 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002239{
2240 return std::unique_ptr<IWorkload>();
2241}
2242
Derek Lamberti901ea112019-12-10 22:07:09 +00002243std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
2244 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01002245{
2246 return std::unique_ptr<IWorkload>();
2247}
2248
Keith Davis3ae3f972021-05-21 16:33:48 +01002249std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
2250 const WorkloadInfo& /*info*/) const
2251{
2252 return std::unique_ptr<IWorkload>();
2253}
2254
Derek Lamberti901ea112019-12-10 22:07:09 +00002255std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
2256 const WorkloadInfo& /*info*/) const
2257{
2258 return std::unique_ptr<IWorkload>();
2259}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002260
Derek Lamberti901ea112019-12-10 22:07:09 +00002261std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
2262 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002263{
2264 return std::unique_ptr<IWorkload>();
2265}
2266
Derek Lamberti901ea112019-12-10 22:07:09 +00002267std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
2268 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002269{
2270 return std::unique_ptr<IWorkload>();
2271}
2272
Derek Lamberti901ea112019-12-10 22:07:09 +00002273std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
2274 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002275{
2276 return std::unique_ptr<IWorkload>();
2277}
2278
Derek Lamberti901ea112019-12-10 22:07:09 +00002279std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
2280 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002281{
2282 return std::unique_ptr<IWorkload>();
2283}
2284
Derek Lamberti901ea112019-12-10 22:07:09 +00002285std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
2286 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01002287{
2288 return std::unique_ptr<IWorkload>();
2289}
2290
Derek Lamberti901ea112019-12-10 22:07:09 +00002291std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
2292 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01002293{
2294 return std::unique_ptr<IWorkload>();
2295}
2296
Derek Lamberti901ea112019-12-10 22:07:09 +00002297std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
2298 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002299{
2300 return std::unique_ptr<IWorkload>();
2301}
2302
Derek Lamberti901ea112019-12-10 22:07:09 +00002303std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
2304 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01002305{
2306 return std::unique_ptr<IWorkload>();
2307}
2308
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002309std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
2310 const WorkloadInfo& /*info*/) const
2311{
2312 return std::unique_ptr<IWorkload>();
2313}
2314
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002315std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002316 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
2317 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002318{
2319 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01002320}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002321
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01002322std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
2323 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
2324 const WorkloadInfo& /*info*/) const
2325{
2326 return std::unique_ptr<IWorkload>();
2327}
2328
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002329} // namepsace armnn