blob: 5e3eed086ac52a6b019b5d4d7e482c62ecae2e23 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010013#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010014#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000017#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000018
Francis Murtagh46c09d02019-05-28 08:15:28 +010019#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
David Beck111b5d92018-11-12 14:59:37 +000021#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000022
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
telsoa01c577f2c2018-08-31 09:22:23 +010026namespace
27{
Finn Williams3e54d032020-10-22 16:53:35 +010028using LayerList = std::list<Layer*>;
29using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010030
David Beck29c75de2018-10-23 13:35:58 +010031const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32{
33 if (!type)
34 {
35 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010036 }
37
David Beck29c75de2018-10-23 13:35:58 +010038 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010039}
40
David Beck29c75de2018-10-23 13:35:58 +010041} // anonymous namespace
42
Sadik Armagan045f6be2020-09-10 13:37:32 +010043bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
44 const IConnectableLayer& connectableLayer,
45 Optional<DataType> dataType,
46 std::string& outReasonIfUnsupported,
47 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000048{
David Beck33f0ae02018-10-18 15:13:56 +010049 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000050 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010051 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010052
David Beck111b5d92018-11-12 14:59:37 +000053 auto const& backendRegistry = BackendRegistryInstance();
54 if (!backendRegistry.IsBackendRegistered(backendId))
55 {
56 std::stringstream ss;
57 ss << connectableLayer.GetName() << " is not supported on " << backendId
58 << " because this backend is not registered.";
59
60 outReasonIfUnsupported = ss.str();
61 return false;
62 }
63
64 auto backendFactory = backendRegistry.GetFactory(backendId);
65 auto backendObject = backendFactory();
Sadik Armagan045f6be2020-09-10 13:37:32 +010066 auto layerSupportObject = backendObject->GetLayerSupport(modelOptions);
David Beck33f0ae02018-10-18 15:13:56 +010067
telsoa014fcda012018-03-09 14:13:49 +000068 switch(layer.GetType())
69 {
70 case LayerType::Activation:
71 {
Jan Eilersbb446e52020-04-02 13:56:54 +010072 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000073 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010074 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010075 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010076 OverrideDataType(input, dataType),
77 OverrideDataType(output, dataType),
78 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010079 reason);
telsoa014fcda012018-03-09 14:13:49 +000080 break;
81 }
82 case LayerType::Addition:
83 {
84 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
85 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
86 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010087 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010088 OverrideDataType(input0, dataType),
89 OverrideDataType(input1, dataType),
90 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010091 reason);
telsoa014fcda012018-03-09 14:13:49 +000092 break;
93 }
Nikhil Rajee391d52019-09-05 17:50:44 +010094 case LayerType::ArgMinMax:
95 {
Jan Eilersbb446e52020-04-02 13:56:54 +010096 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +010097 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
98
99 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
100 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
101 result = layerSupportObject->IsArgMinMaxSupported(
102 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000103 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100104 descriptor,
105 reason);
106 break;
107 }
telsoa014fcda012018-03-09 14:13:49 +0000108 case LayerType::BatchNormalization:
109 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100110 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000111 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100112 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
113 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
114 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
115 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
116 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100117 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100118 OverrideDataType(input, dataType),
119 OverrideDataType(output, dataType),
120 OverrideDataType(mean, dataType),
121 OverrideDataType(var, dataType),
122 OverrideDataType(beta, dataType),
123 OverrideDataType(gamma, dataType),
124 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100125 reason);
telsoa014fcda012018-03-09 14:13:49 +0000126 break;
127 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000128 case LayerType::BatchToSpaceNd:
129 {
130 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
131 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100132 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000133
134 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
135 OverrideDataType(output, dataType),
136 cLayer->GetParameters(),
137 reason);
138 break;
139 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100140 case LayerType::Comparison:
141 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100142 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100143
144 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
145 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
146 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
147
148 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
149 OverrideDataType(input1, dataType),
150 OverrideDataType(output, DataType::Boolean),
151 cLayer->GetParameters(),
152 reason);
153 break;
154 }
telsoa014fcda012018-03-09 14:13:49 +0000155 case LayerType::Constant:
156 {
157 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100158 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100159 break;
160 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000161 case LayerType::ConvertBf16ToFp32:
162 {
163 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
164 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
165 result = layerSupportObject->IsConvertBf16ToFp32Supported(input, output, reason);
166 break;
167 }
telsoa01c577f2c2018-08-31 09:22:23 +0100168 case LayerType::ConvertFp16ToFp32:
169 {
170 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
171 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100172 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100173 break;
174 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000175 case LayerType::ConvertFp32ToBf16:
176 {
177 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
178 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
179 result = layerSupportObject->IsConvertFp32ToBf16Supported(input, output, reason);
180 break;
181 }
telsoa01c577f2c2018-08-31 09:22:23 +0100182 case LayerType::ConvertFp32ToFp16:
183 {
184 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
185 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100186 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000187 break;
188 }
189 case LayerType::Convolution2d:
190 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100191 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100192
193 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
194 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100195 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100196 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100197
arovir01a6824102018-08-28 17:40:45 +0100198 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100199
arovir01a6824102018-08-28 17:40:45 +0100200 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100201 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100202 if (descriptor.m_BiasEnabled)
203 {
David Beck5eec11d2018-10-04 15:43:17 +0100204 biases =
205 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100206 }
207
David Beck33f0ae02018-10-18 15:13:56 +0100208 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100209 input,
210 output,
211 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100212 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100213 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100214 reason);
telsoa014fcda012018-03-09 14:13:49 +0000215 break;
216 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000217 case LayerType::Debug:
218 {
219 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
220 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
221
222 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
223 OverrideDataType(output, dataType),
224 reason);
225 break;
226 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100227 case LayerType::DepthToSpace:
228 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100229 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100230
231 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
232 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
233
234 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
235 OverrideDataType(output, dataType),
236 cLayer->GetParameters(),
237 reason);
238 break;
239 }
telsoa014fcda012018-03-09 14:13:49 +0000240 case LayerType::DepthwiseConvolution2d:
241 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100242 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100243 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
244 dataType);
245 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100246 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100247
telsoa01c577f2c2018-08-31 09:22:23 +0100248 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100249
250 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100251 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100252 if (descriptor.m_BiasEnabled)
253 {
David Beck5eec11d2018-10-04 15:43:17 +0100254 biases =
255 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100256 }
telsoa01c577f2c2018-08-31 09:22:23 +0100257
David Beck33f0ae02018-10-18 15:13:56 +0100258 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100259 input,
260 output,
261 descriptor,
262 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100263 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100264 reason);
telsoa014fcda012018-03-09 14:13:49 +0000265 break;
266 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000267 case LayerType::Dequantize:
268 {
269 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
270 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
271
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000272 result = layerSupportObject->IsDequantizeSupported(input,
273 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000274 reason);
275 break;
276 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000277 case LayerType::DetectionPostProcess:
278 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100279 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000280 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
282 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
283
284 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
285 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
286 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
287 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
288
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000289 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000290 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
291 scores,
292 anchors,
293 detectionBoxes,
294 detectionClasses,
295 detectionScores,
296 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000297 descriptor,
298 reason);
299 break;
300 }
josh minor4a3c6102020-01-06 16:40:46 -0600301 case LayerType::ElementwiseUnary:
302 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100303 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600304
305 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
306 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
307
308 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
309 OverrideDataType(output, dataType),
310 cLayer->GetParameters(),
311 reason);
312 break;
313 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100314 case LayerType::Fill:
315 {
316 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
317 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
318 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
319 const FillDescriptor& descriptor = cLayer->GetParameters();
320
321 result = layerSupportObject->IsFillSupported(
322 OverrideDataType(input, dataType),
323 OverrideDataType(output, dataType),
324 descriptor,
325 reason);
326 break;
327 }
telsoa014fcda012018-03-09 14:13:49 +0000328 case LayerType::FakeQuantization:
329 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100330 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000331 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100332 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
333 cLayer->GetParameters(),
334 reason);
telsoa014fcda012018-03-09 14:13:49 +0000335 break;
336 }
337 case LayerType::Floor:
338 {
339 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
340 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100341 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
342 OverrideDataType(output, dataType),
343 reason);
telsoa014fcda012018-03-09 14:13:49 +0000344 break;
345 }
346 case LayerType::FullyConnected:
347 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100348 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000349 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100350 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100351 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100352
353 TensorInfo biasInfo;
354 const TensorInfo * biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000355 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100356 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
357 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
358 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
359
360 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
361 if (descriptor.m_BiasEnabled)
362 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100363 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100364 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
365 biasInfoPtr = &biasInfo;
366 }
367 else
368 {
369 // If biases are not enabled pass a dummy tensorinfo for the validation
370 switch(input.GetDataType())
371 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000372 case DataType::BFloat16:
373 {
374 biasInfoPtr = &dummyBFloat16Bias;
375 break;
376 }
telsoa01c577f2c2018-08-31 09:22:23 +0100377 case DataType::Float16:
378 {
379 biasInfoPtr = &dummyFloat16Bias;
380 break;
381 }
382 case DataType::Float32:
383 {
384 biasInfoPtr = &dummyFloat32Bias;
385 break;
386 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000387 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000388 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000389 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000390 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100391 {
392 biasInfoPtr = &dummyQA8Bias;
393 break;
394 }
395 default:
396 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100397 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100398 }
399 }
400 }
401
David Beck33f0ae02018-10-18 15:13:56 +0100402 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100403 OverrideDataType(input, dataType),
404 OverrideDataType(output, dataType),
405 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
406 *biasInfoPtr,
407 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100408 reason);
telsoa014fcda012018-03-09 14:13:49 +0000409 break;
410 }
narpra01b89b05f2019-01-16 09:53:09 +0000411 case LayerType::Gather:
412 {
413 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
414 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
415 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100416 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
417 const GatherDescriptor& descriptor = cLayer->GetParameters();
narpra01b89b05f2019-01-16 09:53:09 +0000418 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100419 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000420 OverrideDataType(output, dataType),
Teresa Charlin52664732020-06-29 16:27:03 +0100421 descriptor,
narpra01b89b05f2019-01-16 09:53:09 +0000422 reason);
423 break;
424 }
telsoa014fcda012018-03-09 14:13:49 +0000425 case LayerType::Input:
426 {
427 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100428 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000429 break;
430 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100431 case LayerType::InstanceNormalization:
432 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100433 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100434 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
435
436 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
437 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
438
439 result = layerSupportObject->IsInstanceNormalizationSupported(
440 OverrideDataType(input, dataType),
441 OverrideDataType(output, dataType),
442 descriptor,
443 reason);
444 break;
445 }
telsoa014fcda012018-03-09 14:13:49 +0000446 case LayerType::L2Normalization:
447 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100448 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100449 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
450
telsoa014fcda012018-03-09 14:13:49 +0000451 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100452 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100453
David Beck33f0ae02018-10-18 15:13:56 +0100454 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100455 OverrideDataType(input, dataType),
456 OverrideDataType(output, dataType),
457 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100458 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100459 break;
460 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100461 case LayerType::LogSoftmax:
462 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100463 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100464
465 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
466 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
467
468 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
469 OverrideDataType(output, dataType),
470 cLayer->GetParameters(),
471 reason);
472 break;
473 }
telsoa01c577f2c2018-08-31 09:22:23 +0100474 case LayerType::Lstm:
475 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100476 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100477 const LstmDescriptor& descriptor = cLayer->GetParameters();
478
479 // All inputs.
480 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
481 dataType);
482 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
483 dataType);
484 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
485 dataType);
486 // All outputs
487 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
488 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
489 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
490 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
491
492 // Basic parameters
493 const TensorInfo& inputToForgetWeights
494 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
495 const TensorInfo& inputToCellWeights
496 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
497 const TensorInfo& inputToOutputWeights
498 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
499 const TensorInfo& recurrentToForgetWeights
500 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
501 const TensorInfo& recurrentToCellWeights
502 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
503 const TensorInfo& recurrentToOutputWeights
504 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
505 const TensorInfo& forgetGateBias
506 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
507 const TensorInfo& cellBias
508 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
509 const TensorInfo& outputGateBias
510 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
511
Jan Eilersd01a83c2019-07-03 18:20:40 +0100512 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100513
Jan Eilersd01a83c2019-07-03 18:20:40 +0100514 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
515 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
516 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
517 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
518 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
519 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
520 paramsInfo.m_ForgetGateBias = &forgetGateBias;
521 paramsInfo.m_CellBias = &cellBias;
522 paramsInfo.m_OutputGateBias = &outputGateBias;
523
524
525 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100526 TensorInfo optInputToInputWeights;
527 TensorInfo optRecurrentToInputWeights;
528 TensorInfo optCellToInputWeights;
529 TensorInfo optInputGateBias;
530 TensorInfo optProjectionWeights;
531 TensorInfo optProjectionBias;
532 TensorInfo optCellToForgetWeights;
533 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100534 TensorInfo optInputLayerNormWeights;
535 TensorInfo optForgetLayerNormWeights;
536 TensorInfo optCellLayerNormWeights;
537 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100538
539 if(!descriptor.m_CifgEnabled)
540 {
541 optInputToInputWeights =
542 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100543 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100544
545 optRecurrentToInputWeights =
546 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100547 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100548 optInputGateBias =
549 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100550 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100551 }
552
553 if(descriptor.m_ProjectionEnabled)
554 {
555 optProjectionWeights =
556 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100557 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100558 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
559 {
560 optProjectionBias =
561 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100562 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100563 }
564 }
565
566 if(descriptor.m_PeepholeEnabled)
567 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100568 if(!descriptor.m_CifgEnabled)
569 {
570 optCellToInputWeights =
571 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
572 dataType);
573 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
574 }
telsoa01c577f2c2018-08-31 09:22:23 +0100575 optCellToForgetWeights =
576 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100577 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100578 optCellToOutputWeights =
579 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100580 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100581 }
582
Jan Eilers38e05bd2019-06-26 13:10:09 +0100583 if(descriptor.m_LayerNormEnabled)
584 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100585 if (!descriptor.m_CifgEnabled)
586 {
587 optInputLayerNormWeights = OverrideDataType(
588 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
589 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
590 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100591
592 optForgetLayerNormWeights = OverrideDataType(
593 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100594 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100595
596 optCellLayerNormWeights = OverrideDataType(
597 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100598 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100599
600 optOutputLayerNormWeights = OverrideDataType(
601 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100602 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100603 }
604
David Beck33f0ae02018-10-18 15:13:56 +0100605 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100606 input,
607 outputStateIn,
608 cellStateIn,
609 scratchBuffer,
610 outputStateOut,
611 cellStateOut,
612 output,
613 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100614 paramsInfo,
615 reason);
telsoa014fcda012018-03-09 14:13:49 +0000616 break;
617 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000618 case LayerType::Maximum:
619 {
620 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
621 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
622 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
623
624 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
625 OverrideDataType(input1, dataType),
626 OverrideDataType(output, dataType),
627 reason);
628 break;
629 }
narpra01b89b05f2019-01-16 09:53:09 +0000630 case LayerType::MemCopy:
631 {
632 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
633 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000634
narpra01b89b05f2019-01-16 09:53:09 +0000635 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
636 OverrideDataType(output, dataType),
637 reason);
638 break;
639 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100640 case LayerType::MemImport:
641 {
642 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
643 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
644
645 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
646 OverrideDataType(output, dataType),
647 reason);
648 break;
649 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100650 case LayerType::Merge:
651 {
652 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
653 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
654 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
655
656 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
657 OverrideDataType(input1, dataType),
658 OverrideDataType(output, dataType),
659 reason);
660 break;
661 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100662 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000663 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100664 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000665
telsoa01c577f2c2018-08-31 09:22:23 +0100666 // Get vector of all inputs.
667 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000668 {
telsoa01c577f2c2018-08-31 09:22:23 +0100669 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000670 };
Finn Williams3e54d032020-10-22 16:53:35 +0100671
672 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
673 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100674 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000675
telsoa01c577f2c2018-08-31 09:22:23 +0100676 auto getTensorInfoPtr = [](const TensorInfo& info)
677 {
678 return &info;
679 };
Finn Williams3e54d032020-10-22 16:53:35 +0100680
681 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
682 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100683 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000684
Nikhil Raj8599a412018-11-19 14:51:07 +0000685 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
686
Jim Flynne242f2d2019-05-22 14:24:13 +0100687 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
688
689
telsoa014fcda012018-03-09 14:13:49 +0000690 break;
691 }
692 case LayerType::Multiplication:
693 {
694 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
695 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100696 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100697 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100698 OverrideDataType(input0, dataType),
699 OverrideDataType(input1, dataType),
700 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100701 reason);
telsoa014fcda012018-03-09 14:13:49 +0000702 break;
703 }
704 case LayerType::Normalization:
705 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100706 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000707 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
708 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100709 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
710 OverrideDataType(output, dataType),
711 cLayer->GetParameters(),
712 reason);
telsoa014fcda012018-03-09 14:13:49 +0000713 break;
714 }
715 case LayerType::Output:
716 {
717 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100718 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000719 break;
720 }
721 case LayerType::Permute:
722 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100723 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000724 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
725 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100726 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
727 OverrideDataType(output, dataType),
728 cLayer->GetParameters(),
729 reason);
telsoa014fcda012018-03-09 14:13:49 +0000730 break;
731 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100732 case LayerType::Pad:
733 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100734 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100735 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
736 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100737 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100738 OverrideDataType(input, dataType),
739 OverrideDataType(output, dataType),
740 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100741 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100742 break;
743 }
telsoa014fcda012018-03-09 14:13:49 +0000744 case LayerType::Pooling2d:
745 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100746 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000747 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
748 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100749 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
750 OverrideDataType(output, dataType),
751 cLayer->GetParameters(),
752 reason);
telsoa014fcda012018-03-09 14:13:49 +0000753 break;
754 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000755 case LayerType::PreCompiled:
756 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100757 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000758 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
759 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
760 cLayer->GetParameters(),
761 reason);
762 break;
763 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000764 case LayerType::Quantize:
765 {
766 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
767 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
768 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
769 break;
770 }
James Conroy586a9aa2020-03-20 08:49:33 +0000771 case LayerType::QLstm:
772 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100773 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000774 const QLstmDescriptor& descriptor = cLayer->GetParameters();
775
776 // Inputs
777 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
778 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
779 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
780
781 // Outputs
782 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
783 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
784 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
785
786 // Lstm parameters
787 LstmInputParamsInfo paramsInfo;
788
789 // Basic parameters
790 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
791 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
792 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
793
794 paramsInfo.m_RecurrentToForgetWeights =
795 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
796 paramsInfo.m_RecurrentToCellWeights =
797 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
798 paramsInfo.m_RecurrentToOutputWeights =
799 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
800
801 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
802 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
803 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
804
805 if(!descriptor.m_CifgEnabled)
806 {
807 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
808 paramsInfo.m_RecurrentToInputWeights =
809 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
810 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
811 }
812
813 if(descriptor.m_ProjectionEnabled)
814 {
815 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100816
817 // Projection bias is optional even if projection is enabled
818 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
819 {
820 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
821 }
James Conroy586a9aa2020-03-20 08:49:33 +0000822 }
823
824 if(descriptor.m_PeepholeEnabled)
825 {
826 if (!descriptor.m_CifgEnabled)
827 {
828 paramsInfo.m_CellToInputWeights =
829 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
830 }
831
832 paramsInfo.m_CellToForgetWeights =
833 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
834 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
835 }
836
837 if(descriptor.m_LayerNormEnabled)
838 {
839 if (!descriptor.m_CifgEnabled)
840 {
841 paramsInfo.m_InputLayerNormWeights =
842 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
843 }
844
845 paramsInfo.m_ForgetLayerNormWeights =
846 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
847 paramsInfo.m_CellLayerNormWeights =
848 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
849 paramsInfo.m_OutputLayerNormWeights =
850 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
851 }
852
853 result = layerSupportObject->IsQLstmSupported(input,
854 previousOutputIn,
855 previousCellStateIn,
856 outputStateOut,
857 cellStateOut,
858 output,
859 descriptor,
860 paramsInfo,
861 reason);
862 break;
863 }
James Conroyee18dc82019-07-17 11:27:46 +0100864 case LayerType::QuantizedLstm:
865 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100866 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100867
868 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100869 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
870 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
871 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100872
873 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100874 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
875 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100876
877 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100878 QuantizedLstmInputParamsInfo paramsInfo;
879
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100880 paramsInfo.m_InputToInputWeights =
881 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
882 paramsInfo.m_InputToForgetWeights =
883 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
884 paramsInfo.m_InputToCellWeights =
885 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
886 paramsInfo.m_InputToOutputWeights =
887 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100888
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100889 paramsInfo.m_RecurrentToInputWeights =
890 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
891 paramsInfo.m_RecurrentToForgetWeights =
892 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
893 paramsInfo.m_RecurrentToCellWeights =
894 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
895 paramsInfo.m_RecurrentToOutputWeights =
896 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100897
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100898 paramsInfo.m_InputGateBias =
899 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
900 paramsInfo.m_ForgetGateBias =
901 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
902 paramsInfo.m_CellBias =
903 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
904 paramsInfo.m_OutputGateBias =
905 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100906
907 result = layerSupportObject->IsQuantizedLstmSupported(input,
908 previousCellStateIn,
909 previousOutputIn,
910 cellStateOut,
911 output,
912 paramsInfo,
913 reason);
914 break;
915 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100916 case LayerType::Division:
917 {
918 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
919 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
920 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100921 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100922 OverrideDataType(input0, dataType),
923 OverrideDataType(input1, dataType),
924 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100925 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100926 break;
927 }
Finn Williams2605b232020-06-10 15:53:46 +0100928 case LayerType::Rank:
929 {
930 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
931 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
932 result = layerSupportObject->IsRankSupported(OverrideDataType(input, dataType),
933 OverrideDataType(output, dataType),
934 reason);
935 break;
936 }
telsoa014fcda012018-03-09 14:13:49 +0000937 case LayerType::Reshape:
938 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100939 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000940 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000941 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000942 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
Kevin Maya023c402019-12-12 17:28:05 +0000943 OverrideDataType(output, dataType),
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000944 cLayer->GetParameters(),
945 reason);
telsoa014fcda012018-03-09 14:13:49 +0000946 break;
947 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100948 case LayerType::Resize:
949 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100950 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100951 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100952 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
953 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
954 OverrideDataType(output, dataType),
955 cLayer->GetParameters(),
956 reason);
957 break;
958 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100959 case LayerType::Slice:
960 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100961 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100962
963 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
964 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
965
966 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
967 OverrideDataType(output, dataType),
968 cLayer->GetParameters(),
969 reason);
970 break;
971 }
telsoa014fcda012018-03-09 14:13:49 +0000972 case LayerType::Softmax:
973 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100974 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000975 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100976 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100977 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
978 OverrideDataType(output, dataType),
979 cLayer->GetParameters(),
980 reason);
telsoa014fcda012018-03-09 14:13:49 +0000981 break;
982 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000983 case LayerType::SpaceToBatchNd:
984 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100985 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000986 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
987 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
988 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
989 OverrideDataType(output, dataType),
990 cLayer->GetParameters(),
991 reason);
992 break;
993 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100994 case LayerType::SpaceToDepth:
995 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100996 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100997
998 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
999 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1000
1001 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1002 OverrideDataType(output, dataType),
1003 cLayer->GetParameters(),
1004 reason);
1005 break;
1006 }
telsoa014fcda012018-03-09 14:13:49 +00001007 case LayerType::Splitter:
1008 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001009 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001010 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001011
1012 // Get vector of all outputs.
1013 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1014 {
1015 return OverrideDataType(slot.GetTensorInfo(), dataType);
1016 };
Finn Williams3e54d032020-10-22 16:53:35 +01001017 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1018 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001019 std::vector<TensorInfo> outputs(beginI, endI);
1020
1021 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1022
David Beck33f0ae02018-10-18 15:13:56 +01001023 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001024 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +01001025 cLayer->GetParameters(),
1026 reason);
telsoa014fcda012018-03-09 14:13:49 +00001027 break;
1028 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001029 case LayerType::Stack:
1030 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001031 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001032
1033 // Get vector of all inputs.
1034 auto getTensorInfo = [&dataType](const InputSlot& slot)
1035 {
1036 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1037 };
Finn Williams3e54d032020-10-22 16:53:35 +01001038 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1039 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001040 std::vector<TensorInfo> inputs(beginI, endI);
1041
1042 auto getTensorInfoPtr = [](const TensorInfo& info)
1043 {
1044 return &info;
1045 };
Finn Williams3e54d032020-10-22 16:53:35 +01001046 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1047 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001048 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1049
1050 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1051
1052 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1053
1054 break;
1055 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001056 case LayerType::StandIn:
1057 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001058 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001059
1060 // Get vector of all inputs.
1061 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1062 {
1063 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1064 };
1065 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1066 {
1067 return OverrideDataType(slot.GetTensorInfo(), dataType);
1068 };
Finn Williams3e54d032020-10-22 16:53:35 +01001069 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1070 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001071 std::vector<TensorInfo> inputs(beginI, endI);
1072
Finn Williams3e54d032020-10-22 16:53:35 +01001073 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1074 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001075 std::vector<TensorInfo> outputs(beginO, endO);
1076
1077
1078 auto getTensorInfoPtr = [](const TensorInfo& info)
1079 {
1080 return &info;
1081 };
Finn Williams3e54d032020-10-22 16:53:35 +01001082 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1083 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001084 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1085
Finn Williams3e54d032020-10-22 16:53:35 +01001086 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1087 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001088 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1089
1090
1091 result = layerSupportObject->IsStandInSupported(inputPtrs,
1092 outputPtrs,
1093 cLayer->GetParameters(),
1094 reason);
1095 break;
1096 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001097 case LayerType::StridedSlice:
1098 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001099 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1102 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
1103 OverrideDataType(output, dataType),
1104 cLayer->GetParameters(),
1105 reason);
1106 break;
1107 }
David Beckc2044fe2018-09-05 15:00:38 +01001108 case LayerType::Subtraction:
1109 {
1110 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1111 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1112 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001113 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001114 OverrideDataType(input0, dataType),
1115 OverrideDataType(input1, dataType),
1116 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001117 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001118 break;
1119 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001120 case LayerType::Switch:
1121 {
1122 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1123 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1124 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1125 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1126 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
1127 OverrideDataType(input1, dataType),
1128 OverrideDataType(output0, dataType),
1129 OverrideDataType(output1, dataType),
1130 reason);
1131 break;
1132 }
narpra0132b90462018-09-13 11:07:48 +01001133 case LayerType::Mean:
1134 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001135 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001136 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1137 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001138 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001139 OverrideDataType(input, dataType),
1140 OverrideDataType(output, dataType),
1141 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001142 reason);
narpra0132b90462018-09-13 11:07:48 +01001143 break;
1144 }
kevmay0190539692018-11-29 08:40:19 +00001145 case LayerType::Minimum:
1146 {
1147 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1148 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1149 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1150 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1151 OverrideDataType(input1, dataType),
1152 OverrideDataType(output, dataType),
1153 reason);
1154 break;
1155 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001156 case LayerType::Prelu:
1157 {
1158 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1159 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1161 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1162 OverrideDataType(alpha, dataType),
1163 OverrideDataType(output, dataType),
1164 reason);
1165 break;
1166 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001167 case LayerType::Transpose:
1168 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001169 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001170 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1171 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1172 result = layerSupportObject->IsTransposeSupported(OverrideDataType(input, dataType),
1173 OverrideDataType(output, dataType),
1174 cLayer->GetParameters(),
1175 reason);
1176 break;
1177 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001178 case LayerType::TransposeConvolution2d:
1179 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001180 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001181
1182 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1183 dataType);
1184 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1185
1186 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1187
1188 Optional<TensorInfo> biases;
1189 if (descriptor.m_BiasEnabled)
1190 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001191 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001192 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1193 GetBiasTypeFromWeightsType(dataType));
1194 }
1195
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001196 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001197 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1198
1199 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1200 output,
1201 descriptor,
1202 weights,
1203 biases,
1204 reason);
1205
1206 break;
1207 }
telsoa014fcda012018-03-09 14:13:49 +00001208 default:
1209 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001210 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001211 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001212 result = false;
1213 break;
1214 }
1215 }
telsoa014fcda012018-03-09 14:13:49 +00001216 return result;
1217}
1218
Sadik Armagan045f6be2020-09-10 13:37:32 +01001219bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1220 const IConnectableLayer& connectableLayer,
1221 Optional<DataType> dataType,
1222 std::string& outReasonIfUnsupported)
1223{
1224 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1225}
1226
David Beckdcb751f2018-10-03 11:42:42 +01001227bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001228 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001229 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001230{
Jan Eilersbb446e52020-04-02 13:56:54 +01001231 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001232 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1233}
1234
1235// TODO merge with defaulted modelOptions above
1236bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1237 Optional<DataType> dataType,
1238 std::string& outReasonIfUnsupported,
1239 const ModelOptions& modelOptions)
1240{
1241 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1242 return IsLayerConfigurationSupported(layer->GetBackendId(),
1243 connectableLayer,
1244 dataType,
1245 outReasonIfUnsupported,
1246 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001247}
1248
Sadik Armagan04a72972020-09-14 15:44:18 +01001249bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1250 const IConnectableLayer& connectableLayer,
1251 Optional<DataType> dataType,
1252 std::string& outReasonIfUnsupported,
1253 const ModelOptions& modelOptions)
1254{
1255 return IsLayerConfigurationSupported(backendId,
1256 connectableLayer,
1257 dataType,
1258 outReasonIfUnsupported,
1259 modelOptions);
1260}
1261
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001262// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001263std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1264 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001265{
1266 return std::unique_ptr<IWorkload>();
1267}
1268
Derek Lamberti901ea112019-12-10 22:07:09 +00001269std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1270 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001271{
1272 return std::unique_ptr<IWorkload>();
1273}
1274
Derek Lamberti901ea112019-12-10 22:07:09 +00001275std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1276 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001277{
1278 return std::unique_ptr<IWorkload>();
1279}
1280
Derek Lamberti901ea112019-12-10 22:07:09 +00001281std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1282 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001283{
1284 return std::unique_ptr<IWorkload>();
1285}
1286
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001287std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001288 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001289{
1290 return std::unique_ptr<IWorkload>();
1291}
1292
Derek Lamberti901ea112019-12-10 22:07:09 +00001293std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1294 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001295{
1296 return std::unique_ptr<IWorkload>();
1297}
1298
Derek Lamberti901ea112019-12-10 22:07:09 +00001299std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1300 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001301{
1302 return std::unique_ptr<IWorkload>();
1303}
1304
Derek Lamberti901ea112019-12-10 22:07:09 +00001305std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1306 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001307{
1308 return std::unique_ptr<IWorkload>();
1309}
1310
Derek Lamberti901ea112019-12-10 22:07:09 +00001311std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1312 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001313{
1314 return std::unique_ptr<IWorkload>();
1315}
1316
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001317std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1318 const WorkloadInfo& /*info*/) const
1319{
1320 return std::unique_ptr<IWorkload>();
1321}
1322
Derek Lamberti901ea112019-12-10 22:07:09 +00001323std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1324 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001325{
1326 return std::unique_ptr<IWorkload>();
1327}
1328
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001329std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1330 const WorkloadInfo& /*info*/) const
1331{
1332 return std::unique_ptr<IWorkload>();
1333}
1334
Derek Lamberti901ea112019-12-10 22:07:09 +00001335std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1336 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001337{
1338 return std::unique_ptr<IWorkload>();
1339}
1340
Derek Lamberti901ea112019-12-10 22:07:09 +00001341std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1342 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001343{
1344 return std::unique_ptr<IWorkload>();
1345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1348 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001349{
1350 return std::unique_ptr<IWorkload>();
1351}
1352
Derek Lamberti901ea112019-12-10 22:07:09 +00001353std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1354 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001355{
1356 return std::unique_ptr<IWorkload>();
1357}
1358
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001359std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001360 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001361{
1362 return std::unique_ptr<IWorkload>();
1363}
1364
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001365std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001366 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001367{
1368 return std::unique_ptr<IWorkload>();
1369}
1370
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001371std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001372 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001373{
1374 return std::unique_ptr<IWorkload>();
1375}
1376
Derek Lamberti901ea112019-12-10 22:07:09 +00001377std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1378 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001379{
1380 return std::unique_ptr<IWorkload>();
1381}
1382
josh minor4a3c6102020-01-06 16:40:46 -06001383std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1384 const WorkloadInfo& /*info*/) const
1385{
1386 return std::unique_ptr<IWorkload>();
1387}
1388
Derek Lamberti901ea112019-12-10 22:07:09 +00001389std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1390 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001391{
1392 return std::unique_ptr<IWorkload>();
1393}
1394
Derek Lamberti901ea112019-12-10 22:07:09 +00001395std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1396 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001397{
1398 return std::unique_ptr<IWorkload>();
1399}
1400
Ryan OSheaec6c6802020-06-05 17:17:06 +01001401std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1402 const WorkloadInfo& /*info*/) const
1403{
1404 return std::unique_ptr<IWorkload>();
1405}
1406
Derek Lamberti901ea112019-12-10 22:07:09 +00001407std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1408 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001409{
1410 return std::unique_ptr<IWorkload>();
1411}
1412
Derek Lamberti901ea112019-12-10 22:07:09 +00001413std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1414 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001415{
1416 return std::unique_ptr<IWorkload>();
1417}
1418
Derek Lamberti901ea112019-12-10 22:07:09 +00001419std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1420 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001421{
1422 return std::unique_ptr<IWorkload>();
1423}
1424
Derek Lamberti901ea112019-12-10 22:07:09 +00001425std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1426 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001427{
1428 return std::unique_ptr<IWorkload>();
1429}
1430
Kevin Mayce5045a2019-10-02 14:07:47 +01001431std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001432 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1433 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001434{
1435 return std::unique_ptr<IWorkload>();
1436}
1437
Derek Lamberti901ea112019-12-10 22:07:09 +00001438std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1439 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001440{
1441 return std::unique_ptr<IWorkload>();
1442}
1443
Derek Lamberti901ea112019-12-10 22:07:09 +00001444std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1445 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001446{
1447 return std::unique_ptr<IWorkload>();
1448}
1449
Derek Lamberti901ea112019-12-10 22:07:09 +00001450std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1451 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001452{
1453 return std::unique_ptr<IWorkload>();
1454}
1455
Derek Lamberti901ea112019-12-10 22:07:09 +00001456std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1457 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001458{
1459 return std::unique_ptr<IWorkload>();
1460}
1461
Derek Lamberti901ea112019-12-10 22:07:09 +00001462std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1463 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001464{
1465 return std::unique_ptr<IWorkload>();
1466}
1467
Derek Lamberti901ea112019-12-10 22:07:09 +00001468std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1469 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001470{
1471 return std::unique_ptr<IWorkload>();
1472}
1473
Derek Lamberti901ea112019-12-10 22:07:09 +00001474std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1475 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001476{
1477 return std::unique_ptr<IWorkload>();
1478}
1479
Derek Lamberti901ea112019-12-10 22:07:09 +00001480std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1481 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001482{
1483 return std::unique_ptr<IWorkload>();
1484}
1485
Derek Lamberti901ea112019-12-10 22:07:09 +00001486std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1487 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001488{
1489 return std::unique_ptr<IWorkload>();
1490}
1491
Derek Lamberti901ea112019-12-10 22:07:09 +00001492std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1493 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001494{
1495 return std::unique_ptr<IWorkload>();
1496}
1497
Derek Lamberti901ea112019-12-10 22:07:09 +00001498std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1499 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001500{
1501 return std::unique_ptr<IWorkload>();
1502}
1503
Derek Lamberti901ea112019-12-10 22:07:09 +00001504std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1505 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001506{
1507 return std::unique_ptr<IWorkload>();
1508}
1509
Derek Lamberti901ea112019-12-10 22:07:09 +00001510std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1511 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001512{
1513 return std::unique_ptr<IWorkload>();
1514}
1515
Derek Lamberti901ea112019-12-10 22:07:09 +00001516std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1517 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001518{
1519 return std::unique_ptr<IWorkload>();
1520}
1521
Derek Lamberti901ea112019-12-10 22:07:09 +00001522std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001523 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001524{
1525 return std::unique_ptr<IWorkload>();
1526}
1527
Derek Lamberti901ea112019-12-10 22:07:09 +00001528std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1529 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001530{
1531 return std::unique_ptr<IWorkload>();
1532}
1533
Derek Lamberti901ea112019-12-10 22:07:09 +00001534std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1535 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001536{
1537 return std::unique_ptr<IWorkload>();
1538}
1539
Derek Lamberti901ea112019-12-10 22:07:09 +00001540std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1541 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001542{
1543 return std::unique_ptr<IWorkload>();
1544}
1545
Derek Lamberti901ea112019-12-10 22:07:09 +00001546std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1547 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001548{
1549 return std::unique_ptr<IWorkload>();
1550}
1551
James Conroy586a9aa2020-03-20 08:49:33 +00001552std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1553 const WorkloadInfo& /*info*/) const
1554{
1555 return std::unique_ptr<IWorkload>();
1556}
1557
Derek Lamberti901ea112019-12-10 22:07:09 +00001558std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1559 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001560{
1561 return std::unique_ptr<IWorkload>();
1562}
Finn Williams2605b232020-06-10 15:53:46 +01001563std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1564 const WorkloadInfo& /*info*/) const
1565{
1566 return std::unique_ptr<IWorkload>();
1567}
James Conroyee18dc82019-07-17 11:27:46 +01001568
Derek Lamberti901ea112019-12-10 22:07:09 +00001569std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1570 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001571{
1572 return std::unique_ptr<IWorkload>();
1573}
1574
Derek Lamberti901ea112019-12-10 22:07:09 +00001575std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1576 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001577{
1578 return std::unique_ptr<IWorkload>();
1579}
1580
Derek Lamberti901ea112019-12-10 22:07:09 +00001581std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1582 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001583{
1584 return std::unique_ptr<IWorkload>();
1585}
1586
Derek Lamberti901ea112019-12-10 22:07:09 +00001587std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1588 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001589{
1590 return std::unique_ptr<IWorkload>();
1591}
1592
Derek Lamberti901ea112019-12-10 22:07:09 +00001593std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1594 const WorkloadInfo& /*info*/) const
1595{
1596 return std::unique_ptr<IWorkload>();
1597}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001598
Derek Lamberti901ea112019-12-10 22:07:09 +00001599std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1600 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001601{
1602 return std::unique_ptr<IWorkload>();
1603}
1604
Derek Lamberti901ea112019-12-10 22:07:09 +00001605std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1606 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001607{
1608 return std::unique_ptr<IWorkload>();
1609}
1610
Derek Lamberti901ea112019-12-10 22:07:09 +00001611std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1612 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001613{
1614 return std::unique_ptr<IWorkload>();
1615}
1616
Derek Lamberti901ea112019-12-10 22:07:09 +00001617std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1618 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001619{
1620 return std::unique_ptr<IWorkload>();
1621}
1622
Derek Lamberti901ea112019-12-10 22:07:09 +00001623std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1624 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001625{
1626 return std::unique_ptr<IWorkload>();
1627}
1628
Derek Lamberti901ea112019-12-10 22:07:09 +00001629std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1630 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001631{
1632 return std::unique_ptr<IWorkload>();
1633}
1634
Derek Lamberti901ea112019-12-10 22:07:09 +00001635std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1636 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001637{
1638 return std::unique_ptr<IWorkload>();
1639}
1640
Derek Lamberti901ea112019-12-10 22:07:09 +00001641std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1642 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001643{
1644 return std::unique_ptr<IWorkload>();
1645}
1646
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001647std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1648 const WorkloadInfo& /*info*/) const
1649{
1650 return std::unique_ptr<IWorkload>();
1651}
1652
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001653std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001654 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1655 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001656{
1657 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001658}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001659
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001660} // namepsace armnn