blob: 3660e6e721a5901b76d118829d4637881f6561ac [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnn/backends/IBackendInternal.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Colm Donelan0c479742021-12-10 12:43:54 +000017#include <armnn/backends/WorkloadFactory.hpp>
18#include <armnn/backends/TensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
David Beck111b5d92018-11-12 14:59:37 +000020#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000021
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
telsoa01c577f2c2018-08-31 09:22:23 +010025namespace
26{
Finn Williams3e54d032020-10-22 16:53:35 +010027using LayerList = std::list<Layer*>;
28using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010029
David Beck29c75de2018-10-23 13:35:58 +010030const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
31{
32 if (!type)
33 {
34 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010035 }
36
Matthew Sloyan81beae32021-07-13 19:46:11 +010037 return TensorInfo(info.GetShape(),
38 type.value(),
39 info.GetQuantizationScale(),
40 info.GetQuantizationOffset(),
41 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010042}
43
David Beck29c75de2018-10-23 13:35:58 +010044} // anonymous namespace
45
Sadik Armagana097d2a2021-11-24 15:47:28 +000046inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
47{
48 if (!weightsType)
49 {
50 return weightsType;
51 }
52
53 switch(weightsType.value())
54 {
55 case armnn::DataType::BFloat16:
56 case armnn::DataType::Float16:
57 case armnn::DataType::Float32:
58 return weightsType;
59 case armnn::DataType::QAsymmS8:
60 case armnn::DataType::QAsymmU8:
61 case armnn::DataType::QSymmS8:
62 case armnn::DataType::QSymmS16:
63 return armnn::DataType::Signed32;
64 default:
65 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
66 }
67 return armnn::EmptyOptional();
68}
69
70
Sadik Armagan045f6be2020-09-10 13:37:32 +010071bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
72 const IConnectableLayer& connectableLayer,
73 Optional<DataType> dataType,
74 std::string& outReasonIfUnsupported,
75 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000076{
David Beck33f0ae02018-10-18 15:13:56 +010077 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000078 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010079 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010080
David Beck111b5d92018-11-12 14:59:37 +000081 auto const& backendRegistry = BackendRegistryInstance();
82 if (!backendRegistry.IsBackendRegistered(backendId))
83 {
84 std::stringstream ss;
85 ss << connectableLayer.GetName() << " is not supported on " << backendId
86 << " because this backend is not registered.";
87
88 outReasonIfUnsupported = ss.str();
89 return false;
90 }
91
92 auto backendFactory = backendRegistry.GetFactory(backendId);
93 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000094 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010095
telsoa014fcda012018-03-09 14:13:49 +000096 switch(layer.GetType())
97 {
98 case LayerType::Activation:
99 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100100 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000101 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100102 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000103 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100104 OverrideDataType(input, dataType),
105 OverrideDataType(output, dataType),
106 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100107 reason);
telsoa014fcda012018-03-09 14:13:49 +0000108 break;
109 }
110 case LayerType::Addition:
111 {
112 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
113 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
114 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000115 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100116 OverrideDataType(input0, dataType),
117 OverrideDataType(input1, dataType),
118 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100119 reason);
telsoa014fcda012018-03-09 14:13:49 +0000120 break;
121 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100122 case LayerType::ArgMinMax:
123 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100124 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100125 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
126
127 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000129 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100130 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000131 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100132 descriptor,
133 reason);
134 break;
135 }
telsoa014fcda012018-03-09 14:13:49 +0000136 case LayerType::BatchNormalization:
137 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100138 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000139 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100140 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
141 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
142 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
143 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
144 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000145 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100146 OverrideDataType(input, dataType),
147 OverrideDataType(output, dataType),
148 OverrideDataType(mean, dataType),
149 OverrideDataType(var, dataType),
150 OverrideDataType(beta, dataType),
151 OverrideDataType(gamma, dataType),
152 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100153 reason);
telsoa014fcda012018-03-09 14:13:49 +0000154 break;
155 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000156 case LayerType::BatchToSpaceNd:
157 {
158 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
159 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100160 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000161
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000162 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
163 OverrideDataType(output, dataType),
164 cLayer->GetParameters(),
165 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000166 break;
167 }
mathad01b392e982021-04-07 12:07:30 +0100168 case LayerType::Cast:
169 {
170 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
171 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
172
173 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
174 OverrideDataType(output, dataType),
175 reason);
176 break;
177 }
Simon Obute51f67772021-09-03 15:50:13 +0100178 case LayerType::ChannelShuffle:
179 {
180 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
181
182 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
183 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
184
185 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
186
187 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
188 OverrideDataType(output, dataType),
189 descriptor,
190 reason);
191 break;
192 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100193 case LayerType::Comparison:
194 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100195 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100196
197 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
198 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
199 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
200
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000201 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
202 OverrideDataType(input1, dataType),
203 OverrideDataType(output, DataType::Boolean),
204 cLayer->GetParameters(),
205 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100206 break;
207 }
telsoa014fcda012018-03-09 14:13:49 +0000208 case LayerType::Constant:
209 {
210 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000211 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100212 break;
213 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000214 case LayerType::ConvertBf16ToFp32:
215 {
216 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
217 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000218 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000219 break;
220 }
telsoa01c577f2c2018-08-31 09:22:23 +0100221 case LayerType::ConvertFp16ToFp32:
222 {
223 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
224 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000225 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100226 break;
227 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000228 case LayerType::ConvertFp32ToBf16:
229 {
230 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
231 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000232 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000233 break;
234 }
telsoa01c577f2c2018-08-31 09:22:23 +0100235 case LayerType::ConvertFp32ToFp16:
236 {
237 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
238 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000239 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000240 break;
241 }
242 case LayerType::Convolution2d:
243 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100244 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100245
246 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
247 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100248 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Keith Davis2cddc722022-04-07 11:32:00 +0100249 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
250 "Convolution2dLayer: Weights should be connected as a Constant Layer.");
251 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
252 dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100253
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100254 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100255
arovir01a6824102018-08-28 17:40:45 +0100256 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100257 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100258 if (descriptor.m_BiasEnabled)
259 {
Keith Davis2cddc722022-04-07 11:32:00 +0100260 ARMNN_ASSERT_MSG(layer.GetInputSlot(2).GetConnection(),
261 "Convolution2dLayer: Bias should be connected as a Constant Layer.");
262 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
263 GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100264 }
265
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000266 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100267 input,
268 output,
269 descriptor,
Keith Davis2cddc722022-04-07 11:32:00 +0100270 weights,
arovir01a6824102018-08-28 17:40:45 +0100271 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100272 reason);
telsoa014fcda012018-03-09 14:13:49 +0000273 break;
274 }
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100275 case LayerType::Convolution3d:
276 {
277 auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer);
278
279 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
280 dataType);
281 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100282
283 ARMNN_ASSERT_MSG(layer.GetInputSlot(1).GetConnection(),
284 "Convolution3dLayer: Weights should be connected as a Constant Layer.");
285 const TensorInfo weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
286 dataType);
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100287
288 const Convolution3dDescriptor& descriptor = cLayer->GetParameters();
289
290 // Construct optional biases object based on the value of m_BiasEnabled
291 Optional<TensorInfo> biases;
292 if (descriptor.m_BiasEnabled)
293 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100294 biases = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
295 GetBiasTypeFromWeightsType(dataType));
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100296 }
297
298 result = layerSupportObject.IsConvolution3dSupported(
299 input,
300 output,
301 descriptor,
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +0100302 weights,
Matthew Sloyanb63a3112021-09-08 13:05:51 +0100303 biases,
304 reason);
305 break;
306 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000307 case LayerType::Debug:
308 {
309 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
310 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
311
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000312 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000313 OverrideDataType(output, dataType),
314 reason);
315 break;
316 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100317 case LayerType::DepthToSpace:
318 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100319 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100320
321 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
322 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
323
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000324 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100325 OverrideDataType(output, dataType),
326 cLayer->GetParameters(),
327 reason);
328 break;
329 }
telsoa014fcda012018-03-09 14:13:49 +0000330 case LayerType::DepthwiseConvolution2d:
331 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100332 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
Cathal Corbett06902652022-04-14 17:55:11 +0100333 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
334 dataType);
335 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
336 const TensorInfo& weights = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
337 dataType);
338
339 ARMNN_ASSERT(cLayer->GetInputSlot(1).GetConnection() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100340
telsoa01c577f2c2018-08-31 09:22:23 +0100341 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100342
343 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100344 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100345 if (descriptor.m_BiasEnabled)
346 {
Cathal Corbett06902652022-04-14 17:55:11 +0100347 biases = OverrideDataType(cLayer->GetInputSlot(2).GetConnection()->GetTensorInfo(),
348 GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100349 }
telsoa01c577f2c2018-08-31 09:22:23 +0100350
Cathal Corbett06902652022-04-14 17:55:11 +0100351 result = layerSupportObject.IsDepthwiseConvolutionSupported(input,
352 output,
353 descriptor,
354 weights,
355 biases,
356 reason);
telsoa014fcda012018-03-09 14:13:49 +0000357 break;
358 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000359 case LayerType::Dequantize:
360 {
361 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
362 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
363
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000364 result = layerSupportObject.IsDequantizeSupported(input,
365 OverrideDataType(output, dataType),
366 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000367 break;
368 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000369 case LayerType::DetectionPostProcess:
370 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100371 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000372 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
373 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
374 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
375
376 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
377 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
378 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
379 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
380
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000381 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000382 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
383 scores,
384 anchors,
385 detectionBoxes,
386 detectionClasses,
387 detectionScores,
388 numDetections,
389 descriptor,
390 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000391 break;
392 }
josh minor4a3c6102020-01-06 16:40:46 -0600393 case LayerType::ElementwiseUnary:
394 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100395 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600396
397 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
398 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
399
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000400 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
401 OverrideDataType(output, dataType),
402 cLayer->GetParameters(),
403 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600404 break;
405 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100406 case LayerType::Fill:
407 {
408 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
409 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
410 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
411 const FillDescriptor& descriptor = cLayer->GetParameters();
412
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000413 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100414 OverrideDataType(input, dataType),
415 OverrideDataType(output, dataType),
416 descriptor,
417 reason);
418 break;
419 }
telsoa014fcda012018-03-09 14:13:49 +0000420 case LayerType::FakeQuantization:
421 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100422 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000423 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000424 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
425 cLayer->GetParameters(),
426 reason);
telsoa014fcda012018-03-09 14:13:49 +0000427 break;
428 }
429 case LayerType::Floor:
430 {
431 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
432 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000433 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
434 OverrideDataType(output, dataType),
435 reason);
telsoa014fcda012018-03-09 14:13:49 +0000436 break;
437 }
438 case LayerType::FullyConnected:
439 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100440 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000441 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100442 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000443
444 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
445 TensorInfo weightsInfo;
446 const TensorInfo* weightsInfoPtr = nullptr;
447
Matthew Sloyan81beae32021-07-13 19:46:11 +0100448 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000449 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100450
451 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000452 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000453 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100454 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
455 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
456 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
457
telsoa01c577f2c2018-08-31 09:22:23 +0100458 if (descriptor.m_BiasEnabled)
459 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100460 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
461 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100462 }
463 else
464 {
465 // If biases are not enabled pass a dummy tensorinfo for the validation
466 switch(input.GetDataType())
467 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000468 case DataType::BFloat16:
469 {
470 biasInfoPtr = &dummyBFloat16Bias;
471 break;
472 }
telsoa01c577f2c2018-08-31 09:22:23 +0100473 case DataType::Float16:
474 {
475 biasInfoPtr = &dummyFloat16Bias;
476 break;
477 }
478 case DataType::Float32:
479 {
480 biasInfoPtr = &dummyFloat32Bias;
481 break;
482 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000483 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000484 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000485 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000486 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100487 {
488 biasInfoPtr = &dummyQA8Bias;
489 break;
490 }
491 default:
492 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100493 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100494 }
495 }
496 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000497 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100498 OverrideDataType(input, dataType),
499 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000500 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100501 *biasInfoPtr,
502 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100503 reason);
telsoa014fcda012018-03-09 14:13:49 +0000504 break;
505 }
narpra01b89b05f2019-01-16 09:53:09 +0000506 case LayerType::Gather:
507 {
508 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
509 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
510 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100511 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
512 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000513 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
514 input1,
515 OverrideDataType(output, dataType),
516 descriptor,
517 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000518 break;
519 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100520 case LayerType::GatherNd:
521 {
522 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
523 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
524 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
525 result = layerSupportObject.IsGatherNdSupported(OverrideDataType(input0, dataType),
526 input1,
527 OverrideDataType(output, dataType),
528 reason);
529 break;
530 }
telsoa014fcda012018-03-09 14:13:49 +0000531 case LayerType::Input:
532 {
533 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000534 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000535 break;
536 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100537 case LayerType::InstanceNormalization:
538 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100539 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100540 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
541
542 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
543 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
544
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000545 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100546 OverrideDataType(input, dataType),
547 OverrideDataType(output, dataType),
548 descriptor,
549 reason);
550 break;
551 }
telsoa014fcda012018-03-09 14:13:49 +0000552 case LayerType::L2Normalization:
553 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100554 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100555 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
556
telsoa014fcda012018-03-09 14:13:49 +0000557 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100558 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100559
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000560 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100561 OverrideDataType(input, dataType),
562 OverrideDataType(output, dataType),
563 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100564 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100565 break;
566 }
James Conroyaba90cd2020-11-06 16:28:18 +0000567 case LayerType::LogicalBinary:
568 {
569 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
570
571 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
573 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
574
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000575 result = layerSupportObject.IsLogicalBinarySupported(input0,
576 input1,
577 output,
578 cLayer->GetParameters(),
579 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000580 break;
581 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100582 case LayerType::LogSoftmax:
583 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100584 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100585
586 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
587 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
588
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000589 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
590 OverrideDataType(output, dataType),
591 cLayer->GetParameters(),
592 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100593 break;
594 }
telsoa01c577f2c2018-08-31 09:22:23 +0100595 case LayerType::Lstm:
596 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100597 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100598 const LstmDescriptor& descriptor = cLayer->GetParameters();
599
600 // All inputs.
601 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
602 dataType);
603 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
604 dataType);
605 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
606 dataType);
607 // All outputs
608 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
609 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
610 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
611 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
612
613 // Basic parameters
614 const TensorInfo& inputToForgetWeights
615 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
616 const TensorInfo& inputToCellWeights
617 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
618 const TensorInfo& inputToOutputWeights
619 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
620 const TensorInfo& recurrentToForgetWeights
621 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
622 const TensorInfo& recurrentToCellWeights
623 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
624 const TensorInfo& recurrentToOutputWeights
625 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
626 const TensorInfo& forgetGateBias
627 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
628 const TensorInfo& cellBias
629 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
630 const TensorInfo& outputGateBias
631 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
632
Jan Eilersd01a83c2019-07-03 18:20:40 +0100633 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100634
Jan Eilersd01a83c2019-07-03 18:20:40 +0100635 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
636 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
637 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
638 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
639 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
640 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
641 paramsInfo.m_ForgetGateBias = &forgetGateBias;
642 paramsInfo.m_CellBias = &cellBias;
643 paramsInfo.m_OutputGateBias = &outputGateBias;
644
645
646 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100647 TensorInfo optInputToInputWeights;
648 TensorInfo optRecurrentToInputWeights;
649 TensorInfo optCellToInputWeights;
650 TensorInfo optInputGateBias;
651 TensorInfo optProjectionWeights;
652 TensorInfo optProjectionBias;
653 TensorInfo optCellToForgetWeights;
654 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100655 TensorInfo optInputLayerNormWeights;
656 TensorInfo optForgetLayerNormWeights;
657 TensorInfo optCellLayerNormWeights;
658 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100659
660 if(!descriptor.m_CifgEnabled)
661 {
662 optInputToInputWeights =
663 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100664 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100665
666 optRecurrentToInputWeights =
667 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100668 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100669 optInputGateBias =
670 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100671 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100672 }
673
674 if(descriptor.m_ProjectionEnabled)
675 {
676 optProjectionWeights =
677 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100678 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100679 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
680 {
681 optProjectionBias =
682 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100683 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100684 }
685 }
686
687 if(descriptor.m_PeepholeEnabled)
688 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100689 if(!descriptor.m_CifgEnabled)
690 {
691 optCellToInputWeights =
692 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
693 dataType);
694 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
695 }
telsoa01c577f2c2018-08-31 09:22:23 +0100696 optCellToForgetWeights =
697 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100698 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100699 optCellToOutputWeights =
700 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100701 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100702 }
703
Jan Eilers38e05bd2019-06-26 13:10:09 +0100704 if(descriptor.m_LayerNormEnabled)
705 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100706 if (!descriptor.m_CifgEnabled)
707 {
708 optInputLayerNormWeights = OverrideDataType(
709 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
710 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
711 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100712
713 optForgetLayerNormWeights = OverrideDataType(
714 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100715 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100716
717 optCellLayerNormWeights = OverrideDataType(
718 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100719 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100720
721 optOutputLayerNormWeights = OverrideDataType(
722 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100723 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100724 }
725
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000726 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100727 input,
728 outputStateIn,
729 cellStateIn,
730 scratchBuffer,
731 outputStateOut,
732 cellStateOut,
733 output,
734 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100735 paramsInfo,
736 reason);
telsoa014fcda012018-03-09 14:13:49 +0000737 break;
738 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000739 case LayerType::Maximum:
740 {
741 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
742 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
743 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
744
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000745 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
746 OverrideDataType(input1, dataType),
747 OverrideDataType(output, dataType),
748 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000749 break;
750 }
narpra01b89b05f2019-01-16 09:53:09 +0000751 case LayerType::MemCopy:
752 {
753 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
754 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000755
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000756 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
757 OverrideDataType(output, dataType),
758 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000759 break;
760 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100761 case LayerType::MemImport:
762 {
763 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
764 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
765
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000766 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
767 OverrideDataType(output, dataType),
768 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100769 break;
770 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100771 case LayerType::Merge:
772 {
773 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
774 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
775 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
776
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000777 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
778 OverrideDataType(input1, dataType),
779 OverrideDataType(output, dataType),
780 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100781 break;
782 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100783 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000784 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100785 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000786
telsoa01c577f2c2018-08-31 09:22:23 +0100787 // Get vector of all inputs.
788 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000789 {
telsoa01c577f2c2018-08-31 09:22:23 +0100790 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000791 };
Finn Williams3e54d032020-10-22 16:53:35 +0100792
793 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
794 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100795 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000796
telsoa01c577f2c2018-08-31 09:22:23 +0100797 auto getTensorInfoPtr = [](const TensorInfo& info)
798 {
799 return &info;
800 };
Finn Williams3e54d032020-10-22 16:53:35 +0100801
802 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
803 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100804 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000805
Nikhil Raj8599a412018-11-19 14:51:07 +0000806 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
807
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000808 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100809
810
telsoa014fcda012018-03-09 14:13:49 +0000811 break;
812 }
813 case LayerType::Multiplication:
814 {
815 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
816 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100817 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000818 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100819 OverrideDataType(input0, dataType),
820 OverrideDataType(input1, dataType),
821 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100822 reason);
telsoa014fcda012018-03-09 14:13:49 +0000823 break;
824 }
825 case LayerType::Normalization:
826 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100827 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000828 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
829 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000830 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
831 OverrideDataType(output, dataType),
832 cLayer->GetParameters(),
833 reason);
telsoa014fcda012018-03-09 14:13:49 +0000834 break;
835 }
836 case LayerType::Output:
837 {
838 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000839 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000840 break;
841 }
842 case LayerType::Permute:
843 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100844 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000845 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
846 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000847 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
848 OverrideDataType(output, dataType),
849 cLayer->GetParameters(),
850 reason);
telsoa014fcda012018-03-09 14:13:49 +0000851 break;
852 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100853 case LayerType::Pad:
854 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100855 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100856 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
857 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000858 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100859 OverrideDataType(input, dataType),
860 OverrideDataType(output, dataType),
861 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100862 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100863 break;
864 }
telsoa014fcda012018-03-09 14:13:49 +0000865 case LayerType::Pooling2d:
866 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100867 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000868 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
869 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000870 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
871 OverrideDataType(output, dataType),
872 cLayer->GetParameters(),
873 reason);
telsoa014fcda012018-03-09 14:13:49 +0000874 break;
875 }
Tamás Nyíri7b885b32021-10-26 14:47:57 +0100876 case LayerType::Pooling3d:
877 {
878 auto cLayer = PolymorphicDowncast<const Pooling3dLayer*>(&layer);
879 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
880 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
881 result = layerSupportObject.IsPooling3dSupported(OverrideDataType(input, dataType),
882 OverrideDataType(output, dataType),
883 cLayer->GetParameters(),
884 reason);
885 break;
886 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000887 case LayerType::PreCompiled:
888 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100889 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000890 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000891 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
892 cLayer->GetParameters(),
893 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000894 break;
895 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000896 case LayerType::Quantize:
897 {
898 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
899 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000900 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000901 break;
902 }
James Conroy586a9aa2020-03-20 08:49:33 +0000903 case LayerType::QLstm:
904 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100905 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000906 const QLstmDescriptor& descriptor = cLayer->GetParameters();
907
908 // Inputs
909 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
910 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
911 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
912
913 // Outputs
914 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
915 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
916 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
917
918 // Lstm parameters
919 LstmInputParamsInfo paramsInfo;
920
921 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100922 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
923 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
924 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000925 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
926 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
927 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
928
929 paramsInfo.m_RecurrentToForgetWeights =
930 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
931 paramsInfo.m_RecurrentToCellWeights =
932 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
933 paramsInfo.m_RecurrentToOutputWeights =
934 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
935
936 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
937 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
938 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
939
940 if(!descriptor.m_CifgEnabled)
941 {
942 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
943 paramsInfo.m_RecurrentToInputWeights =
944 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
945 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
946 }
947
948 if(descriptor.m_ProjectionEnabled)
949 {
950 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100951
952 // Projection bias is optional even if projection is enabled
953 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
954 {
955 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
956 }
James Conroy586a9aa2020-03-20 08:49:33 +0000957 }
958
959 if(descriptor.m_PeepholeEnabled)
960 {
961 if (!descriptor.m_CifgEnabled)
962 {
963 paramsInfo.m_CellToInputWeights =
964 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
965 }
966
967 paramsInfo.m_CellToForgetWeights =
968 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
969 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
970 }
971
972 if(descriptor.m_LayerNormEnabled)
973 {
974 if (!descriptor.m_CifgEnabled)
975 {
976 paramsInfo.m_InputLayerNormWeights =
977 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
978 }
979
980 paramsInfo.m_ForgetLayerNormWeights =
981 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
982 paramsInfo.m_CellLayerNormWeights =
983 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
984 paramsInfo.m_OutputLayerNormWeights =
985 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
986 }
987
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000988 result = layerSupportObject.IsQLstmSupported(input,
989 previousOutputIn,
990 previousCellStateIn,
991 outputStateOut,
992 cellStateOut,
993 output,
994 descriptor,
995 paramsInfo,
996 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000997 break;
998 }
James Conroyee18dc82019-07-17 11:27:46 +0100999 case LayerType::QuantizedLstm:
1000 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001001 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +01001002
1003 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001004 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1005 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1006 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001007
1008 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001009 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
1010 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001011
1012 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +01001013 QuantizedLstmInputParamsInfo paramsInfo;
1014
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001015 paramsInfo.m_InputToInputWeights =
1016 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
1017 paramsInfo.m_InputToForgetWeights =
1018 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
1019 paramsInfo.m_InputToCellWeights =
1020 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
1021 paramsInfo.m_InputToOutputWeights =
1022 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001023
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001024 paramsInfo.m_RecurrentToInputWeights =
1025 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
1026 paramsInfo.m_RecurrentToForgetWeights =
1027 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
1028 paramsInfo.m_RecurrentToCellWeights =
1029 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
1030 paramsInfo.m_RecurrentToOutputWeights =
1031 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +01001032
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001033 paramsInfo.m_InputGateBias =
1034 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
1035 paramsInfo.m_ForgetGateBias =
1036 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
1037 paramsInfo.m_CellBias =
1038 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
1039 paramsInfo.m_OutputGateBias =
1040 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +01001041
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001042 result = layerSupportObject.IsQuantizedLstmSupported(input,
1043 previousCellStateIn,
1044 previousOutputIn,
1045 cellStateOut,
1046 output,
1047 paramsInfo,
1048 reason);
James Conroyee18dc82019-07-17 11:27:46 +01001049 break;
1050 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001051 case LayerType::Division:
1052 {
1053 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1054 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1055 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001056 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001057 OverrideDataType(input0, dataType),
1058 OverrideDataType(input1, dataType),
1059 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001060 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001061 break;
1062 }
Finn Williams2605b232020-06-10 15:53:46 +01001063 case LayerType::Rank:
1064 {
1065 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1066 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001067 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
1068 OverrideDataType(output, dataType),
1069 reason);
Finn Williams2605b232020-06-10 15:53:46 +01001070 break;
1071 }
telsoa014fcda012018-03-09 14:13:49 +00001072 case LayerType::Reshape:
1073 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001074 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001075 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +00001076 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001077 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
1078 OverrideDataType(output, dataType),
1079 cLayer->GetParameters(),
1080 reason);
telsoa014fcda012018-03-09 14:13:49 +00001081 break;
1082 }
Teresa Charlina9075df2019-06-27 15:41:57 +01001083 case LayerType::Resize:
1084 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001085 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001086 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001087 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001088 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1089 OverrideDataType(output, dataType),
1090 cLayer->GetParameters(),
1091 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001092 break;
1093 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001094 case LayerType::Shape:
1095 {
1096 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1097 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1098
1099 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1100 OverrideDataType(output, dataType),
1101 reason);
1102 break;
1103 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001104 case LayerType::Slice:
1105 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001106 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001107
1108 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1109 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1110
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001111 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1112 OverrideDataType(output, dataType),
1113 cLayer->GetParameters(),
1114 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001115 break;
1116 }
telsoa014fcda012018-03-09 14:13:49 +00001117 case LayerType::Softmax:
1118 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001119 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001120 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001121 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001122 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1123 OverrideDataType(output, dataType),
1124 cLayer->GetParameters(),
1125 reason);
telsoa014fcda012018-03-09 14:13:49 +00001126 break;
1127 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001128 case LayerType::SpaceToBatchNd:
1129 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001130 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001133 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1134 OverrideDataType(output, dataType),
1135 cLayer->GetParameters(),
1136 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001137 break;
1138 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001139 case LayerType::SpaceToDepth:
1140 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001141 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001142
1143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1145
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001146 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1147 OverrideDataType(output, dataType),
1148 cLayer->GetParameters(),
1149 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001150 break;
1151 }
telsoa014fcda012018-03-09 14:13:49 +00001152 case LayerType::Splitter:
1153 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001154 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001155 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001156
1157 // Get vector of all outputs.
1158 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1159 {
1160 return OverrideDataType(slot.GetTensorInfo(), dataType);
1161 };
Finn Williams3e54d032020-10-22 16:53:35 +01001162 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1163 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001164 std::vector<TensorInfo> outputs(beginI, endI);
1165
1166 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1167
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001168 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1169 outputPtrs,
1170 cLayer->GetParameters(),
1171 reason);
telsoa014fcda012018-03-09 14:13:49 +00001172 break;
1173 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001174 case LayerType::Stack:
1175 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001176 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001177
1178 // Get vector of all inputs.
1179 auto getTensorInfo = [&dataType](const InputSlot& slot)
1180 {
1181 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1182 };
Finn Williams3e54d032020-10-22 16:53:35 +01001183 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1184 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001185 std::vector<TensorInfo> inputs(beginI, endI);
1186
1187 auto getTensorInfoPtr = [](const TensorInfo& info)
1188 {
1189 return &info;
1190 };
Finn Williams3e54d032020-10-22 16:53:35 +01001191 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1192 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001193 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1194
1195 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1196
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001197 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001198
1199 break;
1200 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001201 case LayerType::StandIn:
1202 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001203 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001204
1205 // Get vector of all inputs.
1206 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1207 {
1208 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1209 };
1210 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1211 {
1212 return OverrideDataType(slot.GetTensorInfo(), dataType);
1213 };
Finn Williams3e54d032020-10-22 16:53:35 +01001214 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1215 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001216 std::vector<TensorInfo> inputs(beginI, endI);
1217
Finn Williams3e54d032020-10-22 16:53:35 +01001218 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1219 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001220 std::vector<TensorInfo> outputs(beginO, endO);
1221
1222
1223 auto getTensorInfoPtr = [](const TensorInfo& info)
1224 {
1225 return &info;
1226 };
Finn Williams3e54d032020-10-22 16:53:35 +01001227 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1228 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001229 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1230
Finn Williams3e54d032020-10-22 16:53:35 +01001231 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1232 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001233 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1234
1235
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001236 result = layerSupportObject.IsStandInSupported(inputPtrs,
1237 outputPtrs,
1238 cLayer->GetParameters(),
1239 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001240 break;
1241 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001242 case LayerType::StridedSlice:
1243 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001244 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001245 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1246 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001247 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1248 OverrideDataType(output, dataType),
1249 cLayer->GetParameters(),
1250 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001251 break;
1252 }
David Beckc2044fe2018-09-05 15:00:38 +01001253 case LayerType::Subtraction:
1254 {
1255 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1256 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1257 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001258 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001259 OverrideDataType(input0, dataType),
1260 OverrideDataType(input1, dataType),
1261 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001262 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001263 break;
1264 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001265 case LayerType::Switch:
1266 {
1267 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1268 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1269 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1270 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001271 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1272 OverrideDataType(input1, dataType),
1273 OverrideDataType(output0, dataType),
1274 OverrideDataType(output1, dataType),
1275 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001276 break;
1277 }
narpra0132b90462018-09-13 11:07:48 +01001278 case LayerType::Mean:
1279 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001280 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001281 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1282 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001283 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001284 OverrideDataType(input, dataType),
1285 OverrideDataType(output, dataType),
1286 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001287 reason);
narpra0132b90462018-09-13 11:07:48 +01001288 break;
1289 }
kevmay0190539692018-11-29 08:40:19 +00001290 case LayerType::Minimum:
1291 {
1292 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1293 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1294 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001295 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1296 OverrideDataType(input1, dataType),
1297 OverrideDataType(output, dataType),
1298 reason);
kevmay0190539692018-11-29 08:40:19 +00001299 break;
1300 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001301 case LayerType::Prelu:
1302 {
1303 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1304 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1305 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001306 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1307 OverrideDataType(alpha, dataType),
1308 OverrideDataType(output, dataType),
1309 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001310 break;
1311 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001312 case LayerType::Transpose:
1313 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001314 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001315 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1316 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001317 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1318 OverrideDataType(output, dataType),
1319 cLayer->GetParameters(),
1320 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001321 break;
1322 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001323 case LayerType::TransposeConvolution2d:
1324 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001325 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001326
1327 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1328 dataType);
1329 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1330
1331 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1332
1333 Optional<TensorInfo> biases;
1334 if (descriptor.m_BiasEnabled)
1335 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001336 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001337 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1338 GetBiasTypeFromWeightsType(dataType));
1339 }
1340
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001341 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001342 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1343
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001344 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1345 output,
1346 descriptor,
1347 weights,
1348 biases,
1349 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001350
1351 break;
1352 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001353 case LayerType::Reduce:
1354 {
1355 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1356 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1357 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1358
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001359 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1360 OverrideDataType(output, dataType),
1361 cLayer->GetParameters(),
1362 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001363 break;
1364 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001365 case LayerType::UnidirectionalSequenceLstm:
1366 {
1367 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1368 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1369
1370 // All inputs.
1371 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1372 dataType);
1373 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1374 dataType);
1375 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1376 dataType);
1377 // Outputs
Mike Kelly12994962022-04-21 11:57:09 +01001378 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1379 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
1380 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001381
1382 // Basic parameters
1383 const TensorInfo& inputToForgetWeights
1384 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1385 const TensorInfo& inputToCellWeights
1386 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1387 const TensorInfo& inputToOutputWeights
1388 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1389 const TensorInfo& recurrentToForgetWeights
1390 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1391 const TensorInfo& recurrentToCellWeights
1392 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1393 const TensorInfo& recurrentToOutputWeights
1394 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1395 const TensorInfo& forgetGateBias
1396 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1397 const TensorInfo& cellBias
1398 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1399 const TensorInfo& outputGateBias
1400 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1401
1402 LstmInputParamsInfo paramsInfo;
1403
1404 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1405 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1406 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1407 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1408 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1409 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1410 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1411 paramsInfo.m_CellBias = &cellBias;
1412 paramsInfo.m_OutputGateBias = &outputGateBias;
1413
1414 // Optional parameters
1415 TensorInfo optInputToInputWeights;
1416 TensorInfo optRecurrentToInputWeights;
1417 TensorInfo optCellToInputWeights;
1418 TensorInfo optInputGateBias;
1419 TensorInfo optProjectionWeights;
1420 TensorInfo optProjectionBias;
1421 TensorInfo optCellToForgetWeights;
1422 TensorInfo optCellToOutputWeights;
1423 TensorInfo optInputLayerNormWeights;
1424 TensorInfo optForgetLayerNormWeights;
1425 TensorInfo optCellLayerNormWeights;
1426 TensorInfo optOutputLayerNormWeights;
1427
1428 if(!descriptor.m_CifgEnabled)
1429 {
1430 optInputToInputWeights =
1431 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1432 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1433
1434 optRecurrentToInputWeights =
1435 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1436 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1437 optInputGateBias =
1438 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1439 paramsInfo.m_InputGateBias = &optInputGateBias;
1440 }
1441
1442 if(descriptor.m_ProjectionEnabled)
1443 {
1444 optProjectionWeights =
1445 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1446 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1447 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1448 {
1449 optProjectionBias =
1450 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1451 paramsInfo.m_ProjectionBias = &optProjectionBias;
1452 }
1453 }
1454
1455 if(descriptor.m_PeepholeEnabled)
1456 {
1457 if(!descriptor.m_CifgEnabled)
1458 {
1459 optCellToInputWeights =
1460 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1461 dataType);
1462 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1463 }
1464 optCellToForgetWeights =
1465 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1466 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1467 optCellToOutputWeights =
1468 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1469 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1470 }
1471
1472 if(descriptor.m_LayerNormEnabled)
1473 {
1474 if (!descriptor.m_CifgEnabled)
1475 {
1476 optInputLayerNormWeights = OverrideDataType(
1477 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1478 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1479 }
1480
1481 optForgetLayerNormWeights = OverrideDataType(
1482 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1483 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1484
1485 optCellLayerNormWeights = OverrideDataType(
1486 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1487 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1488
1489 optOutputLayerNormWeights = OverrideDataType(
1490 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1491 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1492 }
1493
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001494 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1495 outputStateIn,
1496 cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01001497 outputStateOut,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001498 cellStateOut,
Mike Kelly12994962022-04-21 11:57:09 +01001499 output,
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001500 descriptor,
1501 paramsInfo,
1502 reason);
1503 break;
1504 }
telsoa014fcda012018-03-09 14:13:49 +00001505 default:
1506 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001507 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001508 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001509 result = false;
1510 break;
1511 }
1512 }
telsoa014fcda012018-03-09 14:13:49 +00001513 return result;
1514}
1515
Sadik Armagan045f6be2020-09-10 13:37:32 +01001516bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1517 const IConnectableLayer& connectableLayer,
1518 Optional<DataType> dataType,
1519 std::string& outReasonIfUnsupported)
1520{
1521 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1522}
1523
David Beckdcb751f2018-10-03 11:42:42 +01001524bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001525 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001526 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001527{
Jan Eilersbb446e52020-04-02 13:56:54 +01001528 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001529 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1530}
1531
1532// TODO merge with defaulted modelOptions above
1533bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1534 Optional<DataType> dataType,
1535 std::string& outReasonIfUnsupported,
1536 const ModelOptions& modelOptions)
1537{
1538 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1539 return IsLayerConfigurationSupported(layer->GetBackendId(),
1540 connectableLayer,
1541 dataType,
1542 outReasonIfUnsupported,
1543 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001544}
1545
Sadik Armagan04a72972020-09-14 15:44:18 +01001546bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1547 const IConnectableLayer& connectableLayer,
1548 Optional<DataType> dataType,
1549 std::string& outReasonIfUnsupported,
1550 const ModelOptions& modelOptions)
1551{
1552 return IsLayerConfigurationSupported(backendId,
1553 connectableLayer,
1554 dataType,
1555 outReasonIfUnsupported,
1556 modelOptions);
1557}
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001558ARMNN_NO_DEPRECATE_WARN_BEGIN
1559std::unique_ptr<IWorkload> IWorkloadFactory::CreateWorkload(LayerType type,
1560 const QueueDescriptor& descriptor,
1561 const WorkloadInfo& info) const
1562{
1563 switch(type)
1564 {
1565 case LayerType::Activation :
1566 {
1567 auto activationQueueDescriptor = PolymorphicDowncast<const ActivationQueueDescriptor*>(&descriptor);
1568 return CreateActivation(*activationQueueDescriptor, info);
1569 }
1570 case LayerType::Addition :
1571 {
1572 auto additionQueueDescriptor = PolymorphicDowncast<const AdditionQueueDescriptor*>(&descriptor);
1573 return CreateAddition(*additionQueueDescriptor, info);
1574 }
1575 case LayerType::ArgMinMax :
1576 {
1577 auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
1578 return CreateArgMinMax(*argMinMaxQueueDescriptor, info);
1579 }
1580 case LayerType::BatchNormalization :
1581 {
1582 auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
1583 return CreateBatchNormalization(*batchNormQueueDescriptor, info);
1584 }
1585 case LayerType::BatchToSpaceNd :
1586 {
1587 auto batchToSpaceNdQueueDescriptor
1588 = PolymorphicDowncast<const BatchToSpaceNdQueueDescriptor*>(&descriptor);
1589 return CreateBatchToSpaceNd(*batchToSpaceNdQueueDescriptor, info);
1590 }
1591 case LayerType::Cast :
1592 {
1593 auto castQueueDescriptor = PolymorphicDowncast<const CastQueueDescriptor*>(&descriptor);
1594 return CreateCast(*castQueueDescriptor, info);
1595 }
1596 case LayerType::ChannelShuffle :
1597 {
1598 auto channelShuffleQueueDescriptor
1599 = PolymorphicDowncast<const ChannelShuffleQueueDescriptor*>(&descriptor);
1600 return CreateChannelShuffle(*channelShuffleQueueDescriptor, info);
1601 }
1602 case LayerType::Comparison :
1603 {
1604 auto comparisonQueueDescriptor = PolymorphicDowncast<const ComparisonQueueDescriptor*>(&descriptor);
1605 return CreateComparison(*comparisonQueueDescriptor, info);
1606 }
1607 case LayerType::Concat :
1608 {
1609 auto concatQueueDescriptor = PolymorphicDowncast<const ConcatQueueDescriptor*>(&descriptor);
1610 return CreateConcat(*concatQueueDescriptor, info);
1611 }
1612 case LayerType::Constant :
1613 {
1614 auto constantQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor);
1615 return CreateConstant(*constantQueueDescriptor, info);
1616 }
1617 case LayerType::ConvertBf16ToFp32 :
1618 {
1619 auto convertBf16ToFp32QueueDescriptor
1620 = PolymorphicDowncast<const ConvertBf16ToFp32QueueDescriptor*>(&descriptor);
1621 return CreateConvertBf16ToFp32(*convertBf16ToFp32QueueDescriptor, info);
1622 }
1623 case LayerType::ConvertFp16ToFp32:
1624 {
1625 auto convertFp16ToFp32QueueDescriptor
1626 = PolymorphicDowncast<const ConvertFp16ToFp32QueueDescriptor*>(&descriptor);
1627 return CreateConvertFp16ToFp32(*convertFp16ToFp32QueueDescriptor, info);
1628 }
1629 case LayerType::ConvertFp32ToBf16:
1630 {
1631 auto convertFp32ToBf16QueueDescriptor
1632 = PolymorphicDowncast<const ConvertFp32ToBf16QueueDescriptor*>(&descriptor);
1633 return CreateConvertFp32ToBf16(*convertFp32ToBf16QueueDescriptor, info);
1634 }
1635 case LayerType::ConvertFp32ToFp16:
1636 {
1637 auto convertFp32ToFp16QueueDescriptor
1638 = PolymorphicDowncast<const ConvertFp32ToFp16QueueDescriptor*>(&descriptor);
1639 return CreateConvertFp32ToFp16(*convertFp32ToFp16QueueDescriptor, info);
1640 }
1641 case LayerType::Convolution2d:
1642 {
1643 auto convolution2dQueueDescriptor = PolymorphicDowncast<const Convolution2dQueueDescriptor*>(&descriptor);
1644 return CreateConvolution2d(*convolution2dQueueDescriptor, info);
1645 }
1646 case LayerType::Convolution3d:
1647 {
1648 auto convolution3dQueueDescriptor = PolymorphicDowncast<const Convolution3dQueueDescriptor*>(&descriptor);
1649 return CreateConvolution3d(*convolution3dQueueDescriptor, info);
1650 }
1651 case LayerType::Debug:
1652 {
1653 auto debugQueueDescriptor = PolymorphicDowncast<const DebugQueueDescriptor*>(&descriptor);
1654 return CreateDebug(*debugQueueDescriptor, info);
1655 }
1656 case LayerType::DepthToSpace:
1657 {
1658 auto depthToSpaceQueueDescriptor = PolymorphicDowncast<const DepthToSpaceQueueDescriptor*>(&descriptor);
1659 return CreateDepthToSpace(*depthToSpaceQueueDescriptor, info);
1660 }
1661 case LayerType::DepthwiseConvolution2d:
1662 {
1663 auto depthwiseConvolution2DQueueDescriptor
1664 = PolymorphicDowncast<const DepthwiseConvolution2dQueueDescriptor*>(&descriptor);
1665 return CreateDepthwiseConvolution2d(*depthwiseConvolution2DQueueDescriptor, info);
1666 }
1667 case LayerType::Dequantize:
1668 {
1669 auto dequantizeQueueDescriptor = PolymorphicDowncast<const DequantizeQueueDescriptor*>(&descriptor);
1670 return CreateDequantize(*dequantizeQueueDescriptor, info);
1671 }
1672 case LayerType::DetectionPostProcess:
1673 {
1674 auto detectionPostProcessQueueDescriptor
1675 = PolymorphicDowncast<const DetectionPostProcessQueueDescriptor*>(&descriptor);
1676 return CreateDetectionPostProcess(*detectionPostProcessQueueDescriptor, info);
1677 }
1678 case LayerType::Division:
1679 {
1680 auto divisionQueueDescriptor = PolymorphicDowncast<const DivisionQueueDescriptor*>(&descriptor);
1681 return CreateDivision(*divisionQueueDescriptor, info);
1682 }
1683 case LayerType::ElementwiseUnary:
1684 {
1685 auto elementwiseUnaryQueueDescriptor
1686 = PolymorphicDowncast<const ElementwiseUnaryQueueDescriptor*>(&descriptor);
1687 return CreateElementwiseUnary(*elementwiseUnaryQueueDescriptor, info);
1688
1689 }
1690 case LayerType::FakeQuantization:
1691 {
1692 auto fakeQuantizationQueueDescriptor
1693 = PolymorphicDowncast<const FakeQuantizationQueueDescriptor*>(&descriptor);
1694 return CreateFakeQuantization(*fakeQuantizationQueueDescriptor, info);
1695 }
1696 case LayerType::Fill:
1697 {
1698 auto fillQueueDescriptor = PolymorphicDowncast<const FillQueueDescriptor*>(&descriptor);
1699 return CreateFill(*fillQueueDescriptor, info);
1700 }
1701 case LayerType::Floor:
1702 {
1703 auto floorQueueDescriptor = PolymorphicDowncast<const FloorQueueDescriptor*>(&descriptor);
1704 return CreateFloor(*floorQueueDescriptor, info);
1705 }
1706 case LayerType::FullyConnected:
1707 {
1708 auto fullyConnectedQueueDescriptor
1709 = PolymorphicDowncast<const FullyConnectedQueueDescriptor*>(&descriptor);
1710 return CreateFullyConnected(*fullyConnectedQueueDescriptor, info);
1711 }
1712 case LayerType::Gather:
1713 {
1714 auto gatherQueueDescriptor = PolymorphicDowncast<const GatherQueueDescriptor*>(&descriptor);
1715 return CreateGather(*gatherQueueDescriptor, info);
1716 }
1717 case LayerType::Input:
1718 {
1719 auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor);
1720 return CreateInput(*inputQueueDescriptor, info);
1721 }
1722 case LayerType::InstanceNormalization:
1723 {
1724 auto instanceNormalizationQueueDescriptor
1725 = PolymorphicDowncast<const InstanceNormalizationQueueDescriptor*>(&descriptor);
1726 return CreateInstanceNormalization(*instanceNormalizationQueueDescriptor, info);
1727 }
1728 case LayerType::L2Normalization:
1729 {
1730 auto l2NormalizationQueueDescriptor
1731 = PolymorphicDowncast<const L2NormalizationQueueDescriptor*>(&descriptor);
1732 return CreateL2Normalization(*l2NormalizationQueueDescriptor, info);
1733 }
1734 case LayerType::LogicalBinary:
1735 {
1736 auto logicalBinaryQueueDescriptor = PolymorphicDowncast<const LogicalBinaryQueueDescriptor*>(&descriptor);
1737 return CreateLogicalBinary(*logicalBinaryQueueDescriptor, info);
1738 }
1739 case LayerType::LogSoftmax:
1740 {
1741 auto logSoftmaxQueueDescriptor = PolymorphicDowncast<const LogSoftmaxQueueDescriptor*>(&descriptor);
1742 return CreateLogSoftmax(*logSoftmaxQueueDescriptor, info);
1743 }
1744 case LayerType::Lstm:
1745 {
1746 auto lstmQueueDescriptor = PolymorphicDowncast<const LstmQueueDescriptor*>(&descriptor);
1747 return CreateLstm(*lstmQueueDescriptor, info);
1748 }
1749 case LayerType::Maximum:
1750 {
1751 auto maximumQueueDescriptor = PolymorphicDowncast<const MaximumQueueDescriptor*>(&descriptor);
1752 return CreateMaximum(*maximumQueueDescriptor, info);
1753 }
1754 case LayerType::Mean:
1755 {
1756 auto meanQueueDescriptor = PolymorphicDowncast<const MeanQueueDescriptor*>(&descriptor);
1757 return CreateMean(*meanQueueDescriptor, info);
1758 }
1759 case LayerType::MemCopy:
1760 {
1761 auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
1762 return CreateMemCopy(*memCopyQueueDescriptor, info);
1763 }
1764 case LayerType::MemImport:
1765 {
1766 auto memImportQueueDescriptor = PolymorphicDowncast<const MemImportQueueDescriptor*>(&descriptor);
1767 return CreateMemImport(*memImportQueueDescriptor, info);
1768 }
1769 case LayerType::Minimum:
1770 {
1771 auto minimumQueueDescriptor = PolymorphicDowncast<const MinimumQueueDescriptor*>(&descriptor);
1772 return CreateMinimum(*minimumQueueDescriptor, info);
1773 }
1774 case LayerType::Multiplication:
1775 {
1776 auto multiplicationQueueDescriptor
1777 = PolymorphicDowncast<const MultiplicationQueueDescriptor*>(&descriptor);
1778 return CreateMultiplication(*multiplicationQueueDescriptor, info);
1779 }
1780 case LayerType::Normalization:
1781 {
1782 auto normalizationQueueDescriptor = PolymorphicDowncast<const NormalizationQueueDescriptor*>(&descriptor);
1783 return CreateNormalization(*normalizationQueueDescriptor, info);
1784 }
1785 case LayerType::Output:
1786 {
1787 auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor);
1788 return CreateOutput(*outputQueueDescriptor, info);
1789 }
1790 case LayerType::Pad:
1791 {
1792 auto padQueueDescriptor = PolymorphicDowncast<const PadQueueDescriptor*>(&descriptor);
1793 return CreatePad(*padQueueDescriptor, info);
1794 }
1795 case LayerType::Permute:
1796 {
1797 auto permuteQueueDescriptor = PolymorphicDowncast<const PermuteQueueDescriptor*>(&descriptor);
1798 return CreatePermute(*permuteQueueDescriptor, info);
1799 }
1800 case LayerType::Pooling2d:
1801 {
1802 auto pooling2dQueueDescriptor = PolymorphicDowncast<const Pooling2dQueueDescriptor*>(&descriptor);
1803 return CreatePooling2d(*pooling2dQueueDescriptor, info);
1804 }
1805 case LayerType::Pooling3d:
1806 {
1807 auto pooling3dQueueDescriptor = PolymorphicDowncast<const Pooling3dQueueDescriptor*>(&descriptor);
1808 return CreatePooling3d(*pooling3dQueueDescriptor, info);
1809 }
1810 case LayerType::PreCompiled:
1811 {
1812 auto preCompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
1813 return CreatePreCompiled(*preCompiledQueueDescriptor, info);
1814 }
1815 case LayerType::Prelu:
1816 {
1817 auto preluQueueDescriptor = PolymorphicDowncast<const PreluQueueDescriptor*>(&descriptor);
1818 return CreatePrelu(*preluQueueDescriptor, info);
1819 }
1820 case LayerType::QLstm:
1821 {
1822 auto qlstmQueueDescriptor = PolymorphicDowncast<const QLstmQueueDescriptor*>(&descriptor);
1823 return CreateQLstm(*qlstmQueueDescriptor, info);
1824 }
1825 case LayerType::Quantize:
1826 {
1827 auto quantizeQueueDescriptor = PolymorphicDowncast<const QuantizeQueueDescriptor*>(&descriptor);
1828 return CreateQuantize(*quantizeQueueDescriptor, info);
1829 }
1830 case LayerType::Rank:
1831 {
1832 auto rankQueueDescriptor = PolymorphicDowncast<const RankQueueDescriptor*>(&descriptor);
1833 return CreateRank(*rankQueueDescriptor, info);
1834 }
1835 case LayerType::Reduce:
1836 {
1837 auto reduceQueueDescriptor = PolymorphicDowncast<const ReduceQueueDescriptor*>(&descriptor);
1838 return CreateReduce(*reduceQueueDescriptor, info);
1839 }
1840 case LayerType::Reshape:
1841 {
1842 auto reshapeQueueDescriptor = PolymorphicDowncast<const ReshapeQueueDescriptor*>(&descriptor);
1843 return CreateReshape(*reshapeQueueDescriptor, info);
1844 }
1845 case LayerType::Resize:
1846 {
1847 auto resizeQueueDescriptor = PolymorphicDowncast<const ResizeQueueDescriptor*>(&descriptor);
1848 return CreateResize(*resizeQueueDescriptor, info);
1849 }
1850 case LayerType::Shape:
1851 {
1852 auto shapeQueueDescriptor = PolymorphicDowncast<const ShapeQueueDescriptor*>(&descriptor);
1853 return CreateShape(*shapeQueueDescriptor, info);
1854 }
1855 case LayerType::Slice:
1856 {
1857 auto sliceQueueDescriptor = PolymorphicDowncast<const SliceQueueDescriptor*>(&descriptor);
1858 return CreateSlice(*sliceQueueDescriptor, info);
1859 }
1860 case LayerType::Softmax:
1861 {
1862 auto softmaxQueueDescriptor = PolymorphicDowncast<const SoftmaxQueueDescriptor*>(&descriptor);
1863 return CreateSoftmax(*softmaxQueueDescriptor, info);
1864 }
1865 case LayerType::SpaceToBatchNd:
1866 {
1867 auto spaceToBatchNdQueueDescriptor
1868 = PolymorphicDowncast<const SpaceToBatchNdQueueDescriptor*>(&descriptor);
1869 return CreateSpaceToBatchNd(*spaceToBatchNdQueueDescriptor, info);
1870 }
1871 case LayerType::SpaceToDepth:
1872 {
1873 auto spaceToDepthQueueDescriptor = PolymorphicDowncast<const SpaceToDepthQueueDescriptor*>(&descriptor);
1874 return CreateSpaceToDepth(*spaceToDepthQueueDescriptor, info);
1875 }
1876 case LayerType::Splitter:
1877 {
1878 auto splitterQueueDescriptor = PolymorphicDowncast<const SplitterQueueDescriptor*>(&descriptor);
1879 return CreateSplitter(*splitterQueueDescriptor, info);
1880 }
1881 case LayerType::Stack:
1882 {
1883 auto stackQueueDescriptor = PolymorphicDowncast<const StackQueueDescriptor*>(&descriptor);
1884 return CreateStack(*stackQueueDescriptor, info);
1885 }
1886 case LayerType::StridedSlice:
1887 {
1888 auto stridedSliceQueueDescriptor = PolymorphicDowncast<const StridedSliceQueueDescriptor*>(&descriptor);
1889 return CreateStridedSlice(*stridedSliceQueueDescriptor, info);
1890 }
1891 case LayerType::Subtraction:
1892 {
1893 auto subtractionQueueDescriptor = PolymorphicDowncast<const SubtractionQueueDescriptor*>(&descriptor);
1894 return CreateSubtraction(*subtractionQueueDescriptor, info);
1895 }
1896 case LayerType::Transpose:
1897 {
1898 auto transposeQueueDescriptor = PolymorphicDowncast<const TransposeQueueDescriptor*>(&descriptor);
1899 return CreateTranspose(*transposeQueueDescriptor, info);
1900 }
1901 case LayerType::TransposeConvolution2d:
1902 {
1903 auto transposeConvolution2dQueueDescriptor
1904 = PolymorphicDowncast<const TransposeConvolution2dQueueDescriptor*>(&descriptor);
1905 return CreateTransposeConvolution2d(*transposeConvolution2dQueueDescriptor, info);
1906 }
1907 case LayerType::UnidirectionalSequenceLstm:
1908 {
1909 auto unidirectionalSequenceLstmQueueDescriptor
1910 = PolymorphicDowncast<const UnidirectionalSequenceLstmQueueDescriptor*>(&descriptor);
1911 return CreateUnidirectionalSequenceLstm(*unidirectionalSequenceLstmQueueDescriptor, info);
1912 }
1913 default:
1914 return nullptr;
1915 }
1916}
1917ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan04a72972020-09-14 15:44:18 +01001918
Derek Lamberti901ea112019-12-10 22:07:09 +00001919std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1920 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001921{
1922 return std::unique_ptr<IWorkload>();
1923}
1924
Derek Lamberti901ea112019-12-10 22:07:09 +00001925std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1926 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001927{
1928 return std::unique_ptr<IWorkload>();
1929}
1930
Derek Lamberti901ea112019-12-10 22:07:09 +00001931std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1932 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001933{
1934 return std::unique_ptr<IWorkload>();
1935}
1936
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001937std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001938 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001939{
1940 return std::unique_ptr<IWorkload>();
1941}
1942
Derek Lamberti901ea112019-12-10 22:07:09 +00001943std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1944 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001945{
1946 return std::unique_ptr<IWorkload>();
1947}
1948
mathad01b392e982021-04-07 12:07:30 +01001949std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1950 const WorkloadInfo& /*info*/) const
1951{
1952 return std::unique_ptr<IWorkload>();
1953}
1954
Simon Obute51f67772021-09-03 15:50:13 +01001955std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1956 const WorkloadInfo& /*info*/) const
1957{
1958 return std::unique_ptr<IWorkload>();
1959}
1960
Derek Lamberti901ea112019-12-10 22:07:09 +00001961std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1962 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001963{
1964 return std::unique_ptr<IWorkload>();
1965}
1966
Derek Lamberti901ea112019-12-10 22:07:09 +00001967std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1968 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001969{
1970 return std::unique_ptr<IWorkload>();
1971}
1972
Derek Lamberti901ea112019-12-10 22:07:09 +00001973std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1974 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001975{
1976 return std::unique_ptr<IWorkload>();
1977}
1978
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001979std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1980 const WorkloadInfo& /*info*/) const
1981{
1982 return std::unique_ptr<IWorkload>();
1983}
1984
Derek Lamberti901ea112019-12-10 22:07:09 +00001985std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1986 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001987{
1988 return std::unique_ptr<IWorkload>();
1989}
1990
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001991std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1992 const WorkloadInfo& /*info*/) const
1993{
1994 return std::unique_ptr<IWorkload>();
1995}
1996
Derek Lamberti901ea112019-12-10 22:07:09 +00001997std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1998 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001999{
2000 return std::unique_ptr<IWorkload>();
2001}
2002
Derek Lamberti901ea112019-12-10 22:07:09 +00002003std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
2004 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002005{
2006 return std::unique_ptr<IWorkload>();
2007}
2008
Matthew Sloyanb63a3112021-09-08 13:05:51 +01002009std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/,
2010 const WorkloadInfo& /*info*/) const
2011{
2012 return std::unique_ptr<IWorkload>();
2013}
2014
Derek Lamberti901ea112019-12-10 22:07:09 +00002015std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
2016 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002017{
2018 return std::unique_ptr<IWorkload>();
2019}
2020
Derek Lamberti901ea112019-12-10 22:07:09 +00002021std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
2022 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002023{
2024 return std::unique_ptr<IWorkload>();
2025}
2026
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002027std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002028 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002029{
2030 return std::unique_ptr<IWorkload>();
2031}
2032
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002033std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00002034 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002035{
2036 return std::unique_ptr<IWorkload>();
2037}
2038
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002039std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00002040 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002041{
2042 return std::unique_ptr<IWorkload>();
2043}
2044
Derek Lamberti901ea112019-12-10 22:07:09 +00002045std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
2046 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002047{
2048 return std::unique_ptr<IWorkload>();
2049}
2050
josh minor4a3c6102020-01-06 16:40:46 -06002051std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2052 const WorkloadInfo& /*info*/) const
2053{
2054 return std::unique_ptr<IWorkload>();
2055}
2056
Derek Lamberti901ea112019-12-10 22:07:09 +00002057std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
2058 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002059{
2060 return std::unique_ptr<IWorkload>();
2061}
2062
Ryan OSheaec6c6802020-06-05 17:17:06 +01002063std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
2064 const WorkloadInfo& /*info*/) const
2065{
2066 return std::unique_ptr<IWorkload>();
2067}
2068
Derek Lamberti901ea112019-12-10 22:07:09 +00002069std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
2070 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002071{
2072 return std::unique_ptr<IWorkload>();
2073}
2074
Derek Lamberti901ea112019-12-10 22:07:09 +00002075std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
2076 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002077{
2078 return std::unique_ptr<IWorkload>();
2079}
2080
Derek Lamberti901ea112019-12-10 22:07:09 +00002081std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
2082 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002083{
2084 return std::unique_ptr<IWorkload>();
2085}
2086
Kevin Mayce5045a2019-10-02 14:07:47 +01002087std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00002088 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
2089 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01002090{
2091 return std::unique_ptr<IWorkload>();
2092}
2093
Derek Lamberti901ea112019-12-10 22:07:09 +00002094std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
2095 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002096{
2097 return std::unique_ptr<IWorkload>();
2098}
2099
James Conroyaba90cd2020-11-06 16:28:18 +00002100std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
2101 const WorkloadInfo& /*info*/) const
2102{
2103 return std::unique_ptr<IWorkload>();
2104}
2105
2106std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
2107 const WorkloadInfo& /*info*/) const
2108{
2109 return std::unique_ptr<IWorkload>();
2110}
2111
Derek Lamberti901ea112019-12-10 22:07:09 +00002112std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
2113 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01002114{
2115 return std::unique_ptr<IWorkload>();
2116}
2117
Derek Lamberti901ea112019-12-10 22:07:09 +00002118std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
2119 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002120{
2121 return std::unique_ptr<IWorkload>();
2122}
2123
Derek Lamberti901ea112019-12-10 22:07:09 +00002124std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
2125 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002126{
2127 return std::unique_ptr<IWorkload>();
2128}
2129
Derek Lamberti901ea112019-12-10 22:07:09 +00002130std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
2131 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002132{
2133 return std::unique_ptr<IWorkload>();
2134}
2135
Derek Lamberti901ea112019-12-10 22:07:09 +00002136std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
2137 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002138{
2139 return std::unique_ptr<IWorkload>();
2140}
2141
Derek Lamberti901ea112019-12-10 22:07:09 +00002142std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
2143 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01002144{
2145 return std::unique_ptr<IWorkload>();
2146}
2147
Derek Lamberti901ea112019-12-10 22:07:09 +00002148std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
2149 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002150{
2151 return std::unique_ptr<IWorkload>();
2152}
2153
Derek Lamberti901ea112019-12-10 22:07:09 +00002154std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
2155 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002156{
2157 return std::unique_ptr<IWorkload>();
2158}
2159
Derek Lamberti901ea112019-12-10 22:07:09 +00002160std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
2161 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002162{
2163 return std::unique_ptr<IWorkload>();
2164}
2165
Derek Lamberti901ea112019-12-10 22:07:09 +00002166std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
2167 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002168{
2169 return std::unique_ptr<IWorkload>();
2170}
2171
Derek Lamberti901ea112019-12-10 22:07:09 +00002172std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
2173 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002174{
2175 return std::unique_ptr<IWorkload>();
2176}
2177
Derek Lamberti901ea112019-12-10 22:07:09 +00002178std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
2179 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002180{
2181 return std::unique_ptr<IWorkload>();
2182}
2183
Derek Lamberti901ea112019-12-10 22:07:09 +00002184std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002185 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002186{
2187 return std::unique_ptr<IWorkload>();
2188}
2189
Derek Lamberti901ea112019-12-10 22:07:09 +00002190std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
2191 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002192{
2193 return std::unique_ptr<IWorkload>();
2194}
2195
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002196std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling3d(const Pooling3dQueueDescriptor& /*descriptor*/,
2197 const WorkloadInfo& /*info*/) const
2198{
2199 return std::unique_ptr<IWorkload>();
2200}
2201
Derek Lamberti901ea112019-12-10 22:07:09 +00002202std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
2203 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002204{
2205 return std::unique_ptr<IWorkload>();
2206}
2207
Derek Lamberti901ea112019-12-10 22:07:09 +00002208std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
2209 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002210{
2211 return std::unique_ptr<IWorkload>();
2212}
2213
Derek Lamberti901ea112019-12-10 22:07:09 +00002214std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
2215 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002216{
2217 return std::unique_ptr<IWorkload>();
2218}
2219
James Conroy586a9aa2020-03-20 08:49:33 +00002220std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
2221 const WorkloadInfo& /*info*/) const
2222{
2223 return std::unique_ptr<IWorkload>();
2224}
2225
Derek Lamberti901ea112019-12-10 22:07:09 +00002226std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
2227 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01002228{
2229 return std::unique_ptr<IWorkload>();
2230}
Finn Williams2605b232020-06-10 15:53:46 +01002231std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
2232 const WorkloadInfo& /*info*/) const
2233{
2234 return std::unique_ptr<IWorkload>();
2235}
James Conroyee18dc82019-07-17 11:27:46 +01002236
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002237std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
2238 const WorkloadInfo& /*info*/) const
2239{
2240 return std::unique_ptr<IWorkload>();
2241}
2242
Derek Lamberti901ea112019-12-10 22:07:09 +00002243std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
2244 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002245{
2246 return std::unique_ptr<IWorkload>();
2247}
2248
Derek Lamberti901ea112019-12-10 22:07:09 +00002249std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
2250 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01002251{
2252 return std::unique_ptr<IWorkload>();
2253}
2254
Keith Davis3ae3f972021-05-21 16:33:48 +01002255std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
2256 const WorkloadInfo& /*info*/) const
2257{
2258 return std::unique_ptr<IWorkload>();
2259}
2260
Derek Lamberti901ea112019-12-10 22:07:09 +00002261std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
2262 const WorkloadInfo& /*info*/) const
2263{
2264 return std::unique_ptr<IWorkload>();
2265}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002266
Derek Lamberti901ea112019-12-10 22:07:09 +00002267std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
2268 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002269{
2270 return std::unique_ptr<IWorkload>();
2271}
2272
Derek Lamberti901ea112019-12-10 22:07:09 +00002273std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
2274 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002275{
2276 return std::unique_ptr<IWorkload>();
2277}
2278
Derek Lamberti901ea112019-12-10 22:07:09 +00002279std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
2280 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002281{
2282 return std::unique_ptr<IWorkload>();
2283}
2284
Derek Lamberti901ea112019-12-10 22:07:09 +00002285std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
2286 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002287{
2288 return std::unique_ptr<IWorkload>();
2289}
2290
Derek Lamberti901ea112019-12-10 22:07:09 +00002291std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
2292 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01002293{
2294 return std::unique_ptr<IWorkload>();
2295}
2296
Derek Lamberti901ea112019-12-10 22:07:09 +00002297std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
2298 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01002299{
2300 return std::unique_ptr<IWorkload>();
2301}
2302
Derek Lamberti901ea112019-12-10 22:07:09 +00002303std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
2304 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002305{
2306 return std::unique_ptr<IWorkload>();
2307}
2308
Derek Lamberti901ea112019-12-10 22:07:09 +00002309std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
2310 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01002311{
2312 return std::unique_ptr<IWorkload>();
2313}
2314
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002315std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
2316 const WorkloadInfo& /*info*/) const
2317{
2318 return std::unique_ptr<IWorkload>();
2319}
2320
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002321std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00002322 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
2323 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002324{
2325 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01002326}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002327
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01002328std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
2329 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
2330 const WorkloadInfo& /*info*/) const
2331{
2332 return std::unique_ptr<IWorkload>();
2333}
2334
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002335} // namepsace armnn