blob: 665ab3f86cfeb9ddeaed398911329c00c6c3bca8 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
18#include <armnn/backends/TensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
David Beck111b5d92018-11-12 14:59:37 +000020#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000021
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025namespace
26{
Finn Williams3e54d032020-10-22 16:53:35 +010027using LayerList = std::list<Layer*>;
28using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010029
David Beck29c75de2018-10-23 13:35:58 +010030const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
31{
32 if (!type)
33 {
34 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010035 }
36
Matthew Sloyan81beae32021-07-13 19:46:11 +010037 return TensorInfo(info.GetShape(),
38 type.value(),
39 info.GetQuantizationScale(),
40 info.GetQuantizationOffset(),
41 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
Sadik Armagana097d2a2021-11-24 15:47:28 +000046inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
47{
48 if (!weightsType)
49 {
50 return weightsType;
51 }
52
53 switch(weightsType.value())
54 {
55 case armnn::DataType::BFloat16:
56 case armnn::DataType::Float16:
57 case armnn::DataType::Float32:
58 return weightsType;
59 case armnn::DataType::QAsymmS8:
60 case armnn::DataType::QAsymmU8:
61 case armnn::DataType::QSymmS8:
62 case armnn::DataType::QSymmS16:
63 return armnn::DataType::Signed32;
64 default:
65 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
66 }
67 return armnn::EmptyOptional();
68}
69
70
Sadik Armagan045f6be2020-09-10 13:37:32 +010071bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
72 const IConnectableLayer& connectableLayer,
73 Optional<DataType> dataType,
74 std::string& outReasonIfUnsupported,
75 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000076{
David Beck33f0ae02018-10-18 15:13:56 +010077 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000078 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010079 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010080
David Beck111b5d92018-11-12 14:59:37 +000081 auto const& backendRegistry = BackendRegistryInstance();
82 if (!backendRegistry.IsBackendRegistered(backendId))
83 {
84 std::stringstream ss;
85 ss << connectableLayer.GetName() << " is not supported on " << backendId
86 << " because this backend is not registered.";
87
88 outReasonIfUnsupported = ss.str();
89 return false;
90 }
91
92 auto backendFactory = backendRegistry.GetFactory(backendId);
93 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000094 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000101 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
112 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
113 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
114 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000115 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100116 OverrideDataType(input0, dataType),
117 OverrideDataType(input1, dataType),
118 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100119 reason);
telsoa014fcda012018-03-09 14:13:49 +0000120 break;
121 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100122 case LayerType::ArgMinMax:
123 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100124 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100125 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
126
127 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000129 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000131 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 descriptor,
133 reason);
134 break;
135 }
Samuel Yap4b7a34d2022-07-06 15:36:03 +0100136 case LayerType::BatchMatMul:
137 {
138 auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
139 const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
140
141 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
142 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
143 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
144 result = layerSupportObject.IsBatchMatMulSupported(
145 OverrideDataType(input0, dataType),
146 OverrideDataType(input1, dataType),
147 OverrideDataType(output, dataType),
148 descriptor,
149 reason);
150 break;
151 }
telsoa014fcda012018-03-09 14:13:49 +0000152 case LayerType::BatchNormalization:
153 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100154 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000155 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100156 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
157 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
158 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
159 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
160 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000161 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100162 OverrideDataType(input, dataType),
163 OverrideDataType(output, dataType),
164 OverrideDataType(mean, dataType),
165 OverrideDataType(var, dataType),
166 OverrideDataType(beta, dataType),
167 OverrideDataType(gamma, dataType),
168 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100169 reason);
telsoa014fcda012018-03-09 14:13:49 +0000170 break;
171 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000172 case LayerType::BatchToSpaceNd:
173 {
174 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100176 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000177
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000178 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
179 OverrideDataType(output, dataType),
180 cLayer->GetParameters(),
181 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000182 break;
183 }
mathad01b392e982021-04-07 12:07:30 +0100184 case LayerType::Cast:
185 {
186 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
187 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
188
189 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
190 OverrideDataType(output, dataType),
191 reason);
192 break;
193 }
Simon Obute51f67772021-09-03 15:50:13 +0100194 case LayerType::ChannelShuffle:
195 {
196 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
197
198 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
199 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
200
201 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
202
203 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
204 OverrideDataType(output, dataType),
205 descriptor,
206 reason);
207 break;
208 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100209 case LayerType::Comparison:
210 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100211 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100212
213 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
214 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
216
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000217 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
218 OverrideDataType(input1, dataType),
219 OverrideDataType(output, DataType::Boolean),
220 cLayer->GetParameters(),
221 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100222 break;
223 }
telsoa014fcda012018-03-09 14:13:49 +0000224 case LayerType::Constant:
225 {
226 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000227 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100228 break;
229 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000230 case LayerType::ConvertBf16ToFp32:
231 {
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000234 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000235 break;
236 }
telsoa01c577f2c2018-08-31 09:22:23 +0100237 case LayerType::ConvertFp16ToFp32:
238 {
239 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
240 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000241 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100242 break;
243 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000244 case LayerType::ConvertFp32ToBf16:
245 {
246 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000248 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000249 break;
250 }
telsoa01c577f2c2018-08-31 09:22:23 +0100251 case LayerType::ConvertFp32ToFp16:
252 {
253 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
254 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000255 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000256 break;
257 }
258 case LayerType::Convolution2d:
259 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100260 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100261
262 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
263 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100264 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100265 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
266 "Convolution2dLayer: Weights should be connected as a Constant Layer.");
267 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
268 dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100269
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100270 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100271
arovir01a6824102018-08-28 17:40:45 +0100272 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100273 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100274 if (descriptor.m_BiasEnabled)
275 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100276 ARMNN_ASSERT_MSG(layer.GetInputSlot(2).GetConnection(),
277 "Convolution2dLayer: Bias should be connected as a Constant Layer.");
278 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
279 GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100280 }
281
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000282 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100283 input,
284 output,
285 descriptor,
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100286 weights,
arovir01a6824102018-08-28 17:40:45 +0100287 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100288 reason);
telsoa014fcda012018-03-09 14:13:49 +0000289 break;
290 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100291 case LayerType::Convolution3d:
292 {
293 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
294
295 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
296 dataType);
297 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100298
299 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
300 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
301 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
302 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100303
304 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
305
306 // Construct optional biases object based on the value of m_BiasEnabled
307 Optional<TensorInfo> biases;
308 if (descriptor.m_BiasEnabled)
309 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100310 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
311 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100312 }
313
314 result = layerSupportObject.IsConvolution3dSupported(
315 input,
316 output,
317 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100318 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100319 biases,
320 reason);
321 break;
322 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000323 case LayerType::Debug:
324 {
325 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
326 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
327
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000328 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000329 OverrideDataType(output, dataType),
330 reason);
331 break;
332 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100333 case LayerType::DepthToSpace:
334 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100335 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100336
337 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
338 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
339
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000340 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100341 OverrideDataType(output, dataType),
342 cLayer->GetParameters(),
343 reason);
344 break;
345 }
telsoa014fcda012018-03-09 14:13:49 +0000346 case LayerType::DepthwiseConvolution2d:
347 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100348 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Cathal Corbett06902652022-04-14 17:55:11 +0100349 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
350 dataType);
351 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
352 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
353 dataType);
354
355 ARMNN_ASSERT(cLayer->GetInputSlot(1).GetConnection() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100356
telsoa01c577f2c2018-08-31 09:22:23 +0100357 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100358
359 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100360 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100361 if (descriptor.m_BiasEnabled)
362 {
Cathal Corbett06902652022-04-14 17:55:11 +0100363 biases = OverrideDataType(cLayer->GetInputSlot(2).GetConnection()->GetTensorInfo(),
364 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100365 }
telsoa01c577f2c2018-08-31 09:22:23 +0100366
Cathal Corbett06902652022-04-14 17:55:11 +0100367 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
368 output,
369 descriptor,
370 weights,
371 biases,
372 reason);
telsoa014fcda012018-03-09 14:13:49 +0000373 break;
374 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000375 case LayerType::Dequantize:
376 {
377 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
378 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
379
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000380 result = layerSupportObject.IsDequantizeSupported(input,
381 OverrideDataType(output, dataType),
382 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000383 break;
384 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000385 case LayerType::DetectionPostProcess:
386 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100387 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000388 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
389 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
390 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
391
392 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
393 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
394 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
395 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
396
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000397 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000398 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
399 scores,
400 anchors,
401 detectionBoxes,
402 detectionClasses,
403 detectionScores,
404 numDetections,
405 descriptor,
406 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000407 break;
408 }
josh minor4a3c6102020-01-06 16:40:46 -0600409 case LayerType::ElementwiseUnary:
410 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100411 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600412
413 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
414 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
415
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000416 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
417 OverrideDataType(output, dataType),
418 cLayer->GetParameters(),
419 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600420 break;
421 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100422 case LayerType::Fill:
423 {
424 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
425 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
426 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
427 const FillDescriptor& descriptor = cLayer->GetParameters();
428
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000429 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100430 OverrideDataType(input, dataType),
431 OverrideDataType(output, dataType),
432 descriptor,
433 reason);
434 break;
435 }
telsoa014fcda012018-03-09 14:13:49 +0000436 case LayerType::FakeQuantization:
437 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100438 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000439 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000440 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
441 cLayer->GetParameters(),
442 reason);
telsoa014fcda012018-03-09 14:13:49 +0000443 break;
444 }
445 case LayerType::Floor:
446 {
447 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
448 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000449 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
450 OverrideDataType(output, dataType),
451 reason);
telsoa014fcda012018-03-09 14:13:49 +0000452 break;
453 }
454 case LayerType::FullyConnected:
455 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100456 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000457 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100458 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000459
460 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
461 TensorInfo weightsInfo;
462 const TensorInfo* weightsInfoPtr = nullptr;
463
Matthew Sloyan81beae32021-07-13 19:46:11 +0100464 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000465 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100466
467 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000468 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000469 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100470 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
471 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
472 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
473
telsoa01c577f2c2018-08-31 09:22:23 +0100474 if (descriptor.m_BiasEnabled)
475 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100476 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
477 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100478 }
479 else
480 {
481 // If biases are not enabled pass a dummy tensorinfo for the validation
482 switch(input.GetDataType())
483 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000484 case DataType::BFloat16:
485 {
486 biasInfoPtr = &dummyBFloat16Bias;
487 break;
488 }
telsoa01c577f2c2018-08-31 09:22:23 +0100489 case DataType::Float16:
490 {
491 biasInfoPtr = &dummyFloat16Bias;
492 break;
493 }
494 case DataType::Float32:
495 {
496 biasInfoPtr = &dummyFloat32Bias;
497 break;
498 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000499 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000500 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000501 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000502 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100503 {
504 biasInfoPtr = &dummyQA8Bias;
505 break;
506 }
507 default:
508 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100509 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100510 }
511 }
512 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000513 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100514 OverrideDataType(input, dataType),
515 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000516 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100517 *biasInfoPtr,
518 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100519 reason);
telsoa014fcda012018-03-09 14:13:49 +0000520 break;
521 }
narpra01b89b05f2019-01-16 09:53:09 +0000522 case LayerType::Gather:
523 {
524 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
525 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
526 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100527 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
528 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000529 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
530 input1,
531 OverrideDataType(output, dataType),
532 descriptor,
533 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000534 break;
535 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100536 case LayerType::GatherNd:
537 {
538 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
539 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
540 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
541 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
542 input1,
543 OverrideDataType(output, dataType),
544 reason);
545 break;
546 }
telsoa014fcda012018-03-09 14:13:49 +0000547 case LayerType::Input:
548 {
549 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000550 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000551 break;
552 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100553 case LayerType::InstanceNormalization:
554 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100555 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100556 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
557
558 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
559 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
560
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000561 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100562 OverrideDataType(input, dataType),
563 OverrideDataType(output, dataType),
564 descriptor,
565 reason);
566 break;
567 }
telsoa014fcda012018-03-09 14:13:49 +0000568 case LayerType::L2Normalization:
569 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100570 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100571 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
572
telsoa014fcda012018-03-09 14:13:49 +0000573 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100574 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100575
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000576 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100577 OverrideDataType(input, dataType),
578 OverrideDataType(output, dataType),
579 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100580 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100581 break;
582 }
James Conroyaba90cd2020-11-06 16:28:18 +0000583 case LayerType::LogicalBinary:
584 {
585 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
586
587 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
588 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
589 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
590
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000591 result = layerSupportObject.IsLogicalBinarySupported(input0,
592 input1,
593 output,
594 cLayer->GetParameters(),
595 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000596 break;
597 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100598 case LayerType::LogSoftmax:
599 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100600 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100601
602 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
603 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
604
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000605 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
606 OverrideDataType(output, dataType),
607 cLayer->GetParameters(),
608 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100609 break;
610 }
telsoa01c577f2c2018-08-31 09:22:23 +0100611 case LayerType::Lstm:
612 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100613 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100614 const LstmDescriptor& descriptor = cLayer->GetParameters();
615
616 // All inputs.
617 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
618 dataType);
619 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
620 dataType);
621 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
622 dataType);
623 // All outputs
624 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
625 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
626 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
627 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
628
629 // Basic parameters
630 const TensorInfo& inputToForgetWeights
631 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
632 const TensorInfo& inputToCellWeights
633 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
634 const TensorInfo& inputToOutputWeights
635 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
636 const TensorInfo& recurrentToForgetWeights
637 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
638 const TensorInfo& recurrentToCellWeights
639 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
640 const TensorInfo& recurrentToOutputWeights
641 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
642 const TensorInfo& forgetGateBias
643 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
644 const TensorInfo& cellBias
645 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
646 const TensorInfo& outputGateBias
647 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
648
Jan Eilersd01a83c2019-07-03 18:20:40 +0100649 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100650
Jan Eilersd01a83c2019-07-03 18:20:40 +0100651 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
652 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
653 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
654 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
655 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
656 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
657 paramsInfo.m_ForgetGateBias = &forgetGateBias;
658 paramsInfo.m_CellBias = &cellBias;
659 paramsInfo.m_OutputGateBias = &outputGateBias;
660
661
662 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100663 TensorInfo optInputToInputWeights;
664 TensorInfo optRecurrentToInputWeights;
665 TensorInfo optCellToInputWeights;
666 TensorInfo optInputGateBias;
667 TensorInfo optProjectionWeights;
668 TensorInfo optProjectionBias;
669 TensorInfo optCellToForgetWeights;
670 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100671 TensorInfo optInputLayerNormWeights;
672 TensorInfo optForgetLayerNormWeights;
673 TensorInfo optCellLayerNormWeights;
674 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100675
676 if(!descriptor.m_CifgEnabled)
677 {
678 optInputToInputWeights =
679 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100680 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100681
682 optRecurrentToInputWeights =
683 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100684 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100685 optInputGateBias =
686 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100687 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100688 }
689
690 if(descriptor.m_ProjectionEnabled)
691 {
692 optProjectionWeights =
693 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100694 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100695 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
696 {
697 optProjectionBias =
698 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100699 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100700 }
701 }
702
703 if(descriptor.m_PeepholeEnabled)
704 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100705 if(!descriptor.m_CifgEnabled)
706 {
707 optCellToInputWeights =
708 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
709 dataType);
710 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
711 }
telsoa01c577f2c2018-08-31 09:22:23 +0100712 optCellToForgetWeights =
713 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100714 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100715 optCellToOutputWeights =
716 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100717 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100718 }
719
Jan Eilers38e05bd2019-06-26 13:10:09 +0100720 if(descriptor.m_LayerNormEnabled)
721 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100722 if (!descriptor.m_CifgEnabled)
723 {
724 optInputLayerNormWeights = OverrideDataType(
725 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
726 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
727 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100728
729 optForgetLayerNormWeights = OverrideDataType(
730 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100731 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100732
733 optCellLayerNormWeights = OverrideDataType(
734 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100735 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100736
737 optOutputLayerNormWeights = OverrideDataType(
738 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100739 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100740 }
741
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000742 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100743 input,
744 outputStateIn,
745 cellStateIn,
746 scratchBuffer,
747 outputStateOut,
748 cellStateOut,
749 output,
750 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100751 paramsInfo,
752 reason);
telsoa014fcda012018-03-09 14:13:49 +0000753 break;
754 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000755 case LayerType::Maximum:
756 {
757 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
759 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
760
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000761 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
762 OverrideDataType(input1, dataType),
763 OverrideDataType(output, dataType),
764 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000765 break;
766 }
narpra01b89b05f2019-01-16 09:53:09 +0000767 case LayerType::MemCopy:
768 {
769 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
770 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000771
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000772 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
773 OverrideDataType(output, dataType),
774 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000775 break;
776 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100777 case LayerType::MemImport:
778 {
779 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
780 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
781
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000782 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
783 OverrideDataType(output, dataType),
784 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100785 break;
786 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100787 case LayerType::Merge:
788 {
789 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
790 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
791 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
792
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000793 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
794 OverrideDataType(input1, dataType),
795 OverrideDataType(output, dataType),
796 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100797 break;
798 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100799 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000800 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100801 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000802
telsoa01c577f2c2018-08-31 09:22:23 +0100803 // Get vector of all inputs.
804 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000805 {
telsoa01c577f2c2018-08-31 09:22:23 +0100806 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000807 };
Finn Williams3e54d032020-10-22 16:53:35 +0100808
809 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
810 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100811 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000812
telsoa01c577f2c2018-08-31 09:22:23 +0100813 auto getTensorInfoPtr = [](const TensorInfo& info)
814 {
815 return &info;
816 };
Finn Williams3e54d032020-10-22 16:53:35 +0100817
818 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
819 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100820 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000821
Nikhil Raj8599a412018-11-19 14:51:07 +0000822 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
823
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000824 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100825
826
telsoa014fcda012018-03-09 14:13:49 +0000827 break;
828 }
829 case LayerType::Multiplication:
830 {
831 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
832 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100833 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000834 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100835 OverrideDataType(input0, dataType),
836 OverrideDataType(input1, dataType),
837 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100838 reason);
telsoa014fcda012018-03-09 14:13:49 +0000839 break;
840 }
841 case LayerType::Normalization:
842 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100843 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000844 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
845 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000846 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
847 OverrideDataType(output, dataType),
848 cLayer->GetParameters(),
849 reason);
telsoa014fcda012018-03-09 14:13:49 +0000850 break;
851 }
852 case LayerType::Output:
853 {
854 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000855 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000856 break;
857 }
858 case LayerType::Permute:
859 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100860 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000861 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
862 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000863 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
864 OverrideDataType(output, dataType),
865 cLayer->GetParameters(),
866 reason);
telsoa014fcda012018-03-09 14:13:49 +0000867 break;
868 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100869 case LayerType::Pad:
870 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100871 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100872 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
873 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000874 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100875 OverrideDataType(input, dataType),
876 OverrideDataType(output, dataType),
877 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100878 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100879 break;
880 }
telsoa014fcda012018-03-09 14:13:49 +0000881 case LayerType::Pooling2d:
882 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100883 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000884 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
885 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000886 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
887 OverrideDataType(output, dataType),
888 cLayer->GetParameters(),
889 reason);
telsoa014fcda012018-03-09 14:13:49 +0000890 break;
891 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100892 case LayerType::Pooling3d:
893 {
894 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
895 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
896 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
897 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
898 OverrideDataType(output, dataType),
899 cLayer->GetParameters(),
900 reason);
901 break;
902 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000903 case LayerType::PreCompiled:
904 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100905 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000906 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000907 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
908 cLayer->GetParameters(),
909 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000910 break;
911 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000912 case LayerType::Quantize:
913 {
914 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
915 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000916 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000917 break;
918 }
James Conroy586a9aa2020-03-20 08:49:33 +0000919 case LayerType::QLstm:
920 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100921 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000922 const QLstmDescriptor& descriptor = cLayer->GetParameters();
923
924 // Inputs
925 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
926 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
927 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
928
929 // Outputs
930 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
931 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
932 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
933
934 // Lstm parameters
935 LstmInputParamsInfo paramsInfo;
936
937 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100938 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
939 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
940 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000941 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
942 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
943 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
944
945 paramsInfo.m_RecurrentToForgetWeights =
946 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
947 paramsInfo.m_RecurrentToCellWeights =
948 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
949 paramsInfo.m_RecurrentToOutputWeights =
950 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
951
952 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
953 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
954 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
955
956 if(!descriptor.m_CifgEnabled)
957 {
958 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
959 paramsInfo.m_RecurrentToInputWeights =
960 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
961 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
962 }
963
964 if(descriptor.m_ProjectionEnabled)
965 {
966 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100967
968 // Projection bias is optional even if projection is enabled
969 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
970 {
971 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
972 }
James Conroy586a9aa2020-03-20 08:49:33 +0000973 }
974
975 if(descriptor.m_PeepholeEnabled)
976 {
977 if (!descriptor.m_CifgEnabled)
978 {
979 paramsInfo.m_CellToInputWeights =
980 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
981 }
982
983 paramsInfo.m_CellToForgetWeights =
984 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
985 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
986 }
987
988 if(descriptor.m_LayerNormEnabled)
989 {
990 if (!descriptor.m_CifgEnabled)
991 {
992 paramsInfo.m_InputLayerNormWeights =
993 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
994 }
995
996 paramsInfo.m_ForgetLayerNormWeights =
997 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
998 paramsInfo.m_CellLayerNormWeights =
999 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
1000 paramsInfo.m_OutputLayerNormWeights =
1001 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
1002 }
1003
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001004 result = layerSupportObject.IsQLstmSupported(input,
1005 previousOutputIn,
1006 previousCellStateIn,
1007 outputStateOut,
1008 cellStateOut,
1009 output,
1010 descriptor,
1011 paramsInfo,
1012 reason);
James Conroy586a9aa2020-03-20 08:49:33 +00001013 break;
1014 }
James Conroyee18dc82019-07-17 11:27:46 +01001015 case LayerType::QuantizedLstm:
1016 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001017 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +01001018
1019 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001020 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1021 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1022 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001023
1024 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001025 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1026 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001027
1028 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001029 QuantizedLstmInputParamsInfo paramsInfo;
1030
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001031 paramsInfo.m_InputToInputWeights =
1032 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1033 paramsInfo.m_InputToForgetWeights =
1034 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1035 paramsInfo.m_InputToCellWeights =
1036 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1037 paramsInfo.m_InputToOutputWeights =
1038 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001039
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001040 paramsInfo.m_RecurrentToInputWeights =
1041 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1042 paramsInfo.m_RecurrentToForgetWeights =
1043 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1044 paramsInfo.m_RecurrentToCellWeights =
1045 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1046 paramsInfo.m_RecurrentToOutputWeights =
1047 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001048
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001049 paramsInfo.m_InputGateBias =
1050 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1051 paramsInfo.m_ForgetGateBias =
1052 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1053 paramsInfo.m_CellBias =
1054 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1055 paramsInfo.m_OutputGateBias =
1056 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001057
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001058 result = layerSupportObject.IsQuantizedLstmSupported(input,
1059 previousCellStateIn,
1060 previousOutputIn,
1061 cellStateOut,
1062 output,
1063 paramsInfo,
1064 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001065 break;
1066 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001067 case LayerType::Division:
1068 {
1069 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1070 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1071 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001072 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001073 OverrideDataType(input0, dataType),
1074 OverrideDataType(input1, dataType),
1075 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001076 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001077 break;
1078 }
Finn Williams2605b232020-06-10 15:53:46 +01001079 case LayerType::Rank:
1080 {
1081 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1082 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001083 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1084 OverrideDataType(output, dataType),
1085 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001086 break;
1087 }
telsoa014fcda012018-03-09 14:13:49 +00001088 case LayerType::Reshape:
1089 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001090 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001091 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001092 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001093 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1094 OverrideDataType(output, dataType),
1095 cLayer->GetParameters(),
1096 reason);
telsoa014fcda012018-03-09 14:13:49 +00001097 break;
1098 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001099 case LayerType::Resize:
1100 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001101 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001102 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001103 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001104 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1105 OverrideDataType(output, dataType),
1106 cLayer->GetParameters(),
1107 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001108 break;
1109 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001110 case LayerType::Shape:
1111 {
1112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1114
1115 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1116 OverrideDataType(output, dataType),
1117 reason);
1118 break;
1119 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001120 case LayerType::Slice:
1121 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001122 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001123
1124 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1125 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1126
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001127 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1128 OverrideDataType(output, dataType),
1129 cLayer->GetParameters(),
1130 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001131 break;
1132 }
telsoa014fcda012018-03-09 14:13:49 +00001133 case LayerType::Softmax:
1134 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001135 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001138 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1139 OverrideDataType(output, dataType),
1140 cLayer->GetParameters(),
1141 reason);
telsoa014fcda012018-03-09 14:13:49 +00001142 break;
1143 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001144 case LayerType::SpaceToBatchNd:
1145 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001146 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001149 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1150 OverrideDataType(output, dataType),
1151 cLayer->GetParameters(),
1152 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001153 break;
1154 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001155 case LayerType::SpaceToDepth:
1156 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001157 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001158
1159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1161
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001162 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1163 OverrideDataType(output, dataType),
1164 cLayer->GetParameters(),
1165 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001166 break;
1167 }
telsoa014fcda012018-03-09 14:13:49 +00001168 case LayerType::Splitter:
1169 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001170 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001172
1173 // Get vector of all outputs.
1174 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1175 {
1176 return OverrideDataType(slot.GetTensorInfo(), dataType);
1177 };
Finn Williams3e54d032020-10-22 16:53:35 +01001178 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1179 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001180 std::vector<TensorInfo> outputs(beginI, endI);
1181
1182 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1183
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001184 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1185 outputPtrs,
1186 cLayer->GetParameters(),
1187 reason);
telsoa014fcda012018-03-09 14:13:49 +00001188 break;
1189 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001190 case LayerType::Stack:
1191 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001192 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001193
1194 // Get vector of all inputs.
1195 auto getTensorInfo = [&dataType](const InputSlot& slot)
1196 {
1197 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1198 };
Finn Williams3e54d032020-10-22 16:53:35 +01001199 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1200 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001201 std::vector<TensorInfo> inputs(beginI, endI);
1202
1203 auto getTensorInfoPtr = [](const TensorInfo& info)
1204 {
1205 return &info;
1206 };
Finn Williams3e54d032020-10-22 16:53:35 +01001207 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1208 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001209 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1210
1211 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1212
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001213 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001214
1215 break;
1216 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001217 case LayerType::StandIn:
1218 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001219 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001220
1221 // Get vector of all inputs.
1222 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1223 {
1224 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1225 };
1226 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1227 {
1228 return OverrideDataType(slot.GetTensorInfo(), dataType);
1229 };
Finn Williams3e54d032020-10-22 16:53:35 +01001230 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1231 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001232 std::vector<TensorInfo> inputs(beginI, endI);
1233
Finn Williams3e54d032020-10-22 16:53:35 +01001234 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1235 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001236 std::vector<TensorInfo> outputs(beginO, endO);
1237
1238
1239 auto getTensorInfoPtr = [](const TensorInfo& info)
1240 {
1241 return &info;
1242 };
Finn Williams3e54d032020-10-22 16:53:35 +01001243 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1244 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001245 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1246
Finn Williams3e54d032020-10-22 16:53:35 +01001247 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1248 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001249 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1250
1251
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001252 result = layerSupportObject.IsStandInSupported(inputPtrs,
1253 outputPtrs,
1254 cLayer->GetParameters(),
1255 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001256 break;
1257 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001258 case LayerType::StridedSlice:
1259 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001260 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001261 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1262 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001263 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1264 OverrideDataType(output, dataType),
1265 cLayer->GetParameters(),
1266 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001267 break;
1268 }
David Beckc2044fe2018-09-05 15:00:38 +01001269 case LayerType::Subtraction:
1270 {
1271 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1272 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1273 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001274 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001275 OverrideDataType(input0, dataType),
1276 OverrideDataType(input1, dataType),
1277 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001278 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001279 break;
1280 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001281 case LayerType::Switch:
1282 {
1283 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1284 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1285 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1286 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001287 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1288 OverrideDataType(input1, dataType),
1289 OverrideDataType(output0, dataType),
1290 OverrideDataType(output1, dataType),
1291 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001292 break;
1293 }
narpra0132b90462018-09-13 11:07:48 +01001294 case LayerType::Mean:
1295 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001296 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001297 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1298 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001299 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001300 OverrideDataType(input, dataType),
1301 OverrideDataType(output, dataType),
1302 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001303 reason);
narpra0132b90462018-09-13 11:07:48 +01001304 break;
1305 }
kevmay0190539692018-11-29 08:40:19 +00001306 case LayerType::Minimum:
1307 {
1308 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1309 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1310 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001311 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1312 OverrideDataType(input1, dataType),
1313 OverrideDataType(output, dataType),
1314 reason);
kevmay0190539692018-11-29 08:40:19 +00001315 break;
1316 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001317 case LayerType::Prelu:
1318 {
1319 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1320 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1321 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001322 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1323 OverrideDataType(alpha, dataType),
1324 OverrideDataType(output, dataType),
1325 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001326 break;
1327 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001328 case LayerType::Transpose:
1329 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001330 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001331 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1332 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001333 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1334 OverrideDataType(output, dataType),
1335 cLayer->GetParameters(),
1336 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001337 break;
1338 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001339 case LayerType::TransposeConvolution2d:
1340 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001341 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001342
1343 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1344 dataType);
1345 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1346
1347 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1348
1349 Optional<TensorInfo> biases;
1350 if (descriptor.m_BiasEnabled)
1351 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001352 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001353 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1354 GetBiasTypeFromWeightsType(dataType));
1355 }
1356
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001357 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001358 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1359
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001360 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1361 output,
1362 descriptor,
1363 weights,
1364 biases,
1365 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001366
1367 break;
1368 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001369 case LayerType::Reduce:
1370 {
1371 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1372 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1373 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1374
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001375 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1376 OverrideDataType(output, dataType),
1377 cLayer->GetParameters(),
1378 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001379 break;
1380 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001381 case LayerType::UnidirectionalSequenceLstm:
1382 {
1383 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1384 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1385
1386 // All inputs.
1387 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1388 dataType);
1389 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1390 dataType);
1391 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1392 dataType);
1393 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001394 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1395 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1396 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001397
1398 // Basic parameters
1399 const TensorInfo& inputToForgetWeights
1400 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1401 const TensorInfo& inputToCellWeights
1402 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1403 const TensorInfo& inputToOutputWeights
1404 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1405 const TensorInfo& recurrentToForgetWeights
1406 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1407 const TensorInfo& recurrentToCellWeights
1408 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1409 const TensorInfo& recurrentToOutputWeights
1410 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1411 const TensorInfo& forgetGateBias
1412 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1413 const TensorInfo& cellBias
1414 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1415 const TensorInfo& outputGateBias
1416 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1417
1418 LstmInputParamsInfo paramsInfo;
1419
1420 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1421 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1422 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1423 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1424 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1425 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1426 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1427 paramsInfo.m_CellBias = &cellBias;
1428 paramsInfo.m_OutputGateBias = &outputGateBias;
1429
1430 // Optional parameters
1431 TensorInfo optInputToInputWeights;
1432 TensorInfo optRecurrentToInputWeights;
1433 TensorInfo optCellToInputWeights;
1434 TensorInfo optInputGateBias;
1435 TensorInfo optProjectionWeights;
1436 TensorInfo optProjectionBias;
1437 TensorInfo optCellToForgetWeights;
1438 TensorInfo optCellToOutputWeights;
1439 TensorInfo optInputLayerNormWeights;
1440 TensorInfo optForgetLayerNormWeights;
1441 TensorInfo optCellLayerNormWeights;
1442 TensorInfo optOutputLayerNormWeights;
1443
1444 if(!descriptor.m_CifgEnabled)
1445 {
1446 optInputToInputWeights =
1447 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1448 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1449
1450 optRecurrentToInputWeights =
1451 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1452 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1453 optInputGateBias =
1454 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1455 paramsInfo.m_InputGateBias = &optInputGateBias;
1456 }
1457
1458 if(descriptor.m_ProjectionEnabled)
1459 {
1460 optProjectionWeights =
1461 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1462 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1463 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1464 {
1465 optProjectionBias =
1466 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1467 paramsInfo.m_ProjectionBias = &optProjectionBias;
1468 }
1469 }
1470
1471 if(descriptor.m_PeepholeEnabled)
1472 {
1473 if(!descriptor.m_CifgEnabled)
1474 {
1475 optCellToInputWeights =
1476 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1477 dataType);
1478 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1479 }
1480 optCellToForgetWeights =
1481 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1482 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1483 optCellToOutputWeights =
1484 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1485 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1486 }
1487
1488 if(descriptor.m_LayerNormEnabled)
1489 {
1490 if (!descriptor.m_CifgEnabled)
1491 {
1492 optInputLayerNormWeights = OverrideDataType(
1493 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1494 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1495 }
1496
1497 optForgetLayerNormWeights = OverrideDataType(
1498 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1499 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1500
1501 optCellLayerNormWeights = OverrideDataType(
1502 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1503 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1504
1505 optOutputLayerNormWeights = OverrideDataType(
1506 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1507 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1508 }
1509
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001510 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1511 outputStateIn,
1512 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001513 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001514 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001515 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001516 descriptor,
1517 paramsInfo,
1518 reason);
1519 break;
1520 }
telsoa014fcda012018-03-09 14:13:49 +00001521 default:
1522 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001523 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001524 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001525 result = false;
1526 break;
1527 }
1528 }
telsoa014fcda012018-03-09 14:13:49 +00001529 return result;
1530}
1531
Sadik Armagan045f6be2020-09-10 13:37:32 +01001532bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1533 const IConnectableLayer& connectableLayer,
1534 Optional<DataType> dataType,
1535 std::string& outReasonIfUnsupported)
1536{
1537 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1538}
1539
David Beckdcb751f2018-10-03 11:42:42 +01001540bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001541 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001542 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001543{
Jan Eilersbb446e52020-04-02 13:56:54 +01001544 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001545 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1546}
1547
1548// TODO merge with defaulted modelOptions above
1549bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1550 Optional<DataType> dataType,
1551 std::string& outReasonIfUnsupported,
1552 const ModelOptions& modelOptions)
1553{
1554 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1555 return IsLayerConfigurationSupported(layer->GetBackendId(),
1556 connectableLayer,
1557 dataType,
1558 outReasonIfUnsupported,
1559 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001560}
1561
Sadik Armagan04a72972020-09-14 15:44:18 +01001562bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1563 const IConnectableLayer& connectableLayer,
1564 Optional<DataType> dataType,
1565 std::string& outReasonIfUnsupported,
1566 const ModelOptions& modelOptions)
1567{
1568 return IsLayerConfigurationSupported(backendId,
1569 connectableLayer,
1570 dataType,
1571 outReasonIfUnsupported,
1572 modelOptions);
1573}
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001574ARMNN_NO_DEPRECATE_WARN_BEGIN
1575std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type,
1576 const QueueDescriptor& descriptor,
1577 const WorkloadInfo& info) const
1578{
1579 switch(type)
1580 {
1581 case LayerType::Activation :
1582 {
1583 auto activationQueueDescriptor = PolymorphicDowncast<const ActivationQueueDescriptor*>(&descriptor);
1584 return CreateActivation(*activationQueueDescriptor, info);
1585 }
1586 case LayerType::Addition :
1587 {
1588 auto additionQueueDescriptor = PolymorphicDowncast<const AdditionQueueDescriptor*>(&descriptor);
1589 return CreateAddition(*additionQueueDescriptor, info);
1590 }
1591 case LayerType::ArgMinMax :
1592 {
1593 auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
1594 return CreateArgMinMax(*argMinMaxQueueDescriptor, info);
1595 }
1596 case LayerType::BatchNormalization :
1597 {
1598 auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
1599 return CreateBatchNormalization(*batchNormQueueDescriptor, info);
1600 }
1601 case LayerType::BatchToSpaceNd :
1602 {
1603 auto batchToSpaceNdQueueDescriptor
1604 = PolymorphicDowncast<const BatchToSpaceNdQueueDescriptor*>(&descriptor);
1605 return CreateBatchToSpaceNd(*batchToSpaceNdQueueDescriptor, info);
1606 }
1607 case LayerType::Cast :
1608 {
1609 auto castQueueDescriptor = PolymorphicDowncast<const CastQueueDescriptor*>(&descriptor);
1610 return CreateCast(*castQueueDescriptor, info);
1611 }
1612 case LayerType::ChannelShuffle :
1613 {
1614 auto channelShuffleQueueDescriptor
1615 = PolymorphicDowncast<const ChannelShuffleQueueDescriptor*>(&descriptor);
1616 return CreateChannelShuffle(*channelShuffleQueueDescriptor, info);
1617 }
1618 case LayerType::Comparison :
1619 {
1620 auto comparisonQueueDescriptor = PolymorphicDowncast<const ComparisonQueueDescriptor*>(&descriptor);
1621 return CreateComparison(*comparisonQueueDescriptor, info);
1622 }
1623 case LayerType::Concat :
1624 {
1625 auto concatQueueDescriptor = PolymorphicDowncast<const ConcatQueueDescriptor*>(&descriptor);
1626 return CreateConcat(*concatQueueDescriptor, info);
1627 }
1628 case LayerType::Constant :
1629 {
1630 auto constantQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor);
1631 return CreateConstant(*constantQueueDescriptor, info);
1632 }
1633 case LayerType::ConvertBf16ToFp32 :
1634 {
1635 auto convertBf16ToFp32QueueDescriptor
1636 = PolymorphicDowncast<const ConvertBf16ToFp32QueueDescriptor*>(&descriptor);
1637 return CreateConvertBf16ToFp32(*convertBf16ToFp32QueueDescriptor, info);
1638 }
1639 case LayerType::ConvertFp16ToFp32:
1640 {
1641 auto convertFp16ToFp32QueueDescriptor
1642 = PolymorphicDowncast<const ConvertFp16ToFp32QueueDescriptor*>(&descriptor);
1643 return CreateConvertFp16ToFp32(*convertFp16ToFp32QueueDescriptor, info);
1644 }
1645 case LayerType::ConvertFp32ToBf16:
1646 {
1647 auto convertFp32ToBf16QueueDescriptor
1648 = PolymorphicDowncast<const ConvertFp32ToBf16QueueDescriptor*>(&descriptor);
1649 return CreateConvertFp32ToBf16(*convertFp32ToBf16QueueDescriptor, info);
1650 }
1651 case LayerType::ConvertFp32ToFp16:
1652 {
1653 auto convertFp32ToFp16QueueDescriptor
1654 = PolymorphicDowncast<const ConvertFp32ToFp16QueueDescriptor*>(&descriptor);
1655 return CreateConvertFp32ToFp16(*convertFp32ToFp16QueueDescriptor, info);
1656 }
1657 case LayerType::Convolution2d:
1658 {
1659 auto convolution2dQueueDescriptor = PolymorphicDowncast<const Convolution2dQueueDescriptor*>(&descriptor);
1660 return CreateConvolution2d(*convolution2dQueueDescriptor, info);
1661 }
1662 case LayerType::Convolution3d:
1663 {
1664 auto convolution3dQueueDescriptor = PolymorphicDowncast<const Convolution3dQueueDescriptor*>(&descriptor);
1665 return CreateConvolution3d(*convolution3dQueueDescriptor, info);
1666 }
1667 case LayerType::Debug:
1668 {
1669 auto debugQueueDescriptor = PolymorphicDowncast<const DebugQueueDescriptor*>(&descriptor);
1670 return CreateDebug(*debugQueueDescriptor, info);
1671 }
1672 case LayerType::DepthToSpace:
1673 {
1674 auto depthToSpaceQueueDescriptor = PolymorphicDowncast<const DepthToSpaceQueueDescriptor*>(&descriptor);
1675 return CreateDepthToSpace(*depthToSpaceQueueDescriptor, info);
1676 }
1677 case LayerType::DepthwiseConvolution2d:
1678 {
1679 auto depthwiseConvolution2DQueueDescriptor
1680 = PolymorphicDowncast<const DepthwiseConvolution2dQueueDescriptor*>(&descriptor);
1681 return CreateDepthwiseConvolution2d(*depthwiseConvolution2DQueueDescriptor, info);
1682 }
1683 case LayerType::Dequantize:
1684 {
1685 auto dequantizeQueueDescriptor = PolymorphicDowncast<const DequantizeQueueDescriptor*>(&descriptor);
1686 return CreateDequantize(*dequantizeQueueDescriptor, info);
1687 }
1688 case LayerType::DetectionPostProcess:
1689 {
1690 auto detectionPostProcessQueueDescriptor
1691 = PolymorphicDowncast<const DetectionPostProcessQueueDescriptor*>(&descriptor);
1692 return CreateDetectionPostProcess(*detectionPostProcessQueueDescriptor, info);
1693 }
1694 case LayerType::Division:
1695 {
1696 auto divisionQueueDescriptor = PolymorphicDowncast<const DivisionQueueDescriptor*>(&descriptor);
1697 return CreateDivision(*divisionQueueDescriptor, info);
1698 }
1699 case LayerType::ElementwiseUnary:
1700 {
1701 auto elementwiseUnaryQueueDescriptor
1702 = PolymorphicDowncast<const ElementwiseUnaryQueueDescriptor*>(&descriptor);
1703 return CreateElementwiseUnary(*elementwiseUnaryQueueDescriptor, info);
1704
1705 }
1706 case LayerType::FakeQuantization:
1707 {
1708 auto fakeQuantizationQueueDescriptor
1709 = PolymorphicDowncast<const FakeQuantizationQueueDescriptor*>(&descriptor);
1710 return CreateFakeQuantization(*fakeQuantizationQueueDescriptor, info);
1711 }
1712 case LayerType::Fill:
1713 {
1714 auto fillQueueDescriptor = PolymorphicDowncast<const FillQueueDescriptor*>(&descriptor);
1715 return CreateFill(*fillQueueDescriptor, info);
1716 }
1717 case LayerType::Floor:
1718 {
1719 auto floorQueueDescriptor = PolymorphicDowncast<const FloorQueueDescriptor*>(&descriptor);
1720 return CreateFloor(*floorQueueDescriptor, info);
1721 }
1722 case LayerType::FullyConnected:
1723 {
1724 auto fullyConnectedQueueDescriptor
1725 = PolymorphicDowncast<const FullyConnectedQueueDescriptor*>(&descriptor);
1726 return CreateFullyConnected(*fullyConnectedQueueDescriptor, info);
1727 }
1728 case LayerType::Gather:
1729 {
1730 auto gatherQueueDescriptor = PolymorphicDowncast<const GatherQueueDescriptor*>(&descriptor);
1731 return CreateGather(*gatherQueueDescriptor, info);
1732 }
1733 case LayerType::Input:
1734 {
1735 auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor);
1736 return CreateInput(*inputQueueDescriptor, info);
1737 }
1738 case LayerType::InstanceNormalization:
1739 {
1740 auto instanceNormalizationQueueDescriptor
1741 = PolymorphicDowncast<const InstanceNormalizationQueueDescriptor*>(&descriptor);
1742 return CreateInstanceNormalization(*instanceNormalizationQueueDescriptor, info);
1743 }
1744 case LayerType::L2Normalization:
1745 {
1746 auto l2NormalizationQueueDescriptor
1747 = PolymorphicDowncast<const L2NormalizationQueueDescriptor*>(&descriptor);
1748 return CreateL2Normalization(*l2NormalizationQueueDescriptor, info);
1749 }
1750 case LayerType::LogicalBinary:
1751 {
1752 auto logicalBinaryQueueDescriptor = PolymorphicDowncast<const LogicalBinaryQueueDescriptor*>(&descriptor);
1753 return CreateLogicalBinary(*logicalBinaryQueueDescriptor, info);
1754 }
1755 case LayerType::LogSoftmax:
1756 {
1757 auto logSoftmaxQueueDescriptor = PolymorphicDowncast<const LogSoftmaxQueueDescriptor*>(&descriptor);
1758 return CreateLogSoftmax(*logSoftmaxQueueDescriptor, info);
1759 }
1760 case LayerType::Lstm:
1761 {
1762 auto lstmQueueDescriptor = PolymorphicDowncast<const LstmQueueDescriptor*>(&descriptor);
1763 return CreateLstm(*lstmQueueDescriptor, info);
1764 }
1765 case LayerType::Maximum:
1766 {
1767 auto maximumQueueDescriptor = PolymorphicDowncast<const MaximumQueueDescriptor*>(&descriptor);
1768 return CreateMaximum(*maximumQueueDescriptor, info);
1769 }
1770 case LayerType::Mean:
1771 {
1772 auto meanQueueDescriptor = PolymorphicDowncast<const MeanQueueDescriptor*>(&descriptor);
1773 return CreateMean(*meanQueueDescriptor, info);
1774 }
1775 case LayerType::MemCopy:
1776 {
1777 auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
1778 return CreateMemCopy(*memCopyQueueDescriptor, info);
1779 }
1780 case LayerType::MemImport:
1781 {
1782 auto memImportQueueDescriptor = PolymorphicDowncast<const MemImportQueueDescriptor*>(&descriptor);
1783 return CreateMemImport(*memImportQueueDescriptor, info);
1784 }
1785 case LayerType::Minimum:
1786 {
1787 auto minimumQueueDescriptor = PolymorphicDowncast<const MinimumQueueDescriptor*>(&descriptor);
1788 return CreateMinimum(*minimumQueueDescriptor, info);
1789 }
1790 case LayerType::Multiplication:
1791 {
1792 auto multiplicationQueueDescriptor
1793 = PolymorphicDowncast<const MultiplicationQueueDescriptor*>(&descriptor);
1794 return CreateMultiplication(*multiplicationQueueDescriptor, info);
1795 }
1796 case LayerType::Normalization:
1797 {
1798 auto normalizationQueueDescriptor = PolymorphicDowncast<const NormalizationQueueDescriptor*>(&descriptor);
1799 return CreateNormalization(*normalizationQueueDescriptor, info);
1800 }
1801 case LayerType::Output:
1802 {
1803 auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor);
1804 return CreateOutput(*outputQueueDescriptor, info);
1805 }
1806 case LayerType::Pad:
1807 {
1808 auto padQueueDescriptor = PolymorphicDowncast<const PadQueueDescriptor*>(&descriptor);
1809 return CreatePad(*padQueueDescriptor, info);
1810 }
1811 case LayerType::Permute:
1812 {
1813 auto permuteQueueDescriptor = PolymorphicDowncast<const PermuteQueueDescriptor*>(&descriptor);
1814 return CreatePermute(*permuteQueueDescriptor, info);
1815 }
1816 case LayerType::Pooling2d:
1817 {
1818 auto pooling2dQueueDescriptor = PolymorphicDowncast<const Pooling2dQueueDescriptor*>(&descriptor);
1819 return CreatePooling2d(*pooling2dQueueDescriptor, info);
1820 }
1821 case LayerType::Pooling3d:
1822 {
1823 auto pooling3dQueueDescriptor = PolymorphicDowncast<const Pooling3dQueueDescriptor*>(&descriptor);
1824 return CreatePooling3d(*pooling3dQueueDescriptor, info);
1825 }
1826 case LayerType::PreCompiled:
1827 {
1828 auto preCompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
1829 return CreatePreCompiled(*preCompiledQueueDescriptor, info);
1830 }
1831 case LayerType::Prelu:
1832 {
1833 auto preluQueueDescriptor = PolymorphicDowncast<const PreluQueueDescriptor*>(&descriptor);
1834 return CreatePrelu(*preluQueueDescriptor, info);
1835 }
1836 case LayerType::QLstm:
1837 {
1838 auto qlstmQueueDescriptor = PolymorphicDowncast<const QLstmQueueDescriptor*>(&descriptor);
1839 return CreateQLstm(*qlstmQueueDescriptor, info);
1840 }
1841 case LayerType::Quantize:
1842 {
1843 auto quantizeQueueDescriptor = PolymorphicDowncast<const QuantizeQueueDescriptor*>(&descriptor);
1844 return CreateQuantize(*quantizeQueueDescriptor, info);
1845 }
1846 case LayerType::Rank:
1847 {
1848 auto rankQueueDescriptor = PolymorphicDowncast<const RankQueueDescriptor*>(&descriptor);
1849 return CreateRank(*rankQueueDescriptor, info);
1850 }
1851 case LayerType::Reduce:
1852 {
1853 auto reduceQueueDescriptor = PolymorphicDowncast<const ReduceQueueDescriptor*>(&descriptor);
1854 return CreateReduce(*reduceQueueDescriptor, info);
1855 }
1856 case LayerType::Reshape:
1857 {
1858 auto reshapeQueueDescriptor = PolymorphicDowncast<const ReshapeQueueDescriptor*>(&descriptor);
1859 return CreateReshape(*reshapeQueueDescriptor, info);
1860 }
1861 case LayerType::Resize:
1862 {
1863 auto resizeQueueDescriptor = PolymorphicDowncast<const ResizeQueueDescriptor*>(&descriptor);
1864 return CreateResize(*resizeQueueDescriptor, info);
1865 }
1866 case LayerType::Shape:
1867 {
1868 auto shapeQueueDescriptor = PolymorphicDowncast<const ShapeQueueDescriptor*>(&descriptor);
1869 return CreateShape(*shapeQueueDescriptor, info);
1870 }
1871 case LayerType::Slice:
1872 {
1873 auto sliceQueueDescriptor = PolymorphicDowncast<const SliceQueueDescriptor*>(&descriptor);
1874 return CreateSlice(*sliceQueueDescriptor, info);
1875 }
1876 case LayerType::Softmax:
1877 {
1878 auto softmaxQueueDescriptor = PolymorphicDowncast<const SoftmaxQueueDescriptor*>(&descriptor);
1879 return CreateSoftmax(*softmaxQueueDescriptor, info);
1880 }
1881 case LayerType::SpaceToBatchNd:
1882 {
1883 auto spaceToBatchNdQueueDescriptor
1884 = PolymorphicDowncast<const SpaceToBatchNdQueueDescriptor*>(&descriptor);
1885 return CreateSpaceToBatchNd(*spaceToBatchNdQueueDescriptor, info);
1886 }
1887 case LayerType::SpaceToDepth:
1888 {
1889 auto spaceToDepthQueueDescriptor = PolymorphicDowncast<const SpaceToDepthQueueDescriptor*>(&descriptor);
1890 return CreateSpaceToDepth(*spaceToDepthQueueDescriptor, info);
1891 }
1892 case LayerType::Splitter:
1893 {
1894 auto splitterQueueDescriptor = PolymorphicDowncast<const SplitterQueueDescriptor*>(&descriptor);
1895 return CreateSplitter(*splitterQueueDescriptor, info);
1896 }
1897 case LayerType::Stack:
1898 {
1899 auto stackQueueDescriptor = PolymorphicDowncast<const StackQueueDescriptor*>(&descriptor);
1900 return CreateStack(*stackQueueDescriptor, info);
1901 }
1902 case LayerType::StridedSlice:
1903 {
1904 auto stridedSliceQueueDescriptor = PolymorphicDowncast<const StridedSliceQueueDescriptor*>(&descriptor);
1905 return CreateStridedSlice(*stridedSliceQueueDescriptor, info);
1906 }
1907 case LayerType::Subtraction:
1908 {
1909 auto subtractionQueueDescriptor = PolymorphicDowncast<const SubtractionQueueDescriptor*>(&descriptor);
1910 return CreateSubtraction(*subtractionQueueDescriptor, info);
1911 }
1912 case LayerType::Transpose:
1913 {
1914 auto transposeQueueDescriptor = PolymorphicDowncast<const TransposeQueueDescriptor*>(&descriptor);
1915 return CreateTranspose(*transposeQueueDescriptor, info);
1916 }
1917 case LayerType::TransposeConvolution2d:
1918 {
1919 auto transposeConvolution2dQueueDescriptor
1920 = PolymorphicDowncast<const TransposeConvolution2dQueueDescriptor*>(&descriptor);
1921 return CreateTransposeConvolution2d(*transposeConvolution2dQueueDescriptor, info);
1922 }
1923 case LayerType::UnidirectionalSequenceLstm:
1924 {
1925 auto unidirectionalSequenceLstmQueueDescriptor
1926 = PolymorphicDowncast<const UnidirectionalSequenceLstmQueueDescriptor*>(&descriptor);
1927 return CreateUnidirectionalSequenceLstm(*unidirectionalSequenceLstmQueueDescriptor, info);
1928 }
1929 default:
1930 return nullptr;
1931 }
1932}
1933ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan04a72972020-09-14 15:44:18 +01001934
Derek Lamberti901ea112019-12-10 22:07:09 +00001935std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1936 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001937{
1938 return std::unique_ptr<IWorkload>();
1939}
1940
Derek Lamberti901ea112019-12-10 22:07:09 +00001941std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1942 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001943{
1944 return std::unique_ptr<IWorkload>();
1945}
1946
Derek Lamberti901ea112019-12-10 22:07:09 +00001947std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1948 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001949{
1950 return std::unique_ptr<IWorkload>();
1951}
1952
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001953std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001954 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001955{
1956 return std::unique_ptr<IWorkload>();
1957}
1958
Derek Lamberti901ea112019-12-10 22:07:09 +00001959std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1960 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001961{
1962 return std::unique_ptr<IWorkload>();
1963}
1964
mathad01b392e982021-04-07 12:07:30 +01001965std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1966 const WorkloadInfo& /*info*/) const
1967{
1968 return std::unique_ptr<IWorkload>();
1969}
1970
Simon Obute51f67772021-09-03 15:50:13 +01001971std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1972 const WorkloadInfo& /*info*/) const
1973{
1974 return std::unique_ptr<IWorkload>();
1975}
1976
Derek Lamberti901ea112019-12-10 22:07:09 +00001977std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1978 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001979{
1980 return std::unique_ptr<IWorkload>();
1981}
1982
Derek Lamberti901ea112019-12-10 22:07:09 +00001983std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1984 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001985{
1986 return std::unique_ptr<IWorkload>();
1987}
1988
Derek Lamberti901ea112019-12-10 22:07:09 +00001989std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1990 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001991{
1992 return std::unique_ptr<IWorkload>();
1993}
1994
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001995std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1996 const WorkloadInfo& /*info*/) const
1997{
1998 return std::unique_ptr<IWorkload>();
1999}
2000
Derek Lamberti901ea112019-12-10 22:07:09 +00002001std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
2002 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002003{
2004 return std::unique_ptr<IWorkload>();
2005}
2006
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002007std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
2008 const WorkloadInfo& /*info*/) const
2009{
2010 return std::unique_ptr<IWorkload>();
2011}
2012
Derek Lamberti901ea112019-12-10 22:07:09 +00002013std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
2014 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002015{
2016 return std::unique_ptr<IWorkload>();
2017}
2018
Derek Lamberti901ea112019-12-10 22:07:09 +00002019std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
2020 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002021{
2022 return std::unique_ptr<IWorkload>();
2023}
2024
Matthew Sloyanb63a3112021-09-08 13:05:51 +01002025std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
2026 const WorkloadInfo& /*info*/) const
2027{
2028 return std::unique_ptr<IWorkload>();
2029}
2030
Derek Lamberti901ea112019-12-10 22:07:09 +00002031std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
2032 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002033{
2034 return std::unique_ptr<IWorkload>();
2035}
2036
Derek Lamberti901ea112019-12-10 22:07:09 +00002037std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
2038 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002039{
2040 return std::unique_ptr<IWorkload>();
2041}
2042
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002043std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002044 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002045{
2046 return std::unique_ptr<IWorkload>();
2047}
2048
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002049std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00002050 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002051{
2052 return std::unique_ptr<IWorkload>();
2053}
2054
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002055std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00002056 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002057{
2058 return std::unique_ptr<IWorkload>();
2059}
2060
Derek Lamberti901ea112019-12-10 22:07:09 +00002061std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
2062 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002063{
2064 return std::unique_ptr<IWorkload>();
2065}
2066
josh minor4a3c6102020-01-06 16:40:46 -06002067std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2068 const WorkloadInfo& /*info*/) const
2069{
2070 return std::unique_ptr<IWorkload>();
2071}
2072
Derek Lamberti901ea112019-12-10 22:07:09 +00002073std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
2074 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002075{
2076 return std::unique_ptr<IWorkload>();
2077}
2078
Ryan OSheaec6c6802020-06-05 17:17:06 +01002079std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
2080 const WorkloadInfo& /*info*/) const
2081{
2082 return std::unique_ptr<IWorkload>();
2083}
2084
Derek Lamberti901ea112019-12-10 22:07:09 +00002085std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
2086 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002087{
2088 return std::unique_ptr<IWorkload>();
2089}
2090
Derek Lamberti901ea112019-12-10 22:07:09 +00002091std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
2092 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002093{
2094 return std::unique_ptr<IWorkload>();
2095}
2096
Derek Lamberti901ea112019-12-10 22:07:09 +00002097std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
2098 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002099{
2100 return std::unique_ptr<IWorkload>();
2101}
2102
Kevin Mayce5045a2019-10-02 14:07:47 +01002103std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00002104 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
2105 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01002106{
2107 return std::unique_ptr<IWorkload>();
2108}
2109
Derek Lamberti901ea112019-12-10 22:07:09 +00002110std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
2111 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002112{
2113 return std::unique_ptr<IWorkload>();
2114}
2115
James Conroyaba90cd2020-11-06 16:28:18 +00002116std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
2117 const WorkloadInfo& /*info*/) const
2118{
2119 return std::unique_ptr<IWorkload>();
2120}
2121
2122std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2123 const WorkloadInfo& /*info*/) const
2124{
2125 return std::unique_ptr<IWorkload>();
2126}
2127
Derek Lamberti901ea112019-12-10 22:07:09 +00002128std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
2129 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01002130{
2131 return std::unique_ptr<IWorkload>();
2132}
2133
Derek Lamberti901ea112019-12-10 22:07:09 +00002134std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
2135 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002136{
2137 return std::unique_ptr<IWorkload>();
2138}
2139
Derek Lamberti901ea112019-12-10 22:07:09 +00002140std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
2141 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002142{
2143 return std::unique_ptr<IWorkload>();
2144}
2145
Derek Lamberti901ea112019-12-10 22:07:09 +00002146std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
2147 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002148{
2149 return std::unique_ptr<IWorkload>();
2150}
2151
Derek Lamberti901ea112019-12-10 22:07:09 +00002152std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
2153 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002154{
2155 return std::unique_ptr<IWorkload>();
2156}
2157
Derek Lamberti901ea112019-12-10 22:07:09 +00002158std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
2159 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01002160{
2161 return std::unique_ptr<IWorkload>();
2162}
2163
Derek Lamberti901ea112019-12-10 22:07:09 +00002164std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
2165 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002166{
2167 return std::unique_ptr<IWorkload>();
2168}
2169
Derek Lamberti901ea112019-12-10 22:07:09 +00002170std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
2171 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002172{
2173 return std::unique_ptr<IWorkload>();
2174}
2175
Derek Lamberti901ea112019-12-10 22:07:09 +00002176std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
2177 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002178{
2179 return std::unique_ptr<IWorkload>();
2180}
2181
Derek Lamberti901ea112019-12-10 22:07:09 +00002182std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
2183 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002184{
2185 return std::unique_ptr<IWorkload>();
2186}
2187
Derek Lamberti901ea112019-12-10 22:07:09 +00002188std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
2189 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002190{
2191 return std::unique_ptr<IWorkload>();
2192}
2193
Derek Lamberti901ea112019-12-10 22:07:09 +00002194std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
2195 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002196{
2197 return std::unique_ptr<IWorkload>();
2198}
2199
Derek Lamberti901ea112019-12-10 22:07:09 +00002200std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002201 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002202{
2203 return std::unique_ptr<IWorkload>();
2204}
2205
Derek Lamberti901ea112019-12-10 22:07:09 +00002206std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
2207 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002208{
2209 return std::unique_ptr<IWorkload>();
2210}
2211
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002212std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling3d(const Pooling3dQueueDescriptor& /*descriptor*/,
2213 const WorkloadInfo& /*info*/) const
2214{
2215 return std::unique_ptr<IWorkload>();
2216}
2217
Derek Lamberti901ea112019-12-10 22:07:09 +00002218std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
2219 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002220{
2221 return std::unique_ptr<IWorkload>();
2222}
2223
Derek Lamberti901ea112019-12-10 22:07:09 +00002224std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
2225 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002226{
2227 return std::unique_ptr<IWorkload>();
2228}
2229
Derek Lamberti901ea112019-12-10 22:07:09 +00002230std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
2231 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002232{
2233 return std::unique_ptr<IWorkload>();
2234}
2235
James Conroy586a9aa2020-03-20 08:49:33 +00002236std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
2237 const WorkloadInfo& /*info*/) const
2238{
2239 return std::unique_ptr<IWorkload>();
2240}
2241
Derek Lamberti901ea112019-12-10 22:07:09 +00002242std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
2243 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01002244{
2245 return std::unique_ptr<IWorkload>();
2246}
Finn Williams2605b232020-06-10 15:53:46 +01002247std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
2248 const WorkloadInfo& /*info*/) const
2249{
2250 return std::unique_ptr<IWorkload>();
2251}
James Conroyee18dc82019-07-17 11:27:46 +01002252
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002253std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
2254 const WorkloadInfo& /*info*/) const
2255{
2256 return std::unique_ptr<IWorkload>();
2257}
2258
Derek Lamberti901ea112019-12-10 22:07:09 +00002259std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
2260 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002261{
2262 return std::unique_ptr<IWorkload>();
2263}
2264
Derek Lamberti901ea112019-12-10 22:07:09 +00002265std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
2266 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01002267{
2268 return std::unique_ptr<IWorkload>();
2269}
2270
Keith Davis3ae3f972021-05-21 16:33:48 +01002271std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
2272 const WorkloadInfo& /*info*/) const
2273{
2274 return std::unique_ptr<IWorkload>();
2275}
2276
Derek Lamberti901ea112019-12-10 22:07:09 +00002277std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
2278 const WorkloadInfo& /*info*/) const
2279{
2280 return std::unique_ptr<IWorkload>();
2281}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002282
Derek Lamberti901ea112019-12-10 22:07:09 +00002283std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
2284 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002285{
2286 return std::unique_ptr<IWorkload>();
2287}
2288
Derek Lamberti901ea112019-12-10 22:07:09 +00002289std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
2290 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002291{
2292 return std::unique_ptr<IWorkload>();
2293}
2294
Derek Lamberti901ea112019-12-10 22:07:09 +00002295std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
2296 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002297{
2298 return std::unique_ptr<IWorkload>();
2299}
2300
Derek Lamberti901ea112019-12-10 22:07:09 +00002301std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
2302 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002303{
2304 return std::unique_ptr<IWorkload>();
2305}
2306
Derek Lamberti901ea112019-12-10 22:07:09 +00002307std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
2308 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01002309{
2310 return std::unique_ptr<IWorkload>();
2311}
2312
Derek Lamberti901ea112019-12-10 22:07:09 +00002313std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
2314 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01002315{
2316 return std::unique_ptr<IWorkload>();
2317}
2318
Derek Lamberti901ea112019-12-10 22:07:09 +00002319std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
2320 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002321{
2322 return std::unique_ptr<IWorkload>();
2323}
2324
Derek Lamberti901ea112019-12-10 22:07:09 +00002325std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
2326 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01002327{
2328 return std::unique_ptr<IWorkload>();
2329}
2330
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002331std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
2332 const WorkloadInfo& /*info*/) const
2333{
2334 return std::unique_ptr<IWorkload>();
2335}
2336
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002337std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002338 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
2339 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002340{
2341 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01002342}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002343
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01002344std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
2345 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
2346 const WorkloadInfo& /*info*/) const
2347{
2348 return std::unique_ptr<IWorkload>();
2349}
2350
Francis Murtaghbf354142022-08-12 13:54:17 +01002351std::unique_ptr<IWorkload> IWorkloadFactory::CreateInput(
2352 const InputQueueDescriptor& /*descriptor*/,
2353 const WorkloadInfo& /*info*/) const
2354{
2355 return std::unique_ptr<IWorkload>();
2356}
2357
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002358} // namepsace armnn