blob: 3f5972dab6b853a4cdcff8d8c51a8030a8d23eac [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
Francis Murtaghcae45682021-04-26 10:07:49 +010011#include <armnn/backends/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010018#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
Matthew Sloyan81beae32021-07-13 19:46:11 +010039 return TensorInfo(info.GetShape(),
40 type.value(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
43 info.IsConstant());
telsoa01c577f2c2018-08-31 09:22:23 +010044}
45
David Beck29c75de2018-10-23 13:35:58 +010046} // anonymous namespace
47
Sadik Armagan045f6be2020-09-10 13:37:32 +010048bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
52 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000053{
David Beck33f0ae02018-10-18 15:13:56 +010054 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000055 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010056 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010057
David Beck111b5d92018-11-12 14:59:37 +000058 auto const& backendRegistry = BackendRegistryInstance();
59 if (!backendRegistry.IsBackendRegistered(backendId))
60 {
61 std::stringstream ss;
62 ss << connectableLayer.GetName() << " is not supported on " << backendId
63 << " because this backend is not registered.";
64
65 outReasonIfUnsupported = ss.str();
66 return false;
67 }
68
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000071 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010072
telsoa014fcda012018-03-09 14:13:49 +000073 switch(layer.GetType())
74 {
75 case LayerType::Activation:
76 {
Jan Eilersbb446e52020-04-02 13:56:54 +010077 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000078 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010079 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000080 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010081 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010084 reason);
telsoa014fcda012018-03-09 14:13:49 +000085 break;
86 }
87 case LayerType::Addition:
88 {
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000092 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010093 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010096 reason);
telsoa014fcda012018-03-09 14:13:49 +000097 break;
98 }
Nikhil Rajee391d52019-09-05 17:50:44 +010099 case LayerType::ArgMinMax:
100 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +0100102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
103
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000106 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100107 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000108 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100109 descriptor,
110 reason);
111 break;
112 }
telsoa014fcda012018-03-09 14:13:49 +0000113 case LayerType::BatchNormalization:
114 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000122 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100130 reason);
telsoa014fcda012018-03-09 14:13:49 +0000131 break;
132 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133 case LayerType::BatchToSpaceNd:
134 {
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000138
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
142 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000143 break;
144 }
mathad01b392e982021-04-07 12:07:30 +0100145 case LayerType::Cast:
146 {
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
149
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
152 reason);
153 break;
154 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100155 case LayerType::Comparison:
156 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100157 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100158
159 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
162
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
164 OverrideDataType(input1, dataType),
165 OverrideDataType(output, DataType::Boolean),
166 cLayer->GetParameters(),
167 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100168 break;
169 }
telsoa014fcda012018-03-09 14:13:49 +0000170 case LayerType::Constant:
171 {
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000173 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 break;
175 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000176 case LayerType::ConvertBf16ToFp32:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000180 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000181 break;
182 }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 case LayerType::ConvertFp16ToFp32:
184 {
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000187 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100188 break;
189 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000190 case LayerType::ConvertFp32ToBf16:
191 {
192 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
193 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000194 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000195 break;
196 }
telsoa01c577f2c2018-08-31 09:22:23 +0100197 case LayerType::ConvertFp32ToFp16:
198 {
199 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
200 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000201 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000202 break;
203 }
204 case LayerType::Convolution2d:
205 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100206 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100207
208 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
209 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100210 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100211 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100212
arovir01a6824102018-08-28 17:40:45 +0100213 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100214
arovir01a6824102018-08-28 17:40:45 +0100215 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100216 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100217 if (descriptor.m_BiasEnabled)
218 {
David Beck5eec11d2018-10-04 15:43:17 +0100219 biases =
220 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100221 }
222
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000223 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100224 input,
225 output,
226 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100227 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100228 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100229 reason);
telsoa014fcda012018-03-09 14:13:49 +0000230 break;
231 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000232 case LayerType::Debug:
233 {
234 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
235 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
236
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000237 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000238 OverrideDataType(output, dataType),
239 reason);
240 break;
241 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100242 case LayerType::DepthToSpace:
243 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100244 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100245
246 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
248
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000249 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100250 OverrideDataType(output, dataType),
251 cLayer->GetParameters(),
252 reason);
253 break;
254 }
telsoa014fcda012018-03-09 14:13:49 +0000255 case LayerType::DepthwiseConvolution2d:
256 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100257 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100258 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
259 dataType);
260 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100261 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100262
telsoa01c577f2c2018-08-31 09:22:23 +0100263 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100264
265 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100266 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100267 if (descriptor.m_BiasEnabled)
268 {
David Beck5eec11d2018-10-04 15:43:17 +0100269 biases =
270 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100271 }
telsoa01c577f2c2018-08-31 09:22:23 +0100272
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000273 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100274 input,
275 output,
276 descriptor,
277 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100278 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100279 reason);
telsoa014fcda012018-03-09 14:13:49 +0000280 break;
281 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000282 case LayerType::Dequantize:
283 {
284 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
285 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
286
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000287 result = layerSupportObject.IsDequantizeSupported(input,
288 OverrideDataType(output, dataType),
289 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000290 break;
291 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000292 case LayerType::DetectionPostProcess:
293 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100294 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000295 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
296 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
297 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
298
299 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
300 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
301 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
302 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
303
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000304 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000305 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
306 scores,
307 anchors,
308 detectionBoxes,
309 detectionClasses,
310 detectionScores,
311 numDetections,
312 descriptor,
313 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000314 break;
315 }
josh minor4a3c6102020-01-06 16:40:46 -0600316 case LayerType::ElementwiseUnary:
317 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100318 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600319
320 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
321 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
322
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000323 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
324 OverrideDataType(output, dataType),
325 cLayer->GetParameters(),
326 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600327 break;
328 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100329 case LayerType::Fill:
330 {
331 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
332 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
333 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
334 const FillDescriptor& descriptor = cLayer->GetParameters();
335
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000336 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100337 OverrideDataType(input, dataType),
338 OverrideDataType(output, dataType),
339 descriptor,
340 reason);
341 break;
342 }
telsoa014fcda012018-03-09 14:13:49 +0000343 case LayerType::FakeQuantization:
344 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100345 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000346 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000347 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
348 cLayer->GetParameters(),
349 reason);
telsoa014fcda012018-03-09 14:13:49 +0000350 break;
351 }
352 case LayerType::Floor:
353 {
354 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
355 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000356 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
357 OverrideDataType(output, dataType),
358 reason);
telsoa014fcda012018-03-09 14:13:49 +0000359 break;
360 }
361 case LayerType::FullyConnected:
362 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100363 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000364 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100365 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000366
367 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
368 TensorInfo weightsInfo;
369 const TensorInfo* weightsInfoPtr = nullptr;
370
Matthew Sloyan81beae32021-07-13 19:46:11 +0100371 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000372 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100373
374 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000375 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000376 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100377 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
378 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
379 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
380
telsoa01c577f2c2018-08-31 09:22:23 +0100381 if (descriptor.m_BiasEnabled)
382 {
Matthew Sloyan81beae32021-07-13 19:46:11 +0100383 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
384 biasInfoPtr = &biasInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100385 }
386 else
387 {
388 // If biases are not enabled pass a dummy tensorinfo for the validation
389 switch(input.GetDataType())
390 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000391 case DataType::BFloat16:
392 {
393 biasInfoPtr = &dummyBFloat16Bias;
394 break;
395 }
telsoa01c577f2c2018-08-31 09:22:23 +0100396 case DataType::Float16:
397 {
398 biasInfoPtr = &dummyFloat16Bias;
399 break;
400 }
401 case DataType::Float32:
402 {
403 biasInfoPtr = &dummyFloat32Bias;
404 break;
405 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000406 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000407 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000408 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000409 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100410 {
411 biasInfoPtr = &dummyQA8Bias;
412 break;
413 }
414 default:
415 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100416 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100417 }
418 }
419 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000420 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100421 OverrideDataType(input, dataType),
422 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000423 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100424 *biasInfoPtr,
425 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100426 reason);
telsoa014fcda012018-03-09 14:13:49 +0000427 break;
428 }
narpra01b89b05f2019-01-16 09:53:09 +0000429 case LayerType::Gather:
430 {
431 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
432 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
433 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100434 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
435 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000436 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
437 input1,
438 OverrideDataType(output, dataType),
439 descriptor,
440 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000441 break;
442 }
telsoa014fcda012018-03-09 14:13:49 +0000443 case LayerType::Input:
444 {
445 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000446 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000447 break;
448 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100449 case LayerType::InstanceNormalization:
450 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100451 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100452 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
453
454 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
455 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
456
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000457 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100458 OverrideDataType(input, dataType),
459 OverrideDataType(output, dataType),
460 descriptor,
461 reason);
462 break;
463 }
telsoa014fcda012018-03-09 14:13:49 +0000464 case LayerType::L2Normalization:
465 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100466 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100467 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
468
telsoa014fcda012018-03-09 14:13:49 +0000469 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100470 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100471
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000472 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100473 OverrideDataType(input, dataType),
474 OverrideDataType(output, dataType),
475 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100476 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100477 break;
478 }
James Conroyaba90cd2020-11-06 16:28:18 +0000479 case LayerType::LogicalBinary:
480 {
481 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
482
483 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
485 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
486
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000487 result = layerSupportObject.IsLogicalBinarySupported(input0,
488 input1,
489 output,
490 cLayer->GetParameters(),
491 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000492 break;
493 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100494 case LayerType::LogSoftmax:
495 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100496 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100497
498 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
499 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
500
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000501 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
502 OverrideDataType(output, dataType),
503 cLayer->GetParameters(),
504 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100505 break;
506 }
telsoa01c577f2c2018-08-31 09:22:23 +0100507 case LayerType::Lstm:
508 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100509 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100510 const LstmDescriptor& descriptor = cLayer->GetParameters();
511
512 // All inputs.
513 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
514 dataType);
515 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
516 dataType);
517 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
518 dataType);
519 // All outputs
520 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
521 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
522 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
523 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
524
525 // Basic parameters
526 const TensorInfo& inputToForgetWeights
527 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
528 const TensorInfo& inputToCellWeights
529 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
530 const TensorInfo& inputToOutputWeights
531 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
532 const TensorInfo& recurrentToForgetWeights
533 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
534 const TensorInfo& recurrentToCellWeights
535 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
536 const TensorInfo& recurrentToOutputWeights
537 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
538 const TensorInfo& forgetGateBias
539 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
540 const TensorInfo& cellBias
541 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
542 const TensorInfo& outputGateBias
543 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
544
Jan Eilersd01a83c2019-07-03 18:20:40 +0100545 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100546
Jan Eilersd01a83c2019-07-03 18:20:40 +0100547 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
548 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
549 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
550 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
551 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
552 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
553 paramsInfo.m_ForgetGateBias = &forgetGateBias;
554 paramsInfo.m_CellBias = &cellBias;
555 paramsInfo.m_OutputGateBias = &outputGateBias;
556
557
558 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100559 TensorInfo optInputToInputWeights;
560 TensorInfo optRecurrentToInputWeights;
561 TensorInfo optCellToInputWeights;
562 TensorInfo optInputGateBias;
563 TensorInfo optProjectionWeights;
564 TensorInfo optProjectionBias;
565 TensorInfo optCellToForgetWeights;
566 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100567 TensorInfo optInputLayerNormWeights;
568 TensorInfo optForgetLayerNormWeights;
569 TensorInfo optCellLayerNormWeights;
570 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100571
572 if(!descriptor.m_CifgEnabled)
573 {
574 optInputToInputWeights =
575 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100576 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100577
578 optRecurrentToInputWeights =
579 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100580 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100581 optInputGateBias =
582 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100583 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100584 }
585
586 if(descriptor.m_ProjectionEnabled)
587 {
588 optProjectionWeights =
589 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100590 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100591 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
592 {
593 optProjectionBias =
594 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100595 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100596 }
597 }
598
599 if(descriptor.m_PeepholeEnabled)
600 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100601 if(!descriptor.m_CifgEnabled)
602 {
603 optCellToInputWeights =
604 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
605 dataType);
606 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
607 }
telsoa01c577f2c2018-08-31 09:22:23 +0100608 optCellToForgetWeights =
609 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100610 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100611 optCellToOutputWeights =
612 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100613 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100614 }
615
Jan Eilers38e05bd2019-06-26 13:10:09 +0100616 if(descriptor.m_LayerNormEnabled)
617 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100618 if (!descriptor.m_CifgEnabled)
619 {
620 optInputLayerNormWeights = OverrideDataType(
621 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
622 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
623 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100624
625 optForgetLayerNormWeights = OverrideDataType(
626 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100627 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100628
629 optCellLayerNormWeights = OverrideDataType(
630 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100631 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100632
633 optOutputLayerNormWeights = OverrideDataType(
634 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100635 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100636 }
637
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000638 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100639 input,
640 outputStateIn,
641 cellStateIn,
642 scratchBuffer,
643 outputStateOut,
644 cellStateOut,
645 output,
646 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100647 paramsInfo,
648 reason);
telsoa014fcda012018-03-09 14:13:49 +0000649 break;
650 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000651 case LayerType::Maximum:
652 {
653 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
654 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
655 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
656
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000657 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
658 OverrideDataType(input1, dataType),
659 OverrideDataType(output, dataType),
660 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000661 break;
662 }
narpra01b89b05f2019-01-16 09:53:09 +0000663 case LayerType::MemCopy:
664 {
665 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
666 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000667
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000668 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
669 OverrideDataType(output, dataType),
670 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000671 break;
672 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100673 case LayerType::MemImport:
674 {
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
677
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000678 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
679 OverrideDataType(output, dataType),
680 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100681 break;
682 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100683 case LayerType::Merge:
684 {
685 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
686 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
687 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
688
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000689 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
690 OverrideDataType(input1, dataType),
691 OverrideDataType(output, dataType),
692 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100693 break;
694 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100695 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000696 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100697 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000698
telsoa01c577f2c2018-08-31 09:22:23 +0100699 // Get vector of all inputs.
700 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000701 {
telsoa01c577f2c2018-08-31 09:22:23 +0100702 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000703 };
Finn Williams3e54d032020-10-22 16:53:35 +0100704
705 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
706 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100707 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000708
telsoa01c577f2c2018-08-31 09:22:23 +0100709 auto getTensorInfoPtr = [](const TensorInfo& info)
710 {
711 return &info;
712 };
Finn Williams3e54d032020-10-22 16:53:35 +0100713
714 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
715 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100716 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000717
Nikhil Raj8599a412018-11-19 14:51:07 +0000718 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
719
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000720 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100721
722
telsoa014fcda012018-03-09 14:13:49 +0000723 break;
724 }
725 case LayerType::Multiplication:
726 {
727 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
728 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100729 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000730 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100731 OverrideDataType(input0, dataType),
732 OverrideDataType(input1, dataType),
733 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100734 reason);
telsoa014fcda012018-03-09 14:13:49 +0000735 break;
736 }
737 case LayerType::Normalization:
738 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100739 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000740 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
741 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000742 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
743 OverrideDataType(output, dataType),
744 cLayer->GetParameters(),
745 reason);
telsoa014fcda012018-03-09 14:13:49 +0000746 break;
747 }
748 case LayerType::Output:
749 {
750 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000751 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000752 break;
753 }
754 case LayerType::Permute:
755 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100756 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000757 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000759 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
760 OverrideDataType(output, dataType),
761 cLayer->GetParameters(),
762 reason);
telsoa014fcda012018-03-09 14:13:49 +0000763 break;
764 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100765 case LayerType::Pad:
766 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100767 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100768 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
769 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000770 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100771 OverrideDataType(input, dataType),
772 OverrideDataType(output, dataType),
773 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100774 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100775 break;
776 }
telsoa014fcda012018-03-09 14:13:49 +0000777 case LayerType::Pooling2d:
778 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100779 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000780 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
781 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000782 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
783 OverrideDataType(output, dataType),
784 cLayer->GetParameters(),
785 reason);
telsoa014fcda012018-03-09 14:13:49 +0000786 break;
787 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000788 case LayerType::PreCompiled:
789 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100790 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000791 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000792 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
793 cLayer->GetParameters(),
794 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000795 break;
796 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000797 case LayerType::Quantize:
798 {
799 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
800 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000801 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000802 break;
803 }
James Conroy586a9aa2020-03-20 08:49:33 +0000804 case LayerType::QLstm:
805 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100806 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000807 const QLstmDescriptor& descriptor = cLayer->GetParameters();
808
809 // Inputs
810 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
811 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
812 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
813
814 // Outputs
815 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
816 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
817 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
818
819 // Lstm parameters
820 LstmInputParamsInfo paramsInfo;
821
822 // Basic parameters
Matthew Bentham6f24b1a2021-06-29 15:18:32 +0100823 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() != nullptr);
824 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() != nullptr);
825 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() != nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000826 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
827 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
828 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
829
830 paramsInfo.m_RecurrentToForgetWeights =
831 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
832 paramsInfo.m_RecurrentToCellWeights =
833 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
834 paramsInfo.m_RecurrentToOutputWeights =
835 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
836
837 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
838 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
839 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
840
841 if(!descriptor.m_CifgEnabled)
842 {
843 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
844 paramsInfo.m_RecurrentToInputWeights =
845 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
846 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
847 }
848
849 if(descriptor.m_ProjectionEnabled)
850 {
851 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100852
853 // Projection bias is optional even if projection is enabled
854 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
855 {
856 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
857 }
James Conroy586a9aa2020-03-20 08:49:33 +0000858 }
859
860 if(descriptor.m_PeepholeEnabled)
861 {
862 if (!descriptor.m_CifgEnabled)
863 {
864 paramsInfo.m_CellToInputWeights =
865 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
866 }
867
868 paramsInfo.m_CellToForgetWeights =
869 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
870 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
871 }
872
873 if(descriptor.m_LayerNormEnabled)
874 {
875 if (!descriptor.m_CifgEnabled)
876 {
877 paramsInfo.m_InputLayerNormWeights =
878 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
879 }
880
881 paramsInfo.m_ForgetLayerNormWeights =
882 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
883 paramsInfo.m_CellLayerNormWeights =
884 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
885 paramsInfo.m_OutputLayerNormWeights =
886 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
887 }
888
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000889 result = layerSupportObject.IsQLstmSupported(input,
890 previousOutputIn,
891 previousCellStateIn,
892 outputStateOut,
893 cellStateOut,
894 output,
895 descriptor,
896 paramsInfo,
897 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000898 break;
899 }
James Conroyee18dc82019-07-17 11:27:46 +0100900 case LayerType::QuantizedLstm:
901 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100902 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100903
904 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100905 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
906 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
907 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100908
909 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100910 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
911 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100912
913 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100914 QuantizedLstmInputParamsInfo paramsInfo;
915
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100916 paramsInfo.m_InputToInputWeights =
917 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
918 paramsInfo.m_InputToForgetWeights =
919 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
920 paramsInfo.m_InputToCellWeights =
921 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
922 paramsInfo.m_InputToOutputWeights =
923 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100924
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100925 paramsInfo.m_RecurrentToInputWeights =
926 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
927 paramsInfo.m_RecurrentToForgetWeights =
928 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
929 paramsInfo.m_RecurrentToCellWeights =
930 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
931 paramsInfo.m_RecurrentToOutputWeights =
932 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100933
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100934 paramsInfo.m_InputGateBias =
935 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
936 paramsInfo.m_ForgetGateBias =
937 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
938 paramsInfo.m_CellBias =
939 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
940 paramsInfo.m_OutputGateBias =
941 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100942
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000943 result = layerSupportObject.IsQuantizedLstmSupported(input,
944 previousCellStateIn,
945 previousOutputIn,
946 cellStateOut,
947 output,
948 paramsInfo,
949 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100950 break;
951 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100952 case LayerType::Division:
953 {
954 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
955 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
956 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000957 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100958 OverrideDataType(input0, dataType),
959 OverrideDataType(input1, dataType),
960 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100961 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100962 break;
963 }
Finn Williams2605b232020-06-10 15:53:46 +0100964 case LayerType::Rank:
965 {
966 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
967 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000968 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
969 OverrideDataType(output, dataType),
970 reason);
Finn Williams2605b232020-06-10 15:53:46 +0100971 break;
972 }
telsoa014fcda012018-03-09 14:13:49 +0000973 case LayerType::Reshape:
974 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100975 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000976 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000977 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000978 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
979 OverrideDataType(output, dataType),
980 cLayer->GetParameters(),
981 reason);
telsoa014fcda012018-03-09 14:13:49 +0000982 break;
983 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100984 case LayerType::Resize:
985 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100986 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100987 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100988 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000989 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
990 OverrideDataType(output, dataType),
991 cLayer->GetParameters(),
992 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +0100993 break;
994 }
Keith Davis3ae3f972021-05-21 16:33:48 +0100995 case LayerType::Shape:
996 {
997 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
998 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
999
1000 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1001 OverrideDataType(output, dataType),
1002 reason);
1003 break;
1004 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001005 case LayerType::Slice:
1006 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001007 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001008
1009 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1010 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1011
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001012 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1013 OverrideDataType(output, dataType),
1014 cLayer->GetParameters(),
1015 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001016 break;
1017 }
telsoa014fcda012018-03-09 14:13:49 +00001018 case LayerType::Softmax:
1019 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001020 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001021 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001022 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001023 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1024 OverrideDataType(output, dataType),
1025 cLayer->GetParameters(),
1026 reason);
telsoa014fcda012018-03-09 14:13:49 +00001027 break;
1028 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001029 case LayerType::SpaceToBatchNd:
1030 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001031 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001032 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1033 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001034 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1035 OverrideDataType(output, dataType),
1036 cLayer->GetParameters(),
1037 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001038 break;
1039 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001040 case LayerType::SpaceToDepth:
1041 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001042 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001043
1044 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1045 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1046
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001047 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1048 OverrideDataType(output, dataType),
1049 cLayer->GetParameters(),
1050 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001051 break;
1052 }
telsoa014fcda012018-03-09 14:13:49 +00001053 case LayerType::Splitter:
1054 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001055 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001056 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001057
1058 // Get vector of all outputs.
1059 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1060 {
1061 return OverrideDataType(slot.GetTensorInfo(), dataType);
1062 };
Finn Williams3e54d032020-10-22 16:53:35 +01001063 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1064 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001065 std::vector<TensorInfo> outputs(beginI, endI);
1066
1067 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1068
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001069 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1070 outputPtrs,
1071 cLayer->GetParameters(),
1072 reason);
telsoa014fcda012018-03-09 14:13:49 +00001073 break;
1074 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001075 case LayerType::Stack:
1076 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001077 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001078
1079 // Get vector of all inputs.
1080 auto getTensorInfo = [&dataType](const InputSlot& slot)
1081 {
1082 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1083 };
Finn Williams3e54d032020-10-22 16:53:35 +01001084 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1085 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001086 std::vector<TensorInfo> inputs(beginI, endI);
1087
1088 auto getTensorInfoPtr = [](const TensorInfo& info)
1089 {
1090 return &info;
1091 };
Finn Williams3e54d032020-10-22 16:53:35 +01001092 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1093 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001094 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1095
1096 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1097
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001098 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001099
1100 break;
1101 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001102 case LayerType::StandIn:
1103 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001104 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001105
1106 // Get vector of all inputs.
1107 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1108 {
1109 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1110 };
1111 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1112 {
1113 return OverrideDataType(slot.GetTensorInfo(), dataType);
1114 };
Finn Williams3e54d032020-10-22 16:53:35 +01001115 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1116 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001117 std::vector<TensorInfo> inputs(beginI, endI);
1118
Finn Williams3e54d032020-10-22 16:53:35 +01001119 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1120 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001121 std::vector<TensorInfo> outputs(beginO, endO);
1122
1123
1124 auto getTensorInfoPtr = [](const TensorInfo& info)
1125 {
1126 return &info;
1127 };
Finn Williams3e54d032020-10-22 16:53:35 +01001128 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1129 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001130 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1131
Finn Williams3e54d032020-10-22 16:53:35 +01001132 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1133 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001134 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1135
1136
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001137 result = layerSupportObject.IsStandInSupported(inputPtrs,
1138 outputPtrs,
1139 cLayer->GetParameters(),
1140 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001141 break;
1142 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001143 case LayerType::StridedSlice:
1144 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001145 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001146 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001148 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1149 OverrideDataType(output, dataType),
1150 cLayer->GetParameters(),
1151 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001152 break;
1153 }
David Beckc2044fe2018-09-05 15:00:38 +01001154 case LayerType::Subtraction:
1155 {
1156 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1157 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001159 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001160 OverrideDataType(input0, dataType),
1161 OverrideDataType(input1, dataType),
1162 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001163 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001164 break;
1165 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001166 case LayerType::Switch:
1167 {
1168 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1169 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1170 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1171 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001172 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1173 OverrideDataType(input1, dataType),
1174 OverrideDataType(output0, dataType),
1175 OverrideDataType(output1, dataType),
1176 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001177 break;
1178 }
narpra0132b90462018-09-13 11:07:48 +01001179 case LayerType::Mean:
1180 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001181 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001182 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1183 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001184 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001185 OverrideDataType(input, dataType),
1186 OverrideDataType(output, dataType),
1187 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001188 reason);
narpra0132b90462018-09-13 11:07:48 +01001189 break;
1190 }
kevmay0190539692018-11-29 08:40:19 +00001191 case LayerType::Minimum:
1192 {
1193 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1194 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1195 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001196 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1197 OverrideDataType(input1, dataType),
1198 OverrideDataType(output, dataType),
1199 reason);
kevmay0190539692018-11-29 08:40:19 +00001200 break;
1201 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001202 case LayerType::Prelu:
1203 {
1204 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1205 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1206 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001207 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1208 OverrideDataType(alpha, dataType),
1209 OverrideDataType(output, dataType),
1210 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001211 break;
1212 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001213 case LayerType::Transpose:
1214 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001215 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001216 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1217 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001218 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1219 OverrideDataType(output, dataType),
1220 cLayer->GetParameters(),
1221 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001222 break;
1223 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001224 case LayerType::TransposeConvolution2d:
1225 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001226 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001227
1228 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1229 dataType);
1230 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1231
1232 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1233
1234 Optional<TensorInfo> biases;
1235 if (descriptor.m_BiasEnabled)
1236 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001237 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001238 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1239 GetBiasTypeFromWeightsType(dataType));
1240 }
1241
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001242 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001243 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1244
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001245 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1246 output,
1247 descriptor,
1248 weights,
1249 biases,
1250 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001251
1252 break;
1253 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001254 case LayerType::Reduce:
1255 {
1256 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1259
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001260 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1261 OverrideDataType(output, dataType),
1262 cLayer->GetParameters(),
1263 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001264 break;
1265 }
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001266 case LayerType::UnidirectionalSequenceLstm:
1267 {
1268 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1269 const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
1270
1271 // All inputs.
1272 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1273 dataType);
1274 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1275 dataType);
1276 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1277 dataType);
1278 // Outputs
1279 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1280
1281 // Basic parameters
1282 const TensorInfo& inputToForgetWeights
1283 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1284 const TensorInfo& inputToCellWeights
1285 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1286 const TensorInfo& inputToOutputWeights
1287 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1288 const TensorInfo& recurrentToForgetWeights
1289 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1290 const TensorInfo& recurrentToCellWeights
1291 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1292 const TensorInfo& recurrentToOutputWeights
1293 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1294 const TensorInfo& forgetGateBias
1295 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1296 const TensorInfo& cellBias
1297 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1298 const TensorInfo& outputGateBias
1299 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1300
1301 LstmInputParamsInfo paramsInfo;
1302
1303 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1304 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1305 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1306 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1307 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1308 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1309 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1310 paramsInfo.m_CellBias = &cellBias;
1311 paramsInfo.m_OutputGateBias = &outputGateBias;
1312
1313 // Optional parameters
1314 TensorInfo optInputToInputWeights;
1315 TensorInfo optRecurrentToInputWeights;
1316 TensorInfo optCellToInputWeights;
1317 TensorInfo optInputGateBias;
1318 TensorInfo optProjectionWeights;
1319 TensorInfo optProjectionBias;
1320 TensorInfo optCellToForgetWeights;
1321 TensorInfo optCellToOutputWeights;
1322 TensorInfo optInputLayerNormWeights;
1323 TensorInfo optForgetLayerNormWeights;
1324 TensorInfo optCellLayerNormWeights;
1325 TensorInfo optOutputLayerNormWeights;
1326
1327 if(!descriptor.m_CifgEnabled)
1328 {
1329 optInputToInputWeights =
1330 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1331 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1332
1333 optRecurrentToInputWeights =
1334 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1335 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1336 optInputGateBias =
1337 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1338 paramsInfo.m_InputGateBias = &optInputGateBias;
1339 }
1340
1341 if(descriptor.m_ProjectionEnabled)
1342 {
1343 optProjectionWeights =
1344 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1345 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1346 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
1347 {
1348 optProjectionBias =
1349 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1350 paramsInfo.m_ProjectionBias = &optProjectionBias;
1351 }
1352 }
1353
1354 if(descriptor.m_PeepholeEnabled)
1355 {
1356 if(!descriptor.m_CifgEnabled)
1357 {
1358 optCellToInputWeights =
1359 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1360 dataType);
1361 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1362 }
1363 optCellToForgetWeights =
1364 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1365 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1366 optCellToOutputWeights =
1367 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1368 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1369 }
1370
1371 if(descriptor.m_LayerNormEnabled)
1372 {
1373 if (!descriptor.m_CifgEnabled)
1374 {
1375 optInputLayerNormWeights = OverrideDataType(
1376 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1377 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1378 }
1379
1380 optForgetLayerNormWeights = OverrideDataType(
1381 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1382 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1383
1384 optCellLayerNormWeights = OverrideDataType(
1385 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1386 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1387
1388 optOutputLayerNormWeights = OverrideDataType(
1389 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1390 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1391 }
1392
1393 Optional<TensorInfo> hiddenStateOut;
1394 Optional<TensorInfo> cellStateOut;
1395
1396 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1397 outputStateIn,
1398 cellStateIn,
1399 output,
1400 hiddenStateOut,
1401 cellStateOut,
1402 descriptor,
1403 paramsInfo,
1404 reason);
1405 break;
1406 }
telsoa014fcda012018-03-09 14:13:49 +00001407 default:
1408 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001409 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001410 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001411 result = false;
1412 break;
1413 }
1414 }
telsoa014fcda012018-03-09 14:13:49 +00001415 return result;
1416}
1417
Sadik Armagan045f6be2020-09-10 13:37:32 +01001418bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1419 const IConnectableLayer& connectableLayer,
1420 Optional<DataType> dataType,
1421 std::string& outReasonIfUnsupported)
1422{
1423 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1424}
1425
David Beckdcb751f2018-10-03 11:42:42 +01001426bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001427 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001428 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001429{
Jan Eilersbb446e52020-04-02 13:56:54 +01001430 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001431 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1432}
1433
1434// TODO merge with defaulted modelOptions above
1435bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1436 Optional<DataType> dataType,
1437 std::string& outReasonIfUnsupported,
1438 const ModelOptions& modelOptions)
1439{
1440 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1441 return IsLayerConfigurationSupported(layer->GetBackendId(),
1442 connectableLayer,
1443 dataType,
1444 outReasonIfUnsupported,
1445 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001446}
1447
Sadik Armagan04a72972020-09-14 15:44:18 +01001448bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1449 const IConnectableLayer& connectableLayer,
1450 Optional<DataType> dataType,
1451 std::string& outReasonIfUnsupported,
1452 const ModelOptions& modelOptions)
1453{
1454 return IsLayerConfigurationSupported(backendId,
1455 connectableLayer,
1456 dataType,
1457 outReasonIfUnsupported,
1458 modelOptions);
1459}
1460
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001461// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001462std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1463 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001464{
1465 return std::unique_ptr<IWorkload>();
1466}
1467
Derek Lamberti901ea112019-12-10 22:07:09 +00001468std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1469 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001470{
1471 return std::unique_ptr<IWorkload>();
1472}
1473
Derek Lamberti901ea112019-12-10 22:07:09 +00001474std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1475 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001476{
1477 return std::unique_ptr<IWorkload>();
1478}
1479
Derek Lamberti901ea112019-12-10 22:07:09 +00001480std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1481 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001482{
1483 return std::unique_ptr<IWorkload>();
1484}
1485
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001486std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001487 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001488{
1489 return std::unique_ptr<IWorkload>();
1490}
1491
Derek Lamberti901ea112019-12-10 22:07:09 +00001492std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1493 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001494{
1495 return std::unique_ptr<IWorkload>();
1496}
1497
mathad01b392e982021-04-07 12:07:30 +01001498std::unique_ptr<IWorkload> IWorkloadFactory::CreateCast(const CastQueueDescriptor& /*descriptor*/,
1499 const WorkloadInfo& /*info*/) const
1500{
1501 return std::unique_ptr<IWorkload>();
1502}
1503
Derek Lamberti901ea112019-12-10 22:07:09 +00001504std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1505 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001506{
1507 return std::unique_ptr<IWorkload>();
1508}
1509
Derek Lamberti901ea112019-12-10 22:07:09 +00001510std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1511 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001512{
1513 return std::unique_ptr<IWorkload>();
1514}
1515
Derek Lamberti901ea112019-12-10 22:07:09 +00001516std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1517 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001518{
1519 return std::unique_ptr<IWorkload>();
1520}
1521
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001522std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1523 const WorkloadInfo& /*info*/) const
1524{
1525 return std::unique_ptr<IWorkload>();
1526}
1527
Derek Lamberti901ea112019-12-10 22:07:09 +00001528std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1529 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001530{
1531 return std::unique_ptr<IWorkload>();
1532}
1533
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001534std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1535 const WorkloadInfo& /*info*/) const
1536{
1537 return std::unique_ptr<IWorkload>();
1538}
1539
Derek Lamberti901ea112019-12-10 22:07:09 +00001540std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1541 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001542{
1543 return std::unique_ptr<IWorkload>();
1544}
1545
Derek Lamberti901ea112019-12-10 22:07:09 +00001546std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1547 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001548{
1549 return std::unique_ptr<IWorkload>();
1550}
1551
Derek Lamberti901ea112019-12-10 22:07:09 +00001552std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1553 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001554{
1555 return std::unique_ptr<IWorkload>();
1556}
1557
Derek Lamberti901ea112019-12-10 22:07:09 +00001558std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1559 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001560{
1561 return std::unique_ptr<IWorkload>();
1562}
1563
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001564std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001565 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001566{
1567 return std::unique_ptr<IWorkload>();
1568}
1569
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001570std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001571 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001572{
1573 return std::unique_ptr<IWorkload>();
1574}
1575
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001576std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001577 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001578{
1579 return std::unique_ptr<IWorkload>();
1580}
1581
Derek Lamberti901ea112019-12-10 22:07:09 +00001582std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1583 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001584{
1585 return std::unique_ptr<IWorkload>();
1586}
1587
josh minor4a3c6102020-01-06 16:40:46 -06001588std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1589 const WorkloadInfo& /*info*/) const
1590{
1591 return std::unique_ptr<IWorkload>();
1592}
1593
Derek Lamberti901ea112019-12-10 22:07:09 +00001594std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1595 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001596{
1597 return std::unique_ptr<IWorkload>();
1598}
1599
Derek Lamberti901ea112019-12-10 22:07:09 +00001600std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1601 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001602{
1603 return std::unique_ptr<IWorkload>();
1604}
1605
Ryan OSheaec6c6802020-06-05 17:17:06 +01001606std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1607 const WorkloadInfo& /*info*/) const
1608{
1609 return std::unique_ptr<IWorkload>();
1610}
1611
Derek Lamberti901ea112019-12-10 22:07:09 +00001612std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1613 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001614{
1615 return std::unique_ptr<IWorkload>();
1616}
1617
Derek Lamberti901ea112019-12-10 22:07:09 +00001618std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1619 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001620{
1621 return std::unique_ptr<IWorkload>();
1622}
1623
Derek Lamberti901ea112019-12-10 22:07:09 +00001624std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1625 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001626{
1627 return std::unique_ptr<IWorkload>();
1628}
1629
Derek Lamberti901ea112019-12-10 22:07:09 +00001630std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1631 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001632{
1633 return std::unique_ptr<IWorkload>();
1634}
1635
Kevin Mayce5045a2019-10-02 14:07:47 +01001636std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001637 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1638 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001639{
1640 return std::unique_ptr<IWorkload>();
1641}
1642
Derek Lamberti901ea112019-12-10 22:07:09 +00001643std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1644 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001645{
1646 return std::unique_ptr<IWorkload>();
1647}
1648
James Conroyaba90cd2020-11-06 16:28:18 +00001649std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1650 const WorkloadInfo& /*info*/) const
1651{
1652 return std::unique_ptr<IWorkload>();
1653}
1654
1655std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1656 const WorkloadInfo& /*info*/) const
1657{
1658 return std::unique_ptr<IWorkload>();
1659}
1660
Derek Lamberti901ea112019-12-10 22:07:09 +00001661std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1662 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001663{
1664 return std::unique_ptr<IWorkload>();
1665}
1666
Derek Lamberti901ea112019-12-10 22:07:09 +00001667std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1668 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001669{
1670 return std::unique_ptr<IWorkload>();
1671}
1672
Derek Lamberti901ea112019-12-10 22:07:09 +00001673std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1674 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001675{
1676 return std::unique_ptr<IWorkload>();
1677}
1678
Derek Lamberti901ea112019-12-10 22:07:09 +00001679std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1680 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001681{
1682 return std::unique_ptr<IWorkload>();
1683}
1684
Derek Lamberti901ea112019-12-10 22:07:09 +00001685std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1686 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001687{
1688 return std::unique_ptr<IWorkload>();
1689}
1690
Derek Lamberti901ea112019-12-10 22:07:09 +00001691std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1692 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001693{
1694 return std::unique_ptr<IWorkload>();
1695}
1696
Derek Lamberti901ea112019-12-10 22:07:09 +00001697std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1698 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001699{
1700 return std::unique_ptr<IWorkload>();
1701}
1702
Derek Lamberti901ea112019-12-10 22:07:09 +00001703std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1704 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001705{
1706 return std::unique_ptr<IWorkload>();
1707}
1708
Derek Lamberti901ea112019-12-10 22:07:09 +00001709std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1710 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001711{
1712 return std::unique_ptr<IWorkload>();
1713}
1714
Derek Lamberti901ea112019-12-10 22:07:09 +00001715std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1716 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001717{
1718 return std::unique_ptr<IWorkload>();
1719}
1720
Derek Lamberti901ea112019-12-10 22:07:09 +00001721std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1722 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001723{
1724 return std::unique_ptr<IWorkload>();
1725}
1726
Derek Lamberti901ea112019-12-10 22:07:09 +00001727std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1728 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001729{
1730 return std::unique_ptr<IWorkload>();
1731}
1732
Derek Lamberti901ea112019-12-10 22:07:09 +00001733std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1734 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001735{
1736 return std::unique_ptr<IWorkload>();
1737}
1738
Derek Lamberti901ea112019-12-10 22:07:09 +00001739std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001740 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001741{
1742 return std::unique_ptr<IWorkload>();
1743}
1744
Derek Lamberti901ea112019-12-10 22:07:09 +00001745std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1746 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001747{
1748 return std::unique_ptr<IWorkload>();
1749}
1750
Derek Lamberti901ea112019-12-10 22:07:09 +00001751std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1752 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001753{
1754 return std::unique_ptr<IWorkload>();
1755}
1756
Derek Lamberti901ea112019-12-10 22:07:09 +00001757std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1758 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001759{
1760 return std::unique_ptr<IWorkload>();
1761}
1762
Derek Lamberti901ea112019-12-10 22:07:09 +00001763std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1764 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001765{
1766 return std::unique_ptr<IWorkload>();
1767}
1768
James Conroy586a9aa2020-03-20 08:49:33 +00001769std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1770 const WorkloadInfo& /*info*/) const
1771{
1772 return std::unique_ptr<IWorkload>();
1773}
1774
Derek Lamberti901ea112019-12-10 22:07:09 +00001775std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1776 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001777{
1778 return std::unique_ptr<IWorkload>();
1779}
Finn Williams2605b232020-06-10 15:53:46 +01001780std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1781 const WorkloadInfo& /*info*/) const
1782{
1783 return std::unique_ptr<IWorkload>();
1784}
James Conroyee18dc82019-07-17 11:27:46 +01001785
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001786std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1787 const WorkloadInfo& /*info*/) const
1788{
1789 return std::unique_ptr<IWorkload>();
1790}
1791
Derek Lamberti901ea112019-12-10 22:07:09 +00001792std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1793 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001794{
1795 return std::unique_ptr<IWorkload>();
1796}
1797
Derek Lamberti901ea112019-12-10 22:07:09 +00001798std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1799 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001800{
1801 return std::unique_ptr<IWorkload>();
1802}
1803
Derek Lamberti901ea112019-12-10 22:07:09 +00001804std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1805 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001806{
1807 return std::unique_ptr<IWorkload>();
1808}
1809
Derek Lamberti901ea112019-12-10 22:07:09 +00001810std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1811 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001812{
1813 return std::unique_ptr<IWorkload>();
1814}
1815
Keith Davis3ae3f972021-05-21 16:33:48 +01001816std::unique_ptr<IWorkload> IWorkloadFactory::CreateShape(const ShapeQueueDescriptor& /*descriptor*/,
1817 const WorkloadInfo& /*info*/) const
1818{
1819 return std::unique_ptr<IWorkload>();
1820}
1821
Derek Lamberti901ea112019-12-10 22:07:09 +00001822std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1823 const WorkloadInfo& /*info*/) const
1824{
1825 return std::unique_ptr<IWorkload>();
1826}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001827
Derek Lamberti901ea112019-12-10 22:07:09 +00001828std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1829 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001830{
1831 return std::unique_ptr<IWorkload>();
1832}
1833
Derek Lamberti901ea112019-12-10 22:07:09 +00001834std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1835 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001836{
1837 return std::unique_ptr<IWorkload>();
1838}
1839
Derek Lamberti901ea112019-12-10 22:07:09 +00001840std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1841 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001842{
1843 return std::unique_ptr<IWorkload>();
1844}
1845
Derek Lamberti901ea112019-12-10 22:07:09 +00001846std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1847 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001848{
1849 return std::unique_ptr<IWorkload>();
1850}
1851
Derek Lamberti901ea112019-12-10 22:07:09 +00001852std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1853 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001854{
1855 return std::unique_ptr<IWorkload>();
1856}
1857
Derek Lamberti901ea112019-12-10 22:07:09 +00001858std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1859 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001860{
1861 return std::unique_ptr<IWorkload>();
1862}
1863
Derek Lamberti901ea112019-12-10 22:07:09 +00001864std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1865 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001866{
1867 return std::unique_ptr<IWorkload>();
1868}
1869
Derek Lamberti901ea112019-12-10 22:07:09 +00001870std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1871 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001872{
1873 return std::unique_ptr<IWorkload>();
1874}
1875
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001876std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1877 const WorkloadInfo& /*info*/) const
1878{
1879 return std::unique_ptr<IWorkload>();
1880}
1881
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001882std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001883 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1884 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001885{
1886 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001887}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001888
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001889std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
1890 const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
1891 const WorkloadInfo& /*info*/) const
1892{
1893 return std::unique_ptr<IWorkload>();
1894}
1895
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001896} // namepsace armnn