blob: 20d7134c3a1038dc3305e6f6c37428c897fd7a58 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000012#include <armnn/BackendHelper.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
Finn Williams3e54d032020-10-22 16:53:35 +010015#include <armnn/utility/TransformIterator.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000017#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000018#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000019
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
David Beck111b5d92018-11-12 14:59:37 +000022#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000023
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
telsoa01c577f2c2018-08-31 09:22:23 +010027namespace
28{
Finn Williams3e54d032020-10-22 16:53:35 +010029using LayerList = std::list<Layer*>;
30using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa01c577f2c2018-08-31 09:22:23 +010031
David Beck29c75de2018-10-23 13:35:58 +010032const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
33{
34 if (!type)
35 {
36 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010037 }
38
David Beck29c75de2018-10-23 13:35:58 +010039 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010040}
41
David Beck29c75de2018-10-23 13:35:58 +010042} // anonymous namespace
43
Sadik Armagan045f6be2020-09-10 13:37:32 +010044bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
45 const IConnectableLayer& connectableLayer,
46 Optional<DataType> dataType,
47 std::string& outReasonIfUnsupported,
48 const ModelOptions& modelOptions)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010052 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010053
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000067 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
Jan Eilersbb446e52020-04-02 13:56:54 +010073 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000074 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 result = layerSupportObject.IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000088 result = layerSupportObject.IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
Jan Eilersbb446e52020-04-02 13:56:54 +010097 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +010098 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000102 result = layerSupportObject.IsArgMinMaxSupported(
Nikhil Rajee391d52019-09-05 17:50:44 +0100103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000118 result = layerSupportObject.IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000134
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000135 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000139 break;
140 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100141 case LayerType::Comparison:
142 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100143 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100144
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000149 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
153 reason);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100154 break;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 case LayerType::Constant:
157 {
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000159 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 break;
161 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000162 case LayerType::ConvertBf16ToFp32:
163 {
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000166 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000167 break;
168 }
telsoa01c577f2c2018-08-31 09:22:23 +0100169 case LayerType::ConvertFp16ToFp32:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000173 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 break;
175 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000176 case LayerType::ConvertFp32ToBf16:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000180 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000181 break;
182 }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 case LayerType::ConvertFp32ToFp16:
184 {
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000187 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000188 break;
189 }
190 case LayerType::Convolution2d:
191 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100192 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100193
194 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
195 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100196 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100197 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100198
arovir01a6824102018-08-28 17:40:45 +0100199 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100200
arovir01a6824102018-08-28 17:40:45 +0100201 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100202 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100203 if (descriptor.m_BiasEnabled)
204 {
David Beck5eec11d2018-10-04 15:43:17 +0100205 biases =
206 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100207 }
208
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000209 result = layerSupportObject.IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100210 input,
211 output,
212 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100213 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100214 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100215 reason);
telsoa014fcda012018-03-09 14:13:49 +0000216 break;
217 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000218 case LayerType::Debug:
219 {
220 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
222
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000223 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000224 OverrideDataType(output, dataType),
225 reason);
226 break;
227 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100228 case LayerType::DepthToSpace:
229 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100230 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100231
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
234
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000235 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100236 OverrideDataType(output, dataType),
237 cLayer->GetParameters(),
238 reason);
239 break;
240 }
telsoa014fcda012018-03-09 14:13:49 +0000241 case LayerType::DepthwiseConvolution2d:
242 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100243 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100244 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
245 dataType);
246 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100247 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100248
telsoa01c577f2c2018-08-31 09:22:23 +0100249 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100250
251 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100252 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100253 if (descriptor.m_BiasEnabled)
254 {
David Beck5eec11d2018-10-04 15:43:17 +0100255 biases =
256 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100257 }
telsoa01c577f2c2018-08-31 09:22:23 +0100258
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000259 result = layerSupportObject.IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100260 input,
261 output,
262 descriptor,
263 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100264 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100265 reason);
telsoa014fcda012018-03-09 14:13:49 +0000266 break;
267 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000268 case LayerType::Dequantize:
269 {
270 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
271 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
272
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000273 result = layerSupportObject.IsDequantizeSupported(input,
274 OverrideDataType(output, dataType),
275 reason);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000276 break;
277 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000278 case LayerType::DetectionPostProcess:
279 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100280 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000281 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
282 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
283 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
284
285 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
286 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
287 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
288 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
289
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000290 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000291 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
292 scores,
293 anchors,
294 detectionBoxes,
295 detectionClasses,
296 detectionScores,
297 numDetections,
298 descriptor,
299 reason);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000300 break;
301 }
josh minor4a3c6102020-01-06 16:40:46 -0600302 case LayerType::ElementwiseUnary:
303 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100304 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600305
306 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
307 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
308
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000309 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
310 OverrideDataType(output, dataType),
311 cLayer->GetParameters(),
312 reason);
josh minor4a3c6102020-01-06 16:40:46 -0600313 break;
314 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100315 case LayerType::Fill:
316 {
317 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
318 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
319 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
320 const FillDescriptor& descriptor = cLayer->GetParameters();
321
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000322 result = layerSupportObject.IsFillSupported(
Ryan OSheaec6c6802020-06-05 17:17:06 +0100323 OverrideDataType(input, dataType),
324 OverrideDataType(output, dataType),
325 descriptor,
326 reason);
327 break;
328 }
telsoa014fcda012018-03-09 14:13:49 +0000329 case LayerType::FakeQuantization:
330 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100331 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000332 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000333 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
334 cLayer->GetParameters(),
335 reason);
telsoa014fcda012018-03-09 14:13:49 +0000336 break;
337 }
338 case LayerType::Floor:
339 {
340 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
341 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000342 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
343 OverrideDataType(output, dataType),
344 reason);
telsoa014fcda012018-03-09 14:13:49 +0000345 break;
346 }
347 case LayerType::FullyConnected:
348 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100349 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000350 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100351 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000352
353 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
354 TensorInfo weightsInfo;
355 const TensorInfo* weightsInfoPtr = nullptr;
356
357 if (descriptor.m_ConstantWeights)
358 {
359 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
360 weightsInfo = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
361 }
362 else
363 {
364 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
365
366 }
367 weightsInfoPtr = &weightsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100368
369 TensorInfo biasInfo;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000370 const TensorInfo* biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000371 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100372 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
373 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
374 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
375
telsoa01c577f2c2018-08-31 09:22:23 +0100376 if (descriptor.m_BiasEnabled)
377 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000378 if(descriptor.m_ConstantWeights)
379 {
380 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
381 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
382 biasInfoPtr = &biasInfo;
383 }
384 else
385 {
386 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
387 biasInfoPtr = &biasInfo;
388 }
telsoa01c577f2c2018-08-31 09:22:23 +0100389 }
390 else
391 {
392 // If biases are not enabled pass a dummy tensorinfo for the validation
393 switch(input.GetDataType())
394 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000395 case DataType::BFloat16:
396 {
397 biasInfoPtr = &dummyBFloat16Bias;
398 break;
399 }
telsoa01c577f2c2018-08-31 09:22:23 +0100400 case DataType::Float16:
401 {
402 biasInfoPtr = &dummyFloat16Bias;
403 break;
404 }
405 case DataType::Float32:
406 {
407 biasInfoPtr = &dummyFloat32Bias;
408 break;
409 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000410 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000411 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000412 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000413 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100414 {
415 biasInfoPtr = &dummyQA8Bias;
416 break;
417 }
418 default:
419 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100420 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100421 }
422 }
423 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000424 result = layerSupportObject.IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100425 OverrideDataType(input, dataType),
426 OverrideDataType(output, dataType),
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000427 *weightsInfoPtr,
telsoa01c577f2c2018-08-31 09:22:23 +0100428 *biasInfoPtr,
429 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100430 reason);
telsoa014fcda012018-03-09 14:13:49 +0000431 break;
432 }
narpra01b89b05f2019-01-16 09:53:09 +0000433 case LayerType::Gather:
434 {
435 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
436 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
437 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Teresa Charlin52664732020-06-29 16:27:03 +0100438 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
439 const GatherDescriptor& descriptor = cLayer->GetParameters();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000440 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
441 input1,
442 OverrideDataType(output, dataType),
443 descriptor,
444 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000445 break;
446 }
telsoa014fcda012018-03-09 14:13:49 +0000447 case LayerType::Input:
448 {
449 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000450 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000451 break;
452 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100453 case LayerType::InstanceNormalization:
454 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100455 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100456 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
457
458 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
459 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
460
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000461 result = layerSupportObject.IsInstanceNormalizationSupported(
Kevin Mayce5045a2019-10-02 14:07:47 +0100462 OverrideDataType(input, dataType),
463 OverrideDataType(output, dataType),
464 descriptor,
465 reason);
466 break;
467 }
telsoa014fcda012018-03-09 14:13:49 +0000468 case LayerType::L2Normalization:
469 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100470 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100471 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
472
telsoa014fcda012018-03-09 14:13:49 +0000473 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100474 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100475
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000476 result = layerSupportObject.IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100477 OverrideDataType(input, dataType),
478 OverrideDataType(output, dataType),
479 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100480 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100481 break;
482 }
James Conroyaba90cd2020-11-06 16:28:18 +0000483 case LayerType::LogicalBinary:
484 {
485 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
486
487 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
488 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
489 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
490
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000491 result = layerSupportObject.IsLogicalBinarySupported(input0,
492 input1,
493 output,
494 cLayer->GetParameters(),
495 reason);
James Conroyaba90cd2020-11-06 16:28:18 +0000496 break;
497 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100498 case LayerType::LogSoftmax:
499 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100500 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100501
502 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
503 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
504
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000505 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
506 OverrideDataType(output, dataType),
507 cLayer->GetParameters(),
508 reason);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100509 break;
510 }
telsoa01c577f2c2018-08-31 09:22:23 +0100511 case LayerType::Lstm:
512 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100513 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100514 const LstmDescriptor& descriptor = cLayer->GetParameters();
515
516 // All inputs.
517 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
518 dataType);
519 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
520 dataType);
521 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
522 dataType);
523 // All outputs
524 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
525 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
526 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
527 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
528
529 // Basic parameters
530 const TensorInfo& inputToForgetWeights
531 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
532 const TensorInfo& inputToCellWeights
533 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
534 const TensorInfo& inputToOutputWeights
535 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
536 const TensorInfo& recurrentToForgetWeights
537 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
538 const TensorInfo& recurrentToCellWeights
539 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
540 const TensorInfo& recurrentToOutputWeights
541 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
542 const TensorInfo& forgetGateBias
543 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
544 const TensorInfo& cellBias
545 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
546 const TensorInfo& outputGateBias
547 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
548
Jan Eilersd01a83c2019-07-03 18:20:40 +0100549 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100550
Jan Eilersd01a83c2019-07-03 18:20:40 +0100551 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
552 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
553 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
554 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
555 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
556 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
557 paramsInfo.m_ForgetGateBias = &forgetGateBias;
558 paramsInfo.m_CellBias = &cellBias;
559 paramsInfo.m_OutputGateBias = &outputGateBias;
560
561
562 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100563 TensorInfo optInputToInputWeights;
564 TensorInfo optRecurrentToInputWeights;
565 TensorInfo optCellToInputWeights;
566 TensorInfo optInputGateBias;
567 TensorInfo optProjectionWeights;
568 TensorInfo optProjectionBias;
569 TensorInfo optCellToForgetWeights;
570 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100571 TensorInfo optInputLayerNormWeights;
572 TensorInfo optForgetLayerNormWeights;
573 TensorInfo optCellLayerNormWeights;
574 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100575
576 if(!descriptor.m_CifgEnabled)
577 {
578 optInputToInputWeights =
579 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100580 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100581
582 optRecurrentToInputWeights =
583 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100584 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100585 optInputGateBias =
586 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100587 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100588 }
589
590 if(descriptor.m_ProjectionEnabled)
591 {
592 optProjectionWeights =
593 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100594 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100595 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
596 {
597 optProjectionBias =
598 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100599 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100600 }
601 }
602
603 if(descriptor.m_PeepholeEnabled)
604 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100605 if(!descriptor.m_CifgEnabled)
606 {
607 optCellToInputWeights =
608 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
609 dataType);
610 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
611 }
telsoa01c577f2c2018-08-31 09:22:23 +0100612 optCellToForgetWeights =
613 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100614 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100615 optCellToOutputWeights =
616 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100617 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100618 }
619
Jan Eilers38e05bd2019-06-26 13:10:09 +0100620 if(descriptor.m_LayerNormEnabled)
621 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100622 if (!descriptor.m_CifgEnabled)
623 {
624 optInputLayerNormWeights = OverrideDataType(
625 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
626 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
627 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100628
629 optForgetLayerNormWeights = OverrideDataType(
630 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100631 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100632
633 optCellLayerNormWeights = OverrideDataType(
634 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100635 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100636
637 optOutputLayerNormWeights = OverrideDataType(
638 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100639 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100640 }
641
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000642 result = layerSupportObject.IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100643 input,
644 outputStateIn,
645 cellStateIn,
646 scratchBuffer,
647 outputStateOut,
648 cellStateOut,
649 output,
650 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100651 paramsInfo,
652 reason);
telsoa014fcda012018-03-09 14:13:49 +0000653 break;
654 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000655 case LayerType::Maximum:
656 {
657 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
658 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
659 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
660
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000661 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
662 OverrideDataType(input1, dataType),
663 OverrideDataType(output, dataType),
664 reason);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000665 break;
666 }
narpra01b89b05f2019-01-16 09:53:09 +0000667 case LayerType::MemCopy:
668 {
669 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
670 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000671
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000672 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
673 OverrideDataType(output, dataType),
674 reason);
narpra01b89b05f2019-01-16 09:53:09 +0000675 break;
676 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100677 case LayerType::MemImport:
678 {
679 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
680 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
681
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000682 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
683 OverrideDataType(output, dataType),
684 reason);
Derek Lambertif674aa02019-08-01 15:56:25 +0100685 break;
686 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100687 case LayerType::Merge:
688 {
689 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
691 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
692
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000693 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
694 OverrideDataType(input1, dataType),
695 OverrideDataType(output, dataType),
696 reason);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100697 break;
698 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100699 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000700 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100701 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000702
telsoa01c577f2c2018-08-31 09:22:23 +0100703 // Get vector of all inputs.
704 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000705 {
telsoa01c577f2c2018-08-31 09:22:23 +0100706 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000707 };
Finn Williams3e54d032020-10-22 16:53:35 +0100708
709 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
710 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100711 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000712
telsoa01c577f2c2018-08-31 09:22:23 +0100713 auto getTensorInfoPtr = [](const TensorInfo& info)
714 {
715 return &info;
716 };
Finn Williams3e54d032020-10-22 16:53:35 +0100717
718 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
719 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
telsoa01c577f2c2018-08-31 09:22:23 +0100720 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000721
Nikhil Raj8599a412018-11-19 14:51:07 +0000722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000724 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Jim Flynne242f2d2019-05-22 14:24:13 +0100725
726
telsoa014fcda012018-03-09 14:13:49 +0000727 break;
728 }
729 case LayerType::Multiplication:
730 {
731 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100733 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000734 result = layerSupportObject.IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100735 OverrideDataType(input0, dataType),
736 OverrideDataType(input1, dataType),
737 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100738 reason);
telsoa014fcda012018-03-09 14:13:49 +0000739 break;
740 }
741 case LayerType::Normalization:
742 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100743 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000744 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
745 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000746 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
747 OverrideDataType(output, dataType),
748 cLayer->GetParameters(),
749 reason);
telsoa014fcda012018-03-09 14:13:49 +0000750 break;
751 }
752 case LayerType::Output:
753 {
754 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000755 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000756 break;
757 }
758 case LayerType::Permute:
759 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100760 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000761 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
762 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000763 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
764 OverrideDataType(output, dataType),
765 cLayer->GetParameters(),
766 reason);
telsoa014fcda012018-03-09 14:13:49 +0000767 break;
768 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100769 case LayerType::Pad:
770 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100771 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100772 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
773 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000774 result = layerSupportObject.IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100775 OverrideDataType(input, dataType),
776 OverrideDataType(output, dataType),
777 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100778 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100779 break;
780 }
telsoa014fcda012018-03-09 14:13:49 +0000781 case LayerType::Pooling2d:
782 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100783 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000784 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
785 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000786 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
787 OverrideDataType(output, dataType),
788 cLayer->GetParameters(),
789 reason);
telsoa014fcda012018-03-09 14:13:49 +0000790 break;
791 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000792 case LayerType::PreCompiled:
793 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100794 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000795 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000796 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
797 cLayer->GetParameters(),
798 reason);
Matteo Martincigh49124022019-01-11 13:25:59 +0000799 break;
800 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000801 case LayerType::Quantize:
802 {
803 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
804 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000805 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000806 break;
807 }
James Conroy586a9aa2020-03-20 08:49:33 +0000808 case LayerType::QLstm:
809 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100810 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000811 const QLstmDescriptor& descriptor = cLayer->GetParameters();
812
813 // Inputs
814 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
815 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
816 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
817
818 // Outputs
819 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
820 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
821 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
822
823 // Lstm parameters
824 LstmInputParamsInfo paramsInfo;
825
826 // Basic parameters
827 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
828 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
829 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
830
831 paramsInfo.m_RecurrentToForgetWeights =
832 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
833 paramsInfo.m_RecurrentToCellWeights =
834 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
835 paramsInfo.m_RecurrentToOutputWeights =
836 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
837
838 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
839 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
840 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
841
842 if(!descriptor.m_CifgEnabled)
843 {
844 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
845 paramsInfo.m_RecurrentToInputWeights =
846 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
847 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
848 }
849
850 if(descriptor.m_ProjectionEnabled)
851 {
852 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100853
854 // Projection bias is optional even if projection is enabled
855 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
856 {
857 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
858 }
James Conroy586a9aa2020-03-20 08:49:33 +0000859 }
860
861 if(descriptor.m_PeepholeEnabled)
862 {
863 if (!descriptor.m_CifgEnabled)
864 {
865 paramsInfo.m_CellToInputWeights =
866 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
867 }
868
869 paramsInfo.m_CellToForgetWeights =
870 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
871 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
872 }
873
874 if(descriptor.m_LayerNormEnabled)
875 {
876 if (!descriptor.m_CifgEnabled)
877 {
878 paramsInfo.m_InputLayerNormWeights =
879 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
880 }
881
882 paramsInfo.m_ForgetLayerNormWeights =
883 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
884 paramsInfo.m_CellLayerNormWeights =
885 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
886 paramsInfo.m_OutputLayerNormWeights =
887 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
888 }
889
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000890 result = layerSupportObject.IsQLstmSupported(input,
891 previousOutputIn,
892 previousCellStateIn,
893 outputStateOut,
894 cellStateOut,
895 output,
896 descriptor,
897 paramsInfo,
898 reason);
James Conroy586a9aa2020-03-20 08:49:33 +0000899 break;
900 }
James Conroyee18dc82019-07-17 11:27:46 +0100901 case LayerType::QuantizedLstm:
902 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100903 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100904
905 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100906 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
907 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
908 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100909
910 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100911 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
912 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100913
914 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100915 QuantizedLstmInputParamsInfo paramsInfo;
916
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100917 paramsInfo.m_InputToInputWeights =
918 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
919 paramsInfo.m_InputToForgetWeights =
920 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
921 paramsInfo.m_InputToCellWeights =
922 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
923 paramsInfo.m_InputToOutputWeights =
924 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100925
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100926 paramsInfo.m_RecurrentToInputWeights =
927 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
928 paramsInfo.m_RecurrentToForgetWeights =
929 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
930 paramsInfo.m_RecurrentToCellWeights =
931 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
932 paramsInfo.m_RecurrentToOutputWeights =
933 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100934
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100935 paramsInfo.m_InputGateBias =
936 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
937 paramsInfo.m_ForgetGateBias =
938 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
939 paramsInfo.m_CellBias =
940 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
941 paramsInfo.m_OutputGateBias =
942 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100943
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000944 result = layerSupportObject.IsQuantizedLstmSupported(input,
945 previousCellStateIn,
946 previousOutputIn,
947 cellStateOut,
948 output,
949 paramsInfo,
950 reason);
James Conroyee18dc82019-07-17 11:27:46 +0100951 break;
952 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100953 case LayerType::Division:
954 {
955 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
956 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
957 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000958 result = layerSupportObject.IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100959 OverrideDataType(input0, dataType),
960 OverrideDataType(input1, dataType),
961 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100962 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100963 break;
964 }
Finn Williams2605b232020-06-10 15:53:46 +0100965 case LayerType::Rank:
966 {
967 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
968 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000969 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
970 OverrideDataType(output, dataType),
971 reason);
Finn Williams2605b232020-06-10 15:53:46 +0100972 break;
973 }
telsoa014fcda012018-03-09 14:13:49 +0000974 case LayerType::Reshape:
975 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100976 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000977 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000978 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000979 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
980 OverrideDataType(output, dataType),
981 cLayer->GetParameters(),
982 reason);
telsoa014fcda012018-03-09 14:13:49 +0000983 break;
984 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100985 case LayerType::Resize:
986 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100987 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100988 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100989 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000990 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
991 OverrideDataType(output, dataType),
992 cLayer->GetParameters(),
993 reason);
Teresa Charlina9075df2019-06-27 15:41:57 +0100994 break;
995 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100996 case LayerType::Slice:
997 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100998 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100999
1000 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1001 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1002
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001003 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1004 OverrideDataType(output, dataType),
1005 cLayer->GetParameters(),
1006 reason);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001007 break;
1008 }
telsoa014fcda012018-03-09 14:13:49 +00001009 case LayerType::Softmax:
1010 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001011 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001012 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +01001013 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001014 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1015 OverrideDataType(output, dataType),
1016 cLayer->GetParameters(),
1017 reason);
telsoa014fcda012018-03-09 14:13:49 +00001018 break;
1019 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001020 case LayerType::SpaceToBatchNd:
1021 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001022 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001023 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1024 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001025 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1026 OverrideDataType(output, dataType),
1027 cLayer->GetParameters(),
1028 reason);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001029 break;
1030 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001031 case LayerType::SpaceToDepth:
1032 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001033 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001034
1035 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1036 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1037
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001038 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1039 OverrideDataType(output, dataType),
1040 cLayer->GetParameters(),
1041 reason);
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001042 break;
1043 }
telsoa014fcda012018-03-09 14:13:49 +00001044 case LayerType::Splitter:
1045 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001046 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +00001047 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001048
1049 // Get vector of all outputs.
1050 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1051 {
1052 return OverrideDataType(slot.GetTensorInfo(), dataType);
1053 };
Finn Williams3e54d032020-10-22 16:53:35 +01001054 auto beginI = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfo);
1055 auto endI = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfo);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001056 std::vector<TensorInfo> outputs(beginI, endI);
1057
1058 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1059
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001060 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1061 outputPtrs,
1062 cLayer->GetParameters(),
1063 reason);
telsoa014fcda012018-03-09 14:13:49 +00001064 break;
1065 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001066 case LayerType::Stack:
1067 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001068 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001069
1070 // Get vector of all inputs.
1071 auto getTensorInfo = [&dataType](const InputSlot& slot)
1072 {
1073 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1074 };
Finn Williams3e54d032020-10-22 16:53:35 +01001075 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfo);
1076 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfo);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001077 std::vector<TensorInfo> inputs(beginI, endI);
1078
1079 auto getTensorInfoPtr = [](const TensorInfo& info)
1080 {
1081 return &info;
1082 };
Finn Williams3e54d032020-10-22 16:53:35 +01001083 auto beginPtr = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1084 auto endPtr = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001085 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1086
1087 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1088
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001089 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001090
1091 break;
1092 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001093 case LayerType::StandIn:
1094 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001095 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001096
1097 // Get vector of all inputs.
1098 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1099 {
1100 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1101 };
1102 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1103 {
1104 return OverrideDataType(slot.GetTensorInfo(), dataType);
1105 };
Finn Williams3e54d032020-10-22 16:53:35 +01001106 auto beginI = MakeTransformIterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1107 auto endI = MakeTransformIterator(layer.GetInputSlots().end(), getTensorInfoIn);
Derek Lamberti013c3902019-10-21 10:46:16 +01001108 std::vector<TensorInfo> inputs(beginI, endI);
1109
Finn Williams3e54d032020-10-22 16:53:35 +01001110 auto beginO = MakeTransformIterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1111 auto endO = MakeTransformIterator(layer.GetOutputSlots().end(), getTensorInfoOut);
Derek Lamberti013c3902019-10-21 10:46:16 +01001112 std::vector<TensorInfo> outputs(beginO, endO);
1113
1114
1115 auto getTensorInfoPtr = [](const TensorInfo& info)
1116 {
1117 return &info;
1118 };
Finn Williams3e54d032020-10-22 16:53:35 +01001119 auto beginPtrI = MakeTransformIterator(inputs.begin(), getTensorInfoPtr);
1120 auto endPtrI = MakeTransformIterator(inputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001121 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1122
Finn Williams3e54d032020-10-22 16:53:35 +01001123 auto beginPtrO = MakeTransformIterator(outputs.begin(), getTensorInfoPtr);
1124 auto endPtrO = MakeTransformIterator(outputs.end(), getTensorInfoPtr);
Derek Lamberti013c3902019-10-21 10:46:16 +01001125 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1126
1127
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001128 result = layerSupportObject.IsStandInSupported(inputPtrs,
1129 outputPtrs,
1130 cLayer->GetParameters(),
1131 reason);
Derek Lamberti013c3902019-10-21 10:46:16 +01001132 break;
1133 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001134 case LayerType::StridedSlice:
1135 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001136 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001137 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1138 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001139 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1140 OverrideDataType(output, dataType),
1141 cLayer->GetParameters(),
1142 reason);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001143 break;
1144 }
David Beckc2044fe2018-09-05 15:00:38 +01001145 case LayerType::Subtraction:
1146 {
1147 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1148 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1149 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001150 result = layerSupportObject.IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001151 OverrideDataType(input0, dataType),
1152 OverrideDataType(input1, dataType),
1153 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001154 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001155 break;
1156 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001157 case LayerType::Switch:
1158 {
1159 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1160 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1161 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1162 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001163 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1164 OverrideDataType(input1, dataType),
1165 OverrideDataType(output0, dataType),
1166 OverrideDataType(output1, dataType),
1167 reason);
Sadik Armaganeff363d2019-04-05 15:25:46 +01001168 break;
1169 }
narpra0132b90462018-09-13 11:07:48 +01001170 case LayerType::Mean:
1171 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001172 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001173 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1174 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001175 result = layerSupportObject.IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001176 OverrideDataType(input, dataType),
1177 OverrideDataType(output, dataType),
1178 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001179 reason);
narpra0132b90462018-09-13 11:07:48 +01001180 break;
1181 }
kevmay0190539692018-11-29 08:40:19 +00001182 case LayerType::Minimum:
1183 {
1184 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1185 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001187 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1188 OverrideDataType(input1, dataType),
1189 OverrideDataType(output, dataType),
1190 reason);
kevmay0190539692018-11-29 08:40:19 +00001191 break;
1192 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001193 case LayerType::Prelu:
1194 {
1195 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1196 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1197 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001198 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1199 OverrideDataType(alpha, dataType),
1200 OverrideDataType(output, dataType),
1201 reason);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001202 break;
1203 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001204 case LayerType::Transpose:
1205 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001206 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001207 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1208 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001209 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1210 OverrideDataType(output, dataType),
1211 cLayer->GetParameters(),
1212 reason);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001213 break;
1214 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001215 case LayerType::TransposeConvolution2d:
1216 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001217 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001218
1219 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1220 dataType);
1221 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1222
1223 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1224
1225 Optional<TensorInfo> biases;
1226 if (descriptor.m_BiasEnabled)
1227 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001228 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001229 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1230 GetBiasTypeFromWeightsType(dataType));
1231 }
1232
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001233 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001234 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1235
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001236 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1237 output,
1238 descriptor,
1239 weights,
1240 biases,
1241 reason);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001242
1243 break;
1244 }
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001245 case LayerType::Reduce:
1246 {
1247 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1248 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1249 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1250
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001251 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1252 OverrideDataType(output, dataType),
1253 cLayer->GetParameters(),
1254 reason);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001255 break;
1256 }
telsoa014fcda012018-03-09 14:13:49 +00001257 default:
1258 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001259 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001260 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001261 result = false;
1262 break;
1263 }
1264 }
telsoa014fcda012018-03-09 14:13:49 +00001265 return result;
1266}
1267
Sadik Armagan045f6be2020-09-10 13:37:32 +01001268bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1269 const IConnectableLayer& connectableLayer,
1270 Optional<DataType> dataType,
1271 std::string& outReasonIfUnsupported)
1272{
1273 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1274}
1275
David Beckdcb751f2018-10-03 11:42:42 +01001276bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001277 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001278 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001279{
Jan Eilersbb446e52020-04-02 13:56:54 +01001280 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
Sadik Armagan045f6be2020-09-10 13:37:32 +01001281 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1282}
1283
1284// TODO merge with defaulted modelOptions above
1285bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1286 Optional<DataType> dataType,
1287 std::string& outReasonIfUnsupported,
1288 const ModelOptions& modelOptions)
1289{
1290 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1291 return IsLayerConfigurationSupported(layer->GetBackendId(),
1292 connectableLayer,
1293 dataType,
1294 outReasonIfUnsupported,
1295 modelOptions);
telsoa014fcda012018-03-09 14:13:49 +00001296}
1297
Sadik Armagan04a72972020-09-14 15:44:18 +01001298bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
1299 const IConnectableLayer& connectableLayer,
1300 Optional<DataType> dataType,
1301 std::string& outReasonIfUnsupported,
1302 const ModelOptions& modelOptions)
1303{
1304 return IsLayerConfigurationSupported(backendId,
1305 connectableLayer,
1306 dataType,
1307 outReasonIfUnsupported,
1308 modelOptions);
1309}
1310
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001311// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001312std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1313 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001314{
1315 return std::unique_ptr<IWorkload>();
1316}
1317
Derek Lamberti901ea112019-12-10 22:07:09 +00001318std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1319 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001320{
1321 return std::unique_ptr<IWorkload>();
1322}
1323
Derek Lamberti901ea112019-12-10 22:07:09 +00001324std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1325 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001326{
1327 return std::unique_ptr<IWorkload>();
1328}
1329
Derek Lamberti901ea112019-12-10 22:07:09 +00001330std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1331 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001332{
1333 return std::unique_ptr<IWorkload>();
1334}
1335
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001336std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001337 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001338{
1339 return std::unique_ptr<IWorkload>();
1340}
1341
Derek Lamberti901ea112019-12-10 22:07:09 +00001342std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1343 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001344{
1345 return std::unique_ptr<IWorkload>();
1346}
1347
Derek Lamberti901ea112019-12-10 22:07:09 +00001348std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1349 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001350{
1351 return std::unique_ptr<IWorkload>();
1352}
1353
Derek Lamberti901ea112019-12-10 22:07:09 +00001354std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1355 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001356{
1357 return std::unique_ptr<IWorkload>();
1358}
1359
Derek Lamberti901ea112019-12-10 22:07:09 +00001360std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1361 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001362{
1363 return std::unique_ptr<IWorkload>();
1364}
1365
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001366std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1367 const WorkloadInfo& /*info*/) const
1368{
1369 return std::unique_ptr<IWorkload>();
1370}
1371
Derek Lamberti901ea112019-12-10 22:07:09 +00001372std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1373 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001374{
1375 return std::unique_ptr<IWorkload>();
1376}
1377
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001378std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1379 const WorkloadInfo& /*info*/) const
1380{
1381 return std::unique_ptr<IWorkload>();
1382}
1383
Derek Lamberti901ea112019-12-10 22:07:09 +00001384std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1385 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001386{
1387 return std::unique_ptr<IWorkload>();
1388}
1389
Derek Lamberti901ea112019-12-10 22:07:09 +00001390std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1391 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001392{
1393 return std::unique_ptr<IWorkload>();
1394}
1395
Derek Lamberti901ea112019-12-10 22:07:09 +00001396std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1397 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001398{
1399 return std::unique_ptr<IWorkload>();
1400}
1401
Derek Lamberti901ea112019-12-10 22:07:09 +00001402std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1403 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001404{
1405 return std::unique_ptr<IWorkload>();
1406}
1407
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001408std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001409 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001410{
1411 return std::unique_ptr<IWorkload>();
1412}
1413
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001414std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001415 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001416{
1417 return std::unique_ptr<IWorkload>();
1418}
1419
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001420std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001421 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001422{
1423 return std::unique_ptr<IWorkload>();
1424}
1425
Derek Lamberti901ea112019-12-10 22:07:09 +00001426std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1427 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001428{
1429 return std::unique_ptr<IWorkload>();
1430}
1431
josh minor4a3c6102020-01-06 16:40:46 -06001432std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1433 const WorkloadInfo& /*info*/) const
1434{
1435 return std::unique_ptr<IWorkload>();
1436}
1437
Derek Lamberti901ea112019-12-10 22:07:09 +00001438std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1439 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001440{
1441 return std::unique_ptr<IWorkload>();
1442}
1443
Derek Lamberti901ea112019-12-10 22:07:09 +00001444std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1445 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001446{
1447 return std::unique_ptr<IWorkload>();
1448}
1449
Ryan OSheaec6c6802020-06-05 17:17:06 +01001450std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1451 const WorkloadInfo& /*info*/) const
1452{
1453 return std::unique_ptr<IWorkload>();
1454}
1455
Derek Lamberti901ea112019-12-10 22:07:09 +00001456std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1457 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001458{
1459 return std::unique_ptr<IWorkload>();
1460}
1461
Derek Lamberti901ea112019-12-10 22:07:09 +00001462std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1463 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001464{
1465 return std::unique_ptr<IWorkload>();
1466}
1467
Derek Lamberti901ea112019-12-10 22:07:09 +00001468std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1469 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001470{
1471 return std::unique_ptr<IWorkload>();
1472}
1473
Derek Lamberti901ea112019-12-10 22:07:09 +00001474std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1475 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001476{
1477 return std::unique_ptr<IWorkload>();
1478}
1479
Kevin Mayce5045a2019-10-02 14:07:47 +01001480std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001481 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1482 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001483{
1484 return std::unique_ptr<IWorkload>();
1485}
1486
Derek Lamberti901ea112019-12-10 22:07:09 +00001487std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1488 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001489{
1490 return std::unique_ptr<IWorkload>();
1491}
1492
James Conroyaba90cd2020-11-06 16:28:18 +00001493std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& /*desc*/,
1494 const WorkloadInfo& /*info*/) const
1495{
1496 return std::unique_ptr<IWorkload>();
1497}
1498
1499std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1500 const WorkloadInfo& /*info*/) const
1501{
1502 return std::unique_ptr<IWorkload>();
1503}
1504
Derek Lamberti901ea112019-12-10 22:07:09 +00001505std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1506 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001507{
1508 return std::unique_ptr<IWorkload>();
1509}
1510
Derek Lamberti901ea112019-12-10 22:07:09 +00001511std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1512 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001513{
1514 return std::unique_ptr<IWorkload>();
1515}
1516
Derek Lamberti901ea112019-12-10 22:07:09 +00001517std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1518 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001519{
1520 return std::unique_ptr<IWorkload>();
1521}
1522
Derek Lamberti901ea112019-12-10 22:07:09 +00001523std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1524 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001525{
1526 return std::unique_ptr<IWorkload>();
1527}
1528
Derek Lamberti901ea112019-12-10 22:07:09 +00001529std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1530 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001531{
1532 return std::unique_ptr<IWorkload>();
1533}
1534
Derek Lamberti901ea112019-12-10 22:07:09 +00001535std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1536 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001537{
1538 return std::unique_ptr<IWorkload>();
1539}
1540
Derek Lamberti901ea112019-12-10 22:07:09 +00001541std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1542 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001543{
1544 return std::unique_ptr<IWorkload>();
1545}
1546
Derek Lamberti901ea112019-12-10 22:07:09 +00001547std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1548 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001549{
1550 return std::unique_ptr<IWorkload>();
1551}
1552
Derek Lamberti901ea112019-12-10 22:07:09 +00001553std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1554 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001555{
1556 return std::unique_ptr<IWorkload>();
1557}
1558
Derek Lamberti901ea112019-12-10 22:07:09 +00001559std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1560 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001561{
1562 return std::unique_ptr<IWorkload>();
1563}
1564
Derek Lamberti901ea112019-12-10 22:07:09 +00001565std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1566 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001567{
1568 return std::unique_ptr<IWorkload>();
1569}
1570
Derek Lamberti901ea112019-12-10 22:07:09 +00001571std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1572 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001573{
1574 return std::unique_ptr<IWorkload>();
1575}
1576
Derek Lamberti901ea112019-12-10 22:07:09 +00001577std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1578 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001579{
1580 return std::unique_ptr<IWorkload>();
1581}
1582
Derek Lamberti901ea112019-12-10 22:07:09 +00001583std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001584 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001585{
1586 return std::unique_ptr<IWorkload>();
1587}
1588
Derek Lamberti901ea112019-12-10 22:07:09 +00001589std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1590 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001591{
1592 return std::unique_ptr<IWorkload>();
1593}
1594
Derek Lamberti901ea112019-12-10 22:07:09 +00001595std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1596 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001597{
1598 return std::unique_ptr<IWorkload>();
1599}
1600
Derek Lamberti901ea112019-12-10 22:07:09 +00001601std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1602 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001603{
1604 return std::unique_ptr<IWorkload>();
1605}
1606
Derek Lamberti901ea112019-12-10 22:07:09 +00001607std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1608 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001609{
1610 return std::unique_ptr<IWorkload>();
1611}
1612
James Conroy586a9aa2020-03-20 08:49:33 +00001613std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1614 const WorkloadInfo& /*info*/) const
1615{
1616 return std::unique_ptr<IWorkload>();
1617}
1618
Derek Lamberti901ea112019-12-10 22:07:09 +00001619std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1620 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001621{
1622 return std::unique_ptr<IWorkload>();
1623}
Finn Williams2605b232020-06-10 15:53:46 +01001624std::unique_ptr<IWorkload> IWorkloadFactory::CreateRank(const RankQueueDescriptor& /*descriptor*/,
1625 const WorkloadInfo& /*info*/) const
1626{
1627 return std::unique_ptr<IWorkload>();
1628}
James Conroyee18dc82019-07-17 11:27:46 +01001629
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001630std::unique_ptr<IWorkload> IWorkloadFactory::CreateReduce(const ReduceQueueDescriptor& /*descriptor*/,
1631 const WorkloadInfo& /*info*/) const
1632{
1633 return std::unique_ptr<IWorkload>();
1634}
1635
Derek Lamberti901ea112019-12-10 22:07:09 +00001636std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1637 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001638{
1639 return std::unique_ptr<IWorkload>();
1640}
1641
Derek Lamberti901ea112019-12-10 22:07:09 +00001642std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1643 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001644{
1645 return std::unique_ptr<IWorkload>();
1646}
1647
Derek Lamberti901ea112019-12-10 22:07:09 +00001648std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1649 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001650{
1651 return std::unique_ptr<IWorkload>();
1652}
1653
Derek Lamberti901ea112019-12-10 22:07:09 +00001654std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1655 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001656{
1657 return std::unique_ptr<IWorkload>();
1658}
1659
Derek Lamberti901ea112019-12-10 22:07:09 +00001660std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1661 const WorkloadInfo& /*info*/) const
1662{
1663 return std::unique_ptr<IWorkload>();
1664}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001665
Derek Lamberti901ea112019-12-10 22:07:09 +00001666std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1667 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001668{
1669 return std::unique_ptr<IWorkload>();
1670}
1671
Derek Lamberti901ea112019-12-10 22:07:09 +00001672std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1673 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001674{
1675 return std::unique_ptr<IWorkload>();
1676}
1677
Derek Lamberti901ea112019-12-10 22:07:09 +00001678std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1679 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001680{
1681 return std::unique_ptr<IWorkload>();
1682}
1683
Derek Lamberti901ea112019-12-10 22:07:09 +00001684std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1685 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001686{
1687 return std::unique_ptr<IWorkload>();
1688}
1689
Derek Lamberti901ea112019-12-10 22:07:09 +00001690std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1691 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001692{
1693 return std::unique_ptr<IWorkload>();
1694}
1695
Derek Lamberti901ea112019-12-10 22:07:09 +00001696std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1697 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001698{
1699 return std::unique_ptr<IWorkload>();
1700}
1701
Derek Lamberti901ea112019-12-10 22:07:09 +00001702std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1703 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001704{
1705 return std::unique_ptr<IWorkload>();
1706}
1707
Derek Lamberti901ea112019-12-10 22:07:09 +00001708std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1709 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001710{
1711 return std::unique_ptr<IWorkload>();
1712}
1713
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001714std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1715 const WorkloadInfo& /*info*/) const
1716{
1717 return std::unique_ptr<IWorkload>();
1718}
1719
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001720std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001721 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1722 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001723{
1724 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001725}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001726
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001727} // namepsace armnn