blob: 00263eca04d364ca8e8b2f6aa092b32d6f834789 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010018#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
Matthew Sloyan81beae32021-07-13 19:46:11 +010039 return TensorInfo(info.GetShape(),
40 type.value(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
43 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010044}
45
David Beck29c75de2018-10-23 13:35:58 +010046} // anonymous namespace
47
Sadik Armagan045f6be2020-09-10 13:37:32 +010048bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
52 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000053{
David Beck33f0ae02018-10-18 15:13:56 +010054 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000055 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010056 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010057
David Beck111b5d92018-11-12 14:59:37 +000058 auto const& backendRegistry = BackendRegistryInstance();
59 if (!backendRegistry.IsBackendRegistered(backendId))
60 {
61 std::stringstream ss;
62 ss << connectableLayer.GetName() << " is not supported on " << backendId
63 << " because this backend is not registered.";
64
65 outReasonIfUnsupported = ss.str();
66 return false;
67 }
68
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000071 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010072
telsoa014fcda012018-03-09 14:13:49 +000073 switch(layer.GetType())
74 {
75 case LayerType::Activation:
76 {
Jan Eilersbb446e52020-04-02 13:56:54 +010077 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000080 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010081 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010084 reason);
telsoa014fcda012018-03-09 14:13:49 +000085 break;
86 }
87 case LayerType::Addition:
88 {
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000092 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010093 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010096 reason);
telsoa014fcda012018-03-09 14:13:49 +000097 break;
98 }
Nikhil Rajee391d52019-09-05 17:50:44 +010099 case LayerType::ArgMinMax:
100 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
103
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000106 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100107 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000108 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100109 descriptor,
110 reason);
111 break;
112 }
telsoa014fcda012018-03-09 14:13:49 +0000113 case LayerType::BatchNormalization:
114 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000122 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000138
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000143 break;
144 }
mathad01b392e982021-04-07 12:07:30 +0100145 case LayerType::Cast:
146 {
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
149
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
152 reason);
153 break;
154 }
Simon Obute51f67772021-09-03 15:50:13 +0100155 case LayerType::ChannelShuffle:
156 {
157 auto cLayer = PolymorphicDowncast<const ChannelShuffleLayer*>(&layer);
158
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161
162 const ChannelShuffleDescriptor descriptor = cLayer->GetParameters();
163
164 result = layerSupportObject.IsChannelShuffleSupported(OverrideDataType(input, dataType),
165 OverrideDataType(output, dataType),
166 descriptor,
167 reason);
168 break;
169 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100170 case LayerType::Comparison:
171 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100172 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100173
174 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
176 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
177
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000178 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
179 OverrideDataType(input1, dataType),
180 OverrideDataType(output, DataType::Boolean),
181 cLayer->GetParameters(),
182 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100183 break;
184 }
telsoa014fcda012018-03-09 14:13:49 +0000185 case LayerType::Constant:
186 {
187 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000188 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 break;
190 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000191 case LayerType::ConvertBf16ToFp32:
192 {
193 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
194 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000195 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000196 break;
197 }
telsoa01c577f2c2018-08-31 09:22:23 +0100198 case LayerType::ConvertFp16ToFp32:
199 {
200 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
201 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000202 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100203 break;
204 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000205 case LayerType::ConvertFp32ToBf16:
206 {
207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000209 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000210 break;
211 }
telsoa01c577f2c2018-08-31 09:22:23 +0100212 case LayerType::ConvertFp32ToFp16:
213 {
214 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000216 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000217 break;
218 }
219 case LayerType::Convolution2d:
220 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100221 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100222
223 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
224 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100225 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100226 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100227
arovir01a6824102018-08-28 17:40:45 +0100228 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100229
arovir01a6824102018-08-28 17:40:45 +0100230 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100231 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100232 if (descriptor.m_BiasEnabled)
233 {
David Beck5eec11d2018-10-04 15:43:17 +0100234 biases =
235 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100236 }
237
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000238 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100239 input,
240 output,
241 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100242 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100243 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100244 reason);
telsoa014fcda012018-03-09 14:13:49 +0000245 break;
246 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000247 case LayerType::Debug:
248 {
249 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
250 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
251
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000252 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000253 OverrideDataType(output, dataType),
254 reason);
255 break;
256 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100257 case LayerType::DepthToSpace:
258 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100259 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100260
261 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
262 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
263
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000264 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100265 OverrideDataType(output, dataType),
266 cLayer->GetParameters(),
267 reason);
268 break;
269 }
telsoa014fcda012018-03-09 14:13:49 +0000270 case LayerType::DepthwiseConvolution2d:
271 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100272 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100273 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
274 dataType);
275 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100276 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100277
telsoa01c577f2c2018-08-31 09:22:23 +0100278 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100279
280 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100281 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100282 if (descriptor.m_BiasEnabled)
283 {
David Beck5eec11d2018-10-04 15:43:17 +0100284 biases =
285 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100286 }
telsoa01c577f2c2018-08-31 09:22:23 +0100287
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000288 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100289 input,
290 output,
291 descriptor,
292 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100293 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100294 reason);
telsoa014fcda012018-03-09 14:13:49 +0000295 break;
296 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000297 case LayerType::Dequantize:
298 {
299 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
300 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
301
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000302 result = layerSupportObject.IsDequantizeSupported(input,
303 OverrideDataType(output, dataType),
304 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000305 break;
306 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000307 case LayerType::DetectionPostProcess:
308 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100309 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000310 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
311 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
312 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
313
314 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
315 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
316 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
317 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
318
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000319 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000320 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
321 scores,
322 anchors,
323 detectionBoxes,
324 detectionClasses,
325 detectionScores,
326 numDetections,
327 descriptor,
328 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000329 break;
330 }
josh minor4a3c6102020-01-06 16:40:46 -0600331 case LayerType::ElementwiseUnary:
332 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100333 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600334
335 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
336 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
337
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000338 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
339 OverrideDataType(output, dataType),
340 cLayer->GetParameters(),
341 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600342 break;
343 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100344 case LayerType::Fill:
345 {
346 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
347 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
348 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
349 const FillDescriptor& descriptor = cLayer->GetParameters();
350
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000351 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100352 OverrideDataType(input, dataType),
353 OverrideDataType(output, dataType),
354 descriptor,
355 reason);
356 break;
357 }
telsoa014fcda012018-03-09 14:13:49 +0000358 case LayerType::FakeQuantization:
359 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100360 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000361 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000362 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
363 cLayer->GetParameters(),
364 reason);
telsoa014fcda012018-03-09 14:13:49 +0000365 break;
366 }
367 case LayerType::Floor:
368 {
369 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
370 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000371 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
372 OverrideDataType(output, dataType),
373 reason);
telsoa014fcda012018-03-09 14:13:49 +0000374 break;
375 }
376 case LayerType::FullyConnected:
377 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100378 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000379 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100380 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000381
382 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
383 TensorInfo weightsInfo;
384 const TensorInfo* weightsInfoPtr = nullptr;
385
Matthew Sloyan81beae32021-07-13 19:46:11 +0100386 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000387 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100388
389 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000390 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000391 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100392 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
393 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
394 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
395
telsoa01c577f2c2018-08-31 09:22:23 +0100396 if (descriptor.m_BiasEnabled)
397 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100398 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
399 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100400 }
401 else
402 {
403 // If biases are not enabled pass a dummy tensorinfo for the validation
404 switch(input.GetDataType())
405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000406 case DataType::BFloat16:
407 {
408 biasInfoPtr = &dummyBFloat16Bias;
409 break;
410 }
telsoa01c577f2c2018-08-31 09:22:23 +0100411 case DataType::Float16:
412 {
413 biasInfoPtr = &dummyFloat16Bias;
414 break;
415 }
416 case DataType::Float32:
417 {
418 biasInfoPtr = &dummyFloat32Bias;
419 break;
420 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000421 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000422 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000423 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000424 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100425 {
426 biasInfoPtr = &dummyQA8Bias;
427 break;
428 }
429 default:
430 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100431 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100432 }
433 }
434 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000435 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100436 OverrideDataType(input, dataType),
437 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000438 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100439 *biasInfoPtr,
440 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100441 reason);
telsoa014fcda012018-03-09 14:13:49 +0000442 break;
443 }
narpra01b89b05f2019-01-16 09:53:09 +0000444 case LayerType::Gather:
445 {
446 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
447 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
448 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100449 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
450 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000451 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
452 input1,
453 OverrideDataType(output, dataType),
454 descriptor,
455 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000456 break;
457 }
telsoa014fcda012018-03-09 14:13:49 +0000458 case LayerType::Input:
459 {
460 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000461 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000462 break;
463 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100464 case LayerType::InstanceNormalization:
465 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100466 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100467 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
468
469 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
470 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
471
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000472 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100473 OverrideDataType(input, dataType),
474 OverrideDataType(output, dataType),
475 descriptor,
476 reason);
477 break;
478 }
telsoa014fcda012018-03-09 14:13:49 +0000479 case LayerType::L2Normalization:
480 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100481 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100482 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
483
telsoa014fcda012018-03-09 14:13:49 +0000484 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100485 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100486
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000487 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100488 OverrideDataType(input, dataType),
489 OverrideDataType(output, dataType),
490 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100491 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100492 break;
493 }
James Conroyaba90cd2020-11-06 16:28:18 +0000494 case LayerType::LogicalBinary:
495 {
496 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
497
498 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
499 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
500 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
501
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000502 result = layerSupportObject.IsLogicalBinarySupported(input0,
503 input1,
504 output,
505 cLayer->GetParameters(),
506 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000507 break;
508 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100509 case LayerType::LogSoftmax:
510 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100511 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100512
513 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
514 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000516 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
517 OverrideDataType(output, dataType),
518 cLayer->GetParameters(),
519 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100520 break;
521 }
telsoa01c577f2c2018-08-31 09:22:23 +0100522 case LayerType::Lstm:
523 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100524 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100525 const LstmDescriptor& descriptor = cLayer->GetParameters();
526
527 // All inputs.
528 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
529 dataType);
530 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
531 dataType);
532 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
533 dataType);
534 // All outputs
535 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
536 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
537 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
538 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
539
540 // Basic parameters
541 const TensorInfo& inputToForgetWeights
542 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
543 const TensorInfo& inputToCellWeights
544 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
545 const TensorInfo& inputToOutputWeights
546 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
547 const TensorInfo& recurrentToForgetWeights
548 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
549 const TensorInfo& recurrentToCellWeights
550 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
551 const TensorInfo& recurrentToOutputWeights
552 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
553 const TensorInfo& forgetGateBias
554 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
555 const TensorInfo& cellBias
556 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
557 const TensorInfo& outputGateBias
558 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
559
Jan Eilersd01a83c2019-07-03 18:20:40 +0100560 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100561
Jan Eilersd01a83c2019-07-03 18:20:40 +0100562 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
563 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
564 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
565 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
566 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
567 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
568 paramsInfo.m_ForgetGateBias = &forgetGateBias;
569 paramsInfo.m_CellBias = &cellBias;
570 paramsInfo.m_OutputGateBias = &outputGateBias;
571
572
573 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100574 TensorInfo optInputToInputWeights;
575 TensorInfo optRecurrentToInputWeights;
576 TensorInfo optCellToInputWeights;
577 TensorInfo optInputGateBias;
578 TensorInfo optProjectionWeights;
579 TensorInfo optProjectionBias;
580 TensorInfo optCellToForgetWeights;
581 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100582 TensorInfo optInputLayerNormWeights;
583 TensorInfo optForgetLayerNormWeights;
584 TensorInfo optCellLayerNormWeights;
585 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100586
587 if(!descriptor.m_CifgEnabled)
588 {
589 optInputToInputWeights =
590 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100591 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100592
593 optRecurrentToInputWeights =
594 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100595 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100596 optInputGateBias =
597 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100598 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100599 }
600
601 if(descriptor.m_ProjectionEnabled)
602 {
603 optProjectionWeights =
604 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100605 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100606 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
607 {
608 optProjectionBias =
609 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100610 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100611 }
612 }
613
614 if(descriptor.m_PeepholeEnabled)
615 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100616 if(!descriptor.m_CifgEnabled)
617 {
618 optCellToInputWeights =
619 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
620 dataType);
621 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
622 }
telsoa01c577f2c2018-08-31 09:22:23 +0100623 optCellToForgetWeights =
624 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100625 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100626 optCellToOutputWeights =
627 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100628 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100629 }
630
Jan Eilers38e05bd2019-06-26 13:10:09 +0100631 if(descriptor.m_LayerNormEnabled)
632 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100633 if (!descriptor.m_CifgEnabled)
634 {
635 optInputLayerNormWeights = OverrideDataType(
636 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
637 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
638 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100639
640 optForgetLayerNormWeights = OverrideDataType(
641 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100642 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100643
644 optCellLayerNormWeights = OverrideDataType(
645 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100646 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100647
648 optOutputLayerNormWeights = OverrideDataType(
649 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100650 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100651 }
652
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000653 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100654 input,
655 outputStateIn,
656 cellStateIn,
657 scratchBuffer,
658 outputStateOut,
659 cellStateOut,
660 output,
661 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100662 paramsInfo,
663 reason);
telsoa014fcda012018-03-09 14:13:49 +0000664 break;
665 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000666 case LayerType::Maximum:
667 {
668 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
669 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
670 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
671
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000672 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
673 OverrideDataType(input1, dataType),
674 OverrideDataType(output, dataType),
675 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000676 break;
677 }
narpra01b89b05f2019-01-16 09:53:09 +0000678 case LayerType::MemCopy:
679 {
680 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
681 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000682
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000683 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
684 OverrideDataType(output, dataType),
685 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000686 break;
687 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100688 case LayerType::MemImport:
689 {
690 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
691 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
692
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000693 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
694 OverrideDataType(output, dataType),
695 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100696 break;
697 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100698 case LayerType::Merge:
699 {
700 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
701 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
702 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
703
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000704 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
705 OverrideDataType(input1, dataType),
706 OverrideDataType(output, dataType),
707 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100708 break;
709 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100710 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000711 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100712 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000713
telsoa01c577f2c2018-08-31 09:22:23 +0100714 // Get vector of all inputs.
715 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000716 {
telsoa01c577f2c2018-08-31 09:22:23 +0100717 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000718 };
Finn Williams3e54d032020-10-22 16:53:35 +0100719
720 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
721 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100722 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000723
telsoa01c577f2c2018-08-31 09:22:23 +0100724 auto getTensorInfoPtr = [](const TensorInfo& info)
725 {
726 return &info;
727 };
Finn Williams3e54d032020-10-22 16:53:35 +0100728
729 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
730 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100731 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000732
Nikhil Raj8599a412018-11-19 14:51:07 +0000733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
734
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000735 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100736
737
telsoa014fcda012018-03-09 14:13:49 +0000738 break;
739 }
740 case LayerType::Multiplication:
741 {
742 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
743 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100744 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000745 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100746 OverrideDataType(input0, dataType),
747 OverrideDataType(input1, dataType),
748 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100749 reason);
telsoa014fcda012018-03-09 14:13:49 +0000750 break;
751 }
752 case LayerType::Normalization:
753 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100754 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000755 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
756 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000757 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
758 OverrideDataType(output, dataType),
759 cLayer->GetParameters(),
760 reason);
telsoa014fcda012018-03-09 14:13:49 +0000761 break;
762 }
763 case LayerType::Output:
764 {
765 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000766 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000767 break;
768 }
769 case LayerType::Permute:
770 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100771 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000772 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
773 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000774 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
775 OverrideDataType(output, dataType),
776 cLayer->GetParameters(),
777 reason);
telsoa014fcda012018-03-09 14:13:49 +0000778 break;
779 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100780 case LayerType::Pad:
781 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100782 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100783 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
784 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000785 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100786 OverrideDataType(input, dataType),
787 OverrideDataType(output, dataType),
788 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100789 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100790 break;
791 }
telsoa014fcda012018-03-09 14:13:49 +0000792 case LayerType::Pooling2d:
793 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100794 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000795 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
796 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000797 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
798 OverrideDataType(output, dataType),
799 cLayer->GetParameters(),
800 reason);
telsoa014fcda012018-03-09 14:13:49 +0000801 break;
802 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000803 case LayerType::PreCompiled:
804 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100805 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000806 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000807 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
808 cLayer->GetParameters(),
809 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000810 break;
811 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000812 case LayerType::Quantize:
813 {
814 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
815 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000816 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000817 break;
818 }
James Conroy586a9aa2020-03-20 08:49:33 +0000819 case LayerType::QLstm:
820 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100821 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000822 const QLstmDescriptor& descriptor = cLayer->GetParameters();
823
824 // Inputs
825 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
826 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
827 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
828
829 // Outputs
830 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
831 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
832 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
833
834 // Lstm parameters
835 LstmInputParamsInfo paramsInfo;
836
837 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100838 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
839 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
840 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000841 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
842 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
843 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
844
845 paramsInfo.m_RecurrentToForgetWeights =
846 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
847 paramsInfo.m_RecurrentToCellWeights =
848 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
849 paramsInfo.m_RecurrentToOutputWeights =
850 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
851
852 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
853 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
854 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
855
856 if(!descriptor.m_CifgEnabled)
857 {
858 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
859 paramsInfo.m_RecurrentToInputWeights =
860 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
861 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
862 }
863
864 if(descriptor.m_ProjectionEnabled)
865 {
866 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100867
868 // Projection bias is optional even if projection is enabled
869 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
870 {
871 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
872 }
James Conroy586a9aa2020-03-20 08:49:33 +0000873 }
874
875 if(descriptor.m_PeepholeEnabled)
876 {
877 if (!descriptor.m_CifgEnabled)
878 {
879 paramsInfo.m_CellToInputWeights =
880 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
881 }
882
883 paramsInfo.m_CellToForgetWeights =
884 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
885 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
886 }
887
888 if(descriptor.m_LayerNormEnabled)
889 {
890 if (!descriptor.m_CifgEnabled)
891 {
892 paramsInfo.m_InputLayerNormWeights =
893 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
894 }
895
896 paramsInfo.m_ForgetLayerNormWeights =
897 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
898 paramsInfo.m_CellLayerNormWeights =
899 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
900 paramsInfo.m_OutputLayerNormWeights =
901 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
902 }
903
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000904 result = layerSupportObject.IsQLstmSupported(input,
905 previousOutputIn,
906 previousCellStateIn,
907 outputStateOut,
908 cellStateOut,
909 output,
910 descriptor,
911 paramsInfo,
912 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000913 break;
914 }
James Conroyee18dc82019-07-17 11:27:46 +0100915 case LayerType::QuantizedLstm:
916 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100917 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100918
919 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100920 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
921 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
922 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100923
924 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100925 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
926 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100927
928 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100929 QuantizedLstmInputParamsInfo paramsInfo;
930
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100931 paramsInfo.m_InputToInputWeights =
932 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
933 paramsInfo.m_InputToForgetWeights =
934 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
935 paramsInfo.m_InputToCellWeights =
936 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
937 paramsInfo.m_InputToOutputWeights =
938 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100939
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100940 paramsInfo.m_RecurrentToInputWeights =
941 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
942 paramsInfo.m_RecurrentToForgetWeights =
943 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
944 paramsInfo.m_RecurrentToCellWeights =
945 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
946 paramsInfo.m_RecurrentToOutputWeights =
947 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100948
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100949 paramsInfo.m_InputGateBias =
950 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
951 paramsInfo.m_ForgetGateBias =
952 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
953 paramsInfo.m_CellBias =
954 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
955 paramsInfo.m_OutputGateBias =
956 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100957
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000958 result = layerSupportObject.IsQuantizedLstmSupported(input,
959 previousCellStateIn,
960 previousOutputIn,
961 cellStateOut,
962 output,
963 paramsInfo,
964 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100965 break;
966 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100967 case LayerType::Division:
968 {
969 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
970 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
971 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000972 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100973 OverrideDataType(input0, dataType),
974 OverrideDataType(input1, dataType),
975 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100976 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100977 break;
978 }
Finn Williams2605b232020-06-10 15:53:46 +0100979 case LayerType::Rank:
980 {
981 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
982 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000983 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
984 OverrideDataType(output, dataType),
985 reason);
Finn Williams2605b232020-06-10 15:53:46 +0100986 break;
987 }
telsoa014fcda012018-03-09 14:13:49 +0000988 case LayerType::Reshape:
989 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100990 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000991 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000992 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000993 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
994 OverrideDataType(output, dataType),
995 cLayer->GetParameters(),
996 reason);
telsoa014fcda012018-03-09 14:13:49 +0000997 break;
998 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100999 case LayerType::Resize:
1000 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001001 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +01001002 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +01001003 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001004 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1005 OverrideDataType(output, dataType),
1006 cLayer->GetParameters(),
1007 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +01001008 break;
1009 }
Keith Davis3ae3f972021-05-21 16:33:48 +01001010 case LayerType::Shape:
1011 {
1012 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1013 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1014
1015 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1016 OverrideDataType(output, dataType),
1017 reason);
1018 break;
1019 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001020 case LayerType::Slice:
1021 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001022 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001023
1024 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1025 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1026
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001027 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1028 OverrideDataType(output, dataType),
1029 cLayer->GetParameters(),
1030 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001031 break;
1032 }
telsoa014fcda012018-03-09 14:13:49 +00001033 case LayerType::Softmax:
1034 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001035 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001036 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001037 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001038 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1039 OverrideDataType(output, dataType),
1040 cLayer->GetParameters(),
1041 reason);
telsoa014fcda012018-03-09 14:13:49 +00001042 break;
1043 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001044 case LayerType::SpaceToBatchNd:
1045 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001046 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001047 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1048 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001049 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1050 OverrideDataType(output, dataType),
1051 cLayer->GetParameters(),
1052 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001053 break;
1054 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001055 case LayerType::SpaceToDepth:
1056 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001057 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001058
1059 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1060 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1061
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001062 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1063 OverrideDataType(output, dataType),
1064 cLayer->GetParameters(),
1065 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001066 break;
1067 }
telsoa014fcda012018-03-09 14:13:49 +00001068 case LayerType::Splitter:
1069 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001070 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001071 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001072
1073 // Get vector of all outputs.
1074 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1075 {
1076 return OverrideDataType(slot.GetTensorInfo(), dataType);
1077 };
Finn Williams3e54d032020-10-22 16:53:35 +01001078 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1079 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001080 std::vector<TensorInfo> outputs(beginI, endI);
1081
1082 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1083
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001084 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1085 outputPtrs,
1086 cLayer->GetParameters(),
1087 reason);
telsoa014fcda012018-03-09 14:13:49 +00001088 break;
1089 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001090 case LayerType::Stack:
1091 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001092 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001093
1094 // Get vector of all inputs.
1095 auto getTensorInfo = [&dataType](const InputSlot& slot)
1096 {
1097 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1098 };
Finn Williams3e54d032020-10-22 16:53:35 +01001099 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1100 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001101 std::vector<TensorInfo> inputs(beginI, endI);
1102
1103 auto getTensorInfoPtr = [](const TensorInfo& info)
1104 {
1105 return &info;
1106 };
Finn Williams3e54d032020-10-22 16:53:35 +01001107 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1108 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001109 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1110
1111 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1112
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001113 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001114
1115 break;
1116 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001117 case LayerType::StandIn:
1118 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001119 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001120
1121 // Get vector of all inputs.
1122 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1123 {
1124 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1125 };
1126 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1127 {
1128 return OverrideDataType(slot.GetTensorInfo(), dataType);
1129 };
Finn Williams3e54d032020-10-22 16:53:35 +01001130 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1131 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001132 std::vector<TensorInfo> inputs(beginI, endI);
1133
Finn Williams3e54d032020-10-22 16:53:35 +01001134 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1135 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001136 std::vector<TensorInfo> outputs(beginO, endO);
1137
1138
1139 auto getTensorInfoPtr = [](const TensorInfo& info)
1140 {
1141 return &info;
1142 };
Finn Williams3e54d032020-10-22 16:53:35 +01001143 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1144 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001145 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1146
Finn Williams3e54d032020-10-22 16:53:35 +01001147 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1148 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001149 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1150
1151
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001152 result = layerSupportObject.IsStandInSupported(inputPtrs,
1153 outputPtrs,
1154 cLayer->GetParameters(),
1155 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001156 break;
1157 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001158 case LayerType::StridedSlice:
1159 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001160 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001161 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1162 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001163 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1164 OverrideDataType(output, dataType),
1165 cLayer->GetParameters(),
1166 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001167 break;
1168 }
David Beckc2044fe2018-09-05 15:00:38 +01001169 case LayerType::Subtraction:
1170 {
1171 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1172 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1173 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001174 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001175 OverrideDataType(input0, dataType),
1176 OverrideDataType(input1, dataType),
1177 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001178 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001179 break;
1180 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001181 case LayerType::Switch:
1182 {
1183 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1184 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1185 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1186 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001187 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1188 OverrideDataType(input1, dataType),
1189 OverrideDataType(output0, dataType),
1190 OverrideDataType(output1, dataType),
1191 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001192 break;
1193 }
narpra0132b90462018-09-13 11:07:48 +01001194 case LayerType::Mean:
1195 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001196 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001197 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1198 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001199 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001200 OverrideDataType(input, dataType),
1201 OverrideDataType(output, dataType),
1202 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001203 reason);
narpra0132b90462018-09-13 11:07:48 +01001204 break;
1205 }
kevmay0190539692018-11-29 08:40:19 +00001206 case LayerType::Minimum:
1207 {
1208 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1209 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1210 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001211 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1212 OverrideDataType(input1, dataType),
1213 OverrideDataType(output, dataType),
1214 reason);
kevmay0190539692018-11-29 08:40:19 +00001215 break;
1216 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001217 case LayerType::Prelu:
1218 {
1219 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1220 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001222 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1223 OverrideDataType(alpha, dataType),
1224 OverrideDataType(output, dataType),
1225 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001226 break;
1227 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001228 case LayerType::Transpose:
1229 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001230 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001231 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1232 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001233 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1234 OverrideDataType(output, dataType),
1235 cLayer->GetParameters(),
1236 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001237 break;
1238 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001239 case LayerType::TransposeConvolution2d:
1240 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001241 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001242
1243 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1244 dataType);
1245 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1246
1247 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1248
1249 Optional<TensorInfo> biases;
1250 if (descriptor.m_BiasEnabled)
1251 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001252 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001253 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1254 GetBiasTypeFromWeightsType(dataType));
1255 }
1256
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001257 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001258 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1259
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001260 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1261 output,
1262 descriptor,
1263 weights,
1264 biases,
1265 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001266
1267 break;
1268 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001269 case LayerType::Reduce:
1270 {
1271 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1272 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1273 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1274
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001275 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1276 OverrideDataType(output, dataType),
1277 cLayer->GetParameters(),
1278 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001279 break;
1280 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001281 case LayerType::UnidirectionalSequenceLstm:
1282 {
1283 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1284 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1285
1286 // All inputs.
1287 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1288 dataType);
1289 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1290 dataType);
1291 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1292 dataType);
1293 // Outputs
1294 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1295
1296 // Basic parameters
1297 const TensorInfo& inputToForgetWeights
1298 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1299 const TensorInfo& inputToCellWeights
1300 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1301 const TensorInfo& inputToOutputWeights
1302 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1303 const TensorInfo& recurrentToForgetWeights
1304 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1305 const TensorInfo& recurrentToCellWeights
1306 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1307 const TensorInfo& recurrentToOutputWeights
1308 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1309 const TensorInfo& forgetGateBias
1310 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1311 const TensorInfo& cellBias
1312 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1313 const TensorInfo& outputGateBias
1314 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1315
1316 LstmInputParamsInfo paramsInfo;
1317
1318 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1319 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1320 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1321 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1322 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1323 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1324 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1325 paramsInfo.m_CellBias = &cellBias;
1326 paramsInfo.m_OutputGateBias = &outputGateBias;
1327
1328 // Optional parameters
1329 TensorInfo optInputToInputWeights;
1330 TensorInfo optRecurrentToInputWeights;
1331 TensorInfo optCellToInputWeights;
1332 TensorInfo optInputGateBias;
1333 TensorInfo optProjectionWeights;
1334 TensorInfo optProjectionBias;
1335 TensorInfo optCellToForgetWeights;
1336 TensorInfo optCellToOutputWeights;
1337 TensorInfo optInputLayerNormWeights;
1338 TensorInfo optForgetLayerNormWeights;
1339 TensorInfo optCellLayerNormWeights;
1340 TensorInfo optOutputLayerNormWeights;
1341
1342 if(!descriptor.m_CifgEnabled)
1343 {
1344 optInputToInputWeights =
1345 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1346 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1347
1348 optRecurrentToInputWeights =
1349 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1350 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1351 optInputGateBias =
1352 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1353 paramsInfo.m_InputGateBias = &optInputGateBias;
1354 }
1355
1356 if(descriptor.m_ProjectionEnabled)
1357 {
1358 optProjectionWeights =
1359 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1360 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1361 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1362 {
1363 optProjectionBias =
1364 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1365 paramsInfo.m_ProjectionBias = &optProjectionBias;
1366 }
1367 }
1368
1369 if(descriptor.m_PeepholeEnabled)
1370 {
1371 if(!descriptor.m_CifgEnabled)
1372 {
1373 optCellToInputWeights =
1374 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1375 dataType);
1376 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1377 }
1378 optCellToForgetWeights =
1379 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1380 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1381 optCellToOutputWeights =
1382 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1383 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1384 }
1385
1386 if(descriptor.m_LayerNormEnabled)
1387 {
1388 if (!descriptor.m_CifgEnabled)
1389 {
1390 optInputLayerNormWeights = OverrideDataType(
1391 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1392 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1393 }
1394
1395 optForgetLayerNormWeights = OverrideDataType(
1396 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1397 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1398
1399 optCellLayerNormWeights = OverrideDataType(
1400 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1401 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1402
1403 optOutputLayerNormWeights = OverrideDataType(
1404 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1405 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1406 }
1407
1408 Optional<TensorInfo> hiddenStateOut;
1409 Optional<TensorInfo> cellStateOut;
1410
1411 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1412 outputStateIn,
1413 cellStateIn,
1414 output,
1415 hiddenStateOut,
1416 cellStateOut,
1417 descriptor,
1418 paramsInfo,
1419 reason);
1420 break;
1421 }
telsoa014fcda012018-03-09 14:13:49 +00001422 default:
1423 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001424 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001425 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001426 result = false;
1427 break;
1428 }
1429 }
telsoa014fcda012018-03-09 14:13:49 +00001430 return result;
1431}
1432
Sadik Armagan045f6be2020-09-10 13:37:32 +01001433bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1434 const IConnectableLayer& connectableLayer,
1435 Optional<DataType> dataType,
1436 std::string& outReasonIfUnsupported)
1437{
1438 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1439}
1440
David Beckdcb751f2018-10-03 11:42:42 +01001441bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001442 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001443 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001444{
Jan Eilersbb446e52020-04-02 13:56:54 +01001445 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001446 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1447}
1448
1449// TODO merge with defaulted modelOptions above
1450bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1451 Optional<DataType> dataType,
1452 std::string& outReasonIfUnsupported,
1453 const ModelOptions& modelOptions)
1454{
1455 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1456 return IsLayerConfigurationSupported(layer->GetBackendId(),
1457 connectableLayer,
1458 dataType,
1459 outReasonIfUnsupported,
1460 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001461}
1462
Sadik Armagan04a72972020-09-14 15:44:18 +01001463bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1464 const IConnectableLayer& connectableLayer,
1465 Optional<DataType> dataType,
1466 std::string& outReasonIfUnsupported,
1467 const ModelOptions& modelOptions)
1468{
1469 return IsLayerConfigurationSupported(backendId,
1470 connectableLayer,
1471 dataType,
1472 outReasonIfUnsupported,
1473 modelOptions);
1474}
1475
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001476// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001477std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1478 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001479{
1480 return std::unique_ptr<IWorkload>();
1481}
1482
Derek Lamberti901ea112019-12-10 22:07:09 +00001483std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1484 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001485{
1486 return std::unique_ptr<IWorkload>();
1487}
1488
Derek Lamberti901ea112019-12-10 22:07:09 +00001489std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1490 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001491{
1492 return std::unique_ptr<IWorkload>();
1493}
1494
Derek Lamberti901ea112019-12-10 22:07:09 +00001495std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1496 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001497{
1498 return std::unique_ptr<IWorkload>();
1499}
1500
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001501std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001502 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001503{
1504 return std::unique_ptr<IWorkload>();
1505}
1506
Derek Lamberti901ea112019-12-10 22:07:09 +00001507std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1508 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001509{
1510 return std::unique_ptr<IWorkload>();
1511}
1512
mathad01b392e982021-04-07 12:07:30 +01001513std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1514 const WorkloadInfo& /*info*/) const
1515{
1516 return std::unique_ptr<IWorkload>();
1517}
1518
Simon Obute51f67772021-09-03 15:50:13 +01001519std::unique_ptr<IWorkload> IWorkloadFactory::CreateChannelShuffle(const ChannelShuffleQueueDescriptor& /*descriptor*/,
1520 const WorkloadInfo& /*info*/) const
1521{
1522 return std::unique_ptr<IWorkload>();
1523}
1524
Derek Lamberti901ea112019-12-10 22:07:09 +00001525std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1526 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001527{
1528 return std::unique_ptr<IWorkload>();
1529}
1530
Derek Lamberti901ea112019-12-10 22:07:09 +00001531std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1532 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001533{
1534 return std::unique_ptr<IWorkload>();
1535}
1536
Derek Lamberti901ea112019-12-10 22:07:09 +00001537std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1538 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001539{
1540 return std::unique_ptr<IWorkload>();
1541}
1542
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001543std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1544 const WorkloadInfo& /*info*/) const
1545{
1546 return std::unique_ptr<IWorkload>();
1547}
1548
Derek Lamberti901ea112019-12-10 22:07:09 +00001549std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1550 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001551{
1552 return std::unique_ptr<IWorkload>();
1553}
1554
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001555std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1556 const WorkloadInfo& /*info*/) const
1557{
1558 return std::unique_ptr<IWorkload>();
1559}
1560
Derek Lamberti901ea112019-12-10 22:07:09 +00001561std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1562 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001563{
1564 return std::unique_ptr<IWorkload>();
1565}
1566
Derek Lamberti901ea112019-12-10 22:07:09 +00001567std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1568 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001569{
1570 return std::unique_ptr<IWorkload>();
1571}
1572
Derek Lamberti901ea112019-12-10 22:07:09 +00001573std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1574 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001575{
1576 return std::unique_ptr<IWorkload>();
1577}
1578
Derek Lamberti901ea112019-12-10 22:07:09 +00001579std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1580 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001581{
1582 return std::unique_ptr<IWorkload>();
1583}
1584
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001585std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001586 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001587{
1588 return std::unique_ptr<IWorkload>();
1589}
1590
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001591std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001592 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001593{
1594 return std::unique_ptr<IWorkload>();
1595}
1596
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001597std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001598 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001599{
1600 return std::unique_ptr<IWorkload>();
1601}
1602
Derek Lamberti901ea112019-12-10 22:07:09 +00001603std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1604 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001605{
1606 return std::unique_ptr<IWorkload>();
1607}
1608
josh minor4a3c6102020-01-06 16:40:46 -06001609std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1610 const WorkloadInfo& /*info*/) const
1611{
1612 return std::unique_ptr<IWorkload>();
1613}
1614
Derek Lamberti901ea112019-12-10 22:07:09 +00001615std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1616 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001617{
1618 return std::unique_ptr<IWorkload>();
1619}
1620
Derek Lamberti901ea112019-12-10 22:07:09 +00001621std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1622 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001623{
1624 return std::unique_ptr<IWorkload>();
1625}
1626
Ryan OSheaec6c6802020-06-05 17:17:06 +01001627std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1628 const WorkloadInfo& /*info*/) const
1629{
1630 return std::unique_ptr<IWorkload>();
1631}
1632
Derek Lamberti901ea112019-12-10 22:07:09 +00001633std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1634 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001635{
1636 return std::unique_ptr<IWorkload>();
1637}
1638
Derek Lamberti901ea112019-12-10 22:07:09 +00001639std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1640 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001641{
1642 return std::unique_ptr<IWorkload>();
1643}
1644
Derek Lamberti901ea112019-12-10 22:07:09 +00001645std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1646 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001647{
1648 return std::unique_ptr<IWorkload>();
1649}
1650
Derek Lamberti901ea112019-12-10 22:07:09 +00001651std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1652 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001653{
1654 return std::unique_ptr<IWorkload>();
1655}
1656
Kevin Mayce5045a2019-10-02 14:07:47 +01001657std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001658 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1659 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001660{
1661 return std::unique_ptr<IWorkload>();
1662}
1663
Derek Lamberti901ea112019-12-10 22:07:09 +00001664std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1665 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001666{
1667 return std::unique_ptr<IWorkload>();
1668}
1669
James Conroyaba90cd2020-11-06 16:28:18 +00001670std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1671 const WorkloadInfo& /*info*/) const
1672{
1673 return std::unique_ptr<IWorkload>();
1674}
1675
1676std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1677 const WorkloadInfo& /*info*/) const
1678{
1679 return std::unique_ptr<IWorkload>();
1680}
1681
Derek Lamberti901ea112019-12-10 22:07:09 +00001682std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1683 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001684{
1685 return std::unique_ptr<IWorkload>();
1686}
1687
Derek Lamberti901ea112019-12-10 22:07:09 +00001688std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1689 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001690{
1691 return std::unique_ptr<IWorkload>();
1692}
1693
Derek Lamberti901ea112019-12-10 22:07:09 +00001694std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1695 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001696{
1697 return std::unique_ptr<IWorkload>();
1698}
1699
Derek Lamberti901ea112019-12-10 22:07:09 +00001700std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1701 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001702{
1703 return std::unique_ptr<IWorkload>();
1704}
1705
Derek Lamberti901ea112019-12-10 22:07:09 +00001706std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1707 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001708{
1709 return std::unique_ptr<IWorkload>();
1710}
1711
Derek Lamberti901ea112019-12-10 22:07:09 +00001712std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1713 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001714{
1715 return std::unique_ptr<IWorkload>();
1716}
1717
Derek Lamberti901ea112019-12-10 22:07:09 +00001718std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1719 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001720{
1721 return std::unique_ptr<IWorkload>();
1722}
1723
Derek Lamberti901ea112019-12-10 22:07:09 +00001724std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1725 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001726{
1727 return std::unique_ptr<IWorkload>();
1728}
1729
Derek Lamberti901ea112019-12-10 22:07:09 +00001730std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1731 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001732{
1733 return std::unique_ptr<IWorkload>();
1734}
1735
Derek Lamberti901ea112019-12-10 22:07:09 +00001736std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1737 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001738{
1739 return std::unique_ptr<IWorkload>();
1740}
1741
Derek Lamberti901ea112019-12-10 22:07:09 +00001742std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1743 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001744{
1745 return std::unique_ptr<IWorkload>();
1746}
1747
Derek Lamberti901ea112019-12-10 22:07:09 +00001748std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1749 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001750{
1751 return std::unique_ptr<IWorkload>();
1752}
1753
Derek Lamberti901ea112019-12-10 22:07:09 +00001754std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1755 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001756{
1757 return std::unique_ptr<IWorkload>();
1758}
1759
Derek Lamberti901ea112019-12-10 22:07:09 +00001760std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001761 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001762{
1763 return std::unique_ptr<IWorkload>();
1764}
1765
Derek Lamberti901ea112019-12-10 22:07:09 +00001766std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1767 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001768{
1769 return std::unique_ptr<IWorkload>();
1770}
1771
Derek Lamberti901ea112019-12-10 22:07:09 +00001772std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1773 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001774{
1775 return std::unique_ptr<IWorkload>();
1776}
1777
Derek Lamberti901ea112019-12-10 22:07:09 +00001778std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1779 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001780{
1781 return std::unique_ptr<IWorkload>();
1782}
1783
Derek Lamberti901ea112019-12-10 22:07:09 +00001784std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1785 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001786{
1787 return std::unique_ptr<IWorkload>();
1788}
1789
James Conroy586a9aa2020-03-20 08:49:33 +00001790std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1791 const WorkloadInfo& /*info*/) const
1792{
1793 return std::unique_ptr<IWorkload>();
1794}
1795
Derek Lamberti901ea112019-12-10 22:07:09 +00001796std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1797 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001798{
1799 return std::unique_ptr<IWorkload>();
1800}
Finn Williams2605b232020-06-10 15:53:46 +01001801std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1802 const WorkloadInfo& /*info*/) const
1803{
1804 return std::unique_ptr<IWorkload>();
1805}
James Conroyee18dc82019-07-17 11:27:46 +01001806
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001807std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1808 const WorkloadInfo& /*info*/) const
1809{
1810 return std::unique_ptr<IWorkload>();
1811}
1812
Derek Lamberti901ea112019-12-10 22:07:09 +00001813std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1814 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001815{
1816 return std::unique_ptr<IWorkload>();
1817}
1818
Derek Lamberti901ea112019-12-10 22:07:09 +00001819std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1820 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001821{
1822 return std::unique_ptr<IWorkload>();
1823}
1824
Derek Lamberti901ea112019-12-10 22:07:09 +00001825std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1826 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001827{
1828 return std::unique_ptr<IWorkload>();
1829}
1830
Derek Lamberti901ea112019-12-10 22:07:09 +00001831std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1832 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001833{
1834 return std::unique_ptr<IWorkload>();
1835}
1836
Keith Davis3ae3f972021-05-21 16:33:48 +01001837std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1838 const WorkloadInfo& /*info*/) const
1839{
1840 return std::unique_ptr<IWorkload>();
1841}
1842
Derek Lamberti901ea112019-12-10 22:07:09 +00001843std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1844 const WorkloadInfo& /*info*/) const
1845{
1846 return std::unique_ptr<IWorkload>();
1847}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001848
Derek Lamberti901ea112019-12-10 22:07:09 +00001849std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1850 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001851{
1852 return std::unique_ptr<IWorkload>();
1853}
1854
Derek Lamberti901ea112019-12-10 22:07:09 +00001855std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1856 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001857{
1858 return std::unique_ptr<IWorkload>();
1859}
1860
Derek Lamberti901ea112019-12-10 22:07:09 +00001861std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1862 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001863{
1864 return std::unique_ptr<IWorkload>();
1865}
1866
Derek Lamberti901ea112019-12-10 22:07:09 +00001867std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1868 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001869{
1870 return std::unique_ptr<IWorkload>();
1871}
1872
Derek Lamberti901ea112019-12-10 22:07:09 +00001873std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1874 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001875{
1876 return std::unique_ptr<IWorkload>();
1877}
1878
Derek Lamberti901ea112019-12-10 22:07:09 +00001879std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1880 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001881{
1882 return std::unique_ptr<IWorkload>();
1883}
1884
Derek Lamberti901ea112019-12-10 22:07:09 +00001885std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1886 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001887{
1888 return std::unique_ptr<IWorkload>();
1889}
1890
Derek Lamberti901ea112019-12-10 22:07:09 +00001891std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1892 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001893{
1894 return std::unique_ptr<IWorkload>();
1895}
1896
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001897std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1898 const WorkloadInfo& /*info*/) const
1899{
1900 return std::unique_ptr<IWorkload>();
1901}
1902
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001903std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001904 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1905 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001906{
1907 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001908}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001909
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001910std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1911 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1912 const WorkloadInfo& /*info*/) const
1913{
1914 return std::unique_ptr<IWorkload>();
1915}
1916
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001917} // namepsace armnn