blob: d2565cf21d1d74eef48158e30fcc89108689474b [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010013#include <armnn/utility/PolymorphicDowncast.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000016#include <armnn/backends/IBackendInternal.hpp>
17#include <backendsCommon/CpuTensorHandle.hpp>
18#include <backendsCommon/WorkloadFactory.hpp>
19
Francis Murtagh46c09d02019-05-28 08:15:28 +010020#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/iterator/transform_iterator.hpp>
23
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000025#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000026
telsoa014fcda012018-03-09 14:13:49 +000027namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030namespace
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032
David Beck29c75de2018-10-23 13:35:58 +010033const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34{
35 if (!type)
36 {
37 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010038 }
39
David Beck29c75de2018-10-23 13:35:58 +010040 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
David Beck33f0ae02018-10-18 15:13:56 +010045bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010046 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010047 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010048 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
Jan Eilersbb446e52020-04-02 13:56:54 +010052 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
David Beckdcb751f2018-10-03 11:42:42 +010053
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
Jan Eilersbb446e52020-04-02 13:56:54 +010073 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +000074 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010076 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010088 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
Jan Eilersbb446e52020-04-02 13:56:54 +010097 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
Nikhil Rajee391d52019-09-05 17:50:44 +010098 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject->IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100118 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Jan Eilersbb446e52020-04-02 13:56:54 +0100133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000134
135 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
139 break;
140 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100141 case LayerType::Comparison:
142 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100143 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100144
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148
149 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
153 reason);
154 break;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 case LayerType::Constant:
157 {
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100159 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 break;
161 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000162 case LayerType::ConvertBf16ToFp32:
163 {
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
166 result = layerSupportObject->IsConvertBf16ToFp32Supported(input, output, reason);
167 break;
168 }
telsoa01c577f2c2018-08-31 09:22:23 +0100169 case LayerType::ConvertFp16ToFp32:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100173 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 break;
175 }
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000176 case LayerType::ConvertFp32ToBf16:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180 result = layerSupportObject->IsConvertFp32ToBf16Supported(input, output, reason);
181 break;
182 }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 case LayerType::ConvertFp32ToFp16:
184 {
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100187 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000188 break;
189 }
190 case LayerType::Convolution2d:
191 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100192 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100193
194 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
195 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100196 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100197 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
surmeh013537c2c2018-05-18 16:31:43 +0100198
arovir01a6824102018-08-28 17:40:45 +0100199 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100200
arovir01a6824102018-08-28 17:40:45 +0100201 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100202 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100203 if (descriptor.m_BiasEnabled)
204 {
David Beck5eec11d2018-10-04 15:43:17 +0100205 biases =
206 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100207 }
208
David Beck33f0ae02018-10-18 15:13:56 +0100209 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100210 input,
211 output,
212 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100213 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100214 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100215 reason);
telsoa014fcda012018-03-09 14:13:49 +0000216 break;
217 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000218 case LayerType::Debug:
219 {
220 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
222
223 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
224 OverrideDataType(output, dataType),
225 reason);
226 break;
227 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100228 case LayerType::DepthToSpace:
229 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100230 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100231
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
234
235 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
236 OverrideDataType(output, dataType),
237 cLayer->GetParameters(),
238 reason);
239 break;
240 }
telsoa014fcda012018-03-09 14:13:49 +0000241 case LayerType::DepthwiseConvolution2d:
242 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100243 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100244 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
245 dataType);
246 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100247 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100248
telsoa01c577f2c2018-08-31 09:22:23 +0100249 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100250
251 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100252 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100253 if (descriptor.m_BiasEnabled)
254 {
David Beck5eec11d2018-10-04 15:43:17 +0100255 biases =
256 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100257 }
telsoa01c577f2c2018-08-31 09:22:23 +0100258
David Beck33f0ae02018-10-18 15:13:56 +0100259 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100260 input,
261 output,
262 descriptor,
263 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100264 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100265 reason);
telsoa014fcda012018-03-09 14:13:49 +0000266 break;
267 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000268 case LayerType::Dequantize:
269 {
270 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
271 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
272
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000273 result = layerSupportObject->IsDequantizeSupported(input,
274 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000275 reason);
276 break;
277 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000278 case LayerType::DetectionPostProcess:
279 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100280 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000281 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
282 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
283 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
284
285 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
286 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
287 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
288 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
289
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000290 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000291 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
292 scores,
293 anchors,
294 detectionBoxes,
295 detectionClasses,
296 detectionScores,
297 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000298 descriptor,
299 reason);
300 break;
301 }
josh minor4a3c6102020-01-06 16:40:46 -0600302 case LayerType::ElementwiseUnary:
303 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100304 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
josh minor4a3c6102020-01-06 16:40:46 -0600305
306 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
307 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
308
309 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
310 OverrideDataType(output, dataType),
311 cLayer->GetParameters(),
312 reason);
313 break;
314 }
Ryan OSheaec6c6802020-06-05 17:17:06 +0100315 case LayerType::Fill:
316 {
317 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
318 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
319 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
320 const FillDescriptor& descriptor = cLayer->GetParameters();
321
322 result = layerSupportObject->IsFillSupported(
323 OverrideDataType(input, dataType),
324 OverrideDataType(output, dataType),
325 descriptor,
326 reason);
327 break;
328 }
telsoa014fcda012018-03-09 14:13:49 +0000329 case LayerType::FakeQuantization:
330 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100331 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000332 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100333 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
334 cLayer->GetParameters(),
335 reason);
telsoa014fcda012018-03-09 14:13:49 +0000336 break;
337 }
338 case LayerType::Floor:
339 {
340 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
341 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100342 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
343 OverrideDataType(output, dataType),
344 reason);
telsoa014fcda012018-03-09 14:13:49 +0000345 break;
346 }
347 case LayerType::FullyConnected:
348 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100349 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000350 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100351 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100352 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100353
354 TensorInfo biasInfo;
355 const TensorInfo * biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000356 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100357 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
358 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
359 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
360
361 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
362 if (descriptor.m_BiasEnabled)
363 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100364 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100365 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
366 biasInfoPtr = &biasInfo;
367 }
368 else
369 {
370 // If biases are not enabled pass a dummy tensorinfo for the validation
371 switch(input.GetDataType())
372 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000373 case DataType::BFloat16:
374 {
375 biasInfoPtr = &dummyBFloat16Bias;
376 break;
377 }
telsoa01c577f2c2018-08-31 09:22:23 +0100378 case DataType::Float16:
379 {
380 biasInfoPtr = &dummyFloat16Bias;
381 break;
382 }
383 case DataType::Float32:
384 {
385 biasInfoPtr = &dummyFloat32Bias;
386 break;
387 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000388 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000389 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000390 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000391 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100392 {
393 biasInfoPtr = &dummyQA8Bias;
394 break;
395 }
396 default:
397 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100398 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
telsoa01c577f2c2018-08-31 09:22:23 +0100399 }
400 }
401 }
402
David Beck33f0ae02018-10-18 15:13:56 +0100403 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100404 OverrideDataType(input, dataType),
405 OverrideDataType(output, dataType),
406 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
407 *biasInfoPtr,
408 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100409 reason);
telsoa014fcda012018-03-09 14:13:49 +0000410 break;
411 }
narpra01b89b05f2019-01-16 09:53:09 +0000412 case LayerType::Gather:
413 {
414 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
415 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
416 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
417 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100418 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000419 OverrideDataType(output, dataType),
420 reason);
421 break;
422 }
telsoa014fcda012018-03-09 14:13:49 +0000423 case LayerType::Input:
424 {
425 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100426 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000427 break;
428 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100429 case LayerType::InstanceNormalization:
430 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100431 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
Kevin Mayce5045a2019-10-02 14:07:47 +0100432 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
433
434 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
435 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
436
437 result = layerSupportObject->IsInstanceNormalizationSupported(
438 OverrideDataType(input, dataType),
439 OverrideDataType(output, dataType),
440 descriptor,
441 reason);
442 break;
443 }
telsoa014fcda012018-03-09 14:13:49 +0000444 case LayerType::L2Normalization:
445 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100446 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100447 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
448
telsoa014fcda012018-03-09 14:13:49 +0000449 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100450 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100451
David Beck33f0ae02018-10-18 15:13:56 +0100452 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100453 OverrideDataType(input, dataType),
454 OverrideDataType(output, dataType),
455 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100456 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100457 break;
458 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100459 case LayerType::LogSoftmax:
460 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100461 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100462
463 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
464 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
465
466 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
467 OverrideDataType(output, dataType),
468 cLayer->GetParameters(),
469 reason);
470 break;
471 }
telsoa01c577f2c2018-08-31 09:22:23 +0100472 case LayerType::Lstm:
473 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100474 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100475 const LstmDescriptor& descriptor = cLayer->GetParameters();
476
477 // All inputs.
478 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
479 dataType);
480 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
481 dataType);
482 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
483 dataType);
484 // All outputs
485 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
486 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
487 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
488 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
489
490 // Basic parameters
491 const TensorInfo& inputToForgetWeights
492 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
493 const TensorInfo& inputToCellWeights
494 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
495 const TensorInfo& inputToOutputWeights
496 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
497 const TensorInfo& recurrentToForgetWeights
498 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
499 const TensorInfo& recurrentToCellWeights
500 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
501 const TensorInfo& recurrentToOutputWeights
502 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
503 const TensorInfo& forgetGateBias
504 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
505 const TensorInfo& cellBias
506 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
507 const TensorInfo& outputGateBias
508 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
509
Jan Eilersd01a83c2019-07-03 18:20:40 +0100510 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100511
Jan Eilersd01a83c2019-07-03 18:20:40 +0100512 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
513 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
514 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
515 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
516 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
517 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
518 paramsInfo.m_ForgetGateBias = &forgetGateBias;
519 paramsInfo.m_CellBias = &cellBias;
520 paramsInfo.m_OutputGateBias = &outputGateBias;
521
522
523 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100524 TensorInfo optInputToInputWeights;
525 TensorInfo optRecurrentToInputWeights;
526 TensorInfo optCellToInputWeights;
527 TensorInfo optInputGateBias;
528 TensorInfo optProjectionWeights;
529 TensorInfo optProjectionBias;
530 TensorInfo optCellToForgetWeights;
531 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100532 TensorInfo optInputLayerNormWeights;
533 TensorInfo optForgetLayerNormWeights;
534 TensorInfo optCellLayerNormWeights;
535 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100536
537 if(!descriptor.m_CifgEnabled)
538 {
539 optInputToInputWeights =
540 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100541 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100542
543 optRecurrentToInputWeights =
544 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100545 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100546 optInputGateBias =
547 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100548 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100549 }
550
551 if(descriptor.m_ProjectionEnabled)
552 {
553 optProjectionWeights =
554 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100555 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100556 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
557 {
558 optProjectionBias =
559 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100560 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100561 }
562 }
563
564 if(descriptor.m_PeepholeEnabled)
565 {
Jan Eilerse2062cd2020-03-30 15:07:45 +0100566 if(!descriptor.m_CifgEnabled)
567 {
568 optCellToInputWeights =
569 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
570 dataType);
571 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
572 }
telsoa01c577f2c2018-08-31 09:22:23 +0100573 optCellToForgetWeights =
574 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100575 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100576 optCellToOutputWeights =
577 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100578 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100579 }
580
Jan Eilers38e05bd2019-06-26 13:10:09 +0100581 if(descriptor.m_LayerNormEnabled)
582 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100583 if (!descriptor.m_CifgEnabled)
584 {
585 optInputLayerNormWeights = OverrideDataType(
586 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
587 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
588 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100589
590 optForgetLayerNormWeights = OverrideDataType(
591 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100592 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100593
594 optCellLayerNormWeights = OverrideDataType(
595 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100596 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100597
598 optOutputLayerNormWeights = OverrideDataType(
599 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100600 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100601 }
602
David Beck33f0ae02018-10-18 15:13:56 +0100603 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100604 input,
605 outputStateIn,
606 cellStateIn,
607 scratchBuffer,
608 outputStateOut,
609 cellStateOut,
610 output,
611 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100612 paramsInfo,
613 reason);
telsoa014fcda012018-03-09 14:13:49 +0000614 break;
615 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000616 case LayerType::Maximum:
617 {
618 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
619 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
620 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
621
622 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
623 OverrideDataType(input1, dataType),
624 OverrideDataType(output, dataType),
625 reason);
626 break;
627 }
narpra01b89b05f2019-01-16 09:53:09 +0000628 case LayerType::MemCopy:
629 {
630 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
631 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000632
narpra01b89b05f2019-01-16 09:53:09 +0000633 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
634 OverrideDataType(output, dataType),
635 reason);
636 break;
637 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100638 case LayerType::MemImport:
639 {
640 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
641 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
642
643 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
644 OverrideDataType(output, dataType),
645 reason);
646 break;
647 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100648 case LayerType::Merge:
649 {
650 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
651 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
652 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
653
654 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
655 OverrideDataType(input1, dataType),
656 OverrideDataType(output, dataType),
657 reason);
658 break;
659 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100660 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000661 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100662 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000663
telsoa01c577f2c2018-08-31 09:22:23 +0100664 // Get vector of all inputs.
665 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000666 {
telsoa01c577f2c2018-08-31 09:22:23 +0100667 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000668 };
telsoa01c577f2c2018-08-31 09:22:23 +0100669 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
670 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
671 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000672
telsoa01c577f2c2018-08-31 09:22:23 +0100673 auto getTensorInfoPtr = [](const TensorInfo& info)
674 {
675 return &info;
676 };
677 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
678 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
679 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000680
Nikhil Raj8599a412018-11-19 14:51:07 +0000681 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
682
Jim Flynne242f2d2019-05-22 14:24:13 +0100683 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
684
685
telsoa014fcda012018-03-09 14:13:49 +0000686 break;
687 }
688 case LayerType::Multiplication:
689 {
690 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
691 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100692 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100693 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100694 OverrideDataType(input0, dataType),
695 OverrideDataType(input1, dataType),
696 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100697 reason);
telsoa014fcda012018-03-09 14:13:49 +0000698 break;
699 }
700 case LayerType::Normalization:
701 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100702 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000703 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
704 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100705 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
706 OverrideDataType(output, dataType),
707 cLayer->GetParameters(),
708 reason);
telsoa014fcda012018-03-09 14:13:49 +0000709 break;
710 }
711 case LayerType::Output:
712 {
713 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100714 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000715 break;
716 }
717 case LayerType::Permute:
718 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100719 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000720 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
721 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100722 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
723 OverrideDataType(output, dataType),
724 cLayer->GetParameters(),
725 reason);
telsoa014fcda012018-03-09 14:13:49 +0000726 break;
727 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100728 case LayerType::Pad:
729 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100730 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100731 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100733 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100734 OverrideDataType(input, dataType),
735 OverrideDataType(output, dataType),
736 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100737 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100738 break;
739 }
telsoa014fcda012018-03-09 14:13:49 +0000740 case LayerType::Pooling2d:
741 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100742 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000743 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
744 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100745 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
746 OverrideDataType(output, dataType),
747 cLayer->GetParameters(),
748 reason);
telsoa014fcda012018-03-09 14:13:49 +0000749 break;
750 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000751 case LayerType::PreCompiled:
752 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100753 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
Matteo Martincigh49124022019-01-11 13:25:59 +0000754 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
755 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
756 cLayer->GetParameters(),
757 reason);
758 break;
759 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000760 case LayerType::Quantize:
761 {
762 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
763 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
764 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
765 break;
766 }
James Conroy586a9aa2020-03-20 08:49:33 +0000767 case LayerType::QLstm:
768 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100769 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
James Conroy586a9aa2020-03-20 08:49:33 +0000770 const QLstmDescriptor& descriptor = cLayer->GetParameters();
771
772 // Inputs
773 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
774 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
775 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
776
777 // Outputs
778 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
779 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
780 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
781
782 // Lstm parameters
783 LstmInputParamsInfo paramsInfo;
784
785 // Basic parameters
786 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
787 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
788 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
789
790 paramsInfo.m_RecurrentToForgetWeights =
791 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
792 paramsInfo.m_RecurrentToCellWeights =
793 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
794 paramsInfo.m_RecurrentToOutputWeights =
795 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
796
797 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
798 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
799 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
800
801 if(!descriptor.m_CifgEnabled)
802 {
803 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
804 paramsInfo.m_RecurrentToInputWeights =
805 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
806 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
807 }
808
809 if(descriptor.m_ProjectionEnabled)
810 {
811 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
James Conroyed324052020-05-18 15:16:42 +0100812
813 // Projection bias is optional even if projection is enabled
814 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
815 {
816 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
817 }
James Conroy586a9aa2020-03-20 08:49:33 +0000818 }
819
820 if(descriptor.m_PeepholeEnabled)
821 {
822 if (!descriptor.m_CifgEnabled)
823 {
824 paramsInfo.m_CellToInputWeights =
825 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
826 }
827
828 paramsInfo.m_CellToForgetWeights =
829 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
830 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
831 }
832
833 if(descriptor.m_LayerNormEnabled)
834 {
835 if (!descriptor.m_CifgEnabled)
836 {
837 paramsInfo.m_InputLayerNormWeights =
838 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
839 }
840
841 paramsInfo.m_ForgetLayerNormWeights =
842 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
843 paramsInfo.m_CellLayerNormWeights =
844 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
845 paramsInfo.m_OutputLayerNormWeights =
846 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
847 }
848
849 result = layerSupportObject->IsQLstmSupported(input,
850 previousOutputIn,
851 previousCellStateIn,
852 outputStateOut,
853 cellStateOut,
854 output,
855 descriptor,
856 paramsInfo,
857 reason);
858 break;
859 }
James Conroyee18dc82019-07-17 11:27:46 +0100860 case LayerType::QuantizedLstm:
861 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100862 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
James Conroyee18dc82019-07-17 11:27:46 +0100863
864 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100865 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
866 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
867 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100868
869 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100870 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
871 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100872
873 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100874 QuantizedLstmInputParamsInfo paramsInfo;
875
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100876 paramsInfo.m_InputToInputWeights =
877 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
878 paramsInfo.m_InputToForgetWeights =
879 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
880 paramsInfo.m_InputToCellWeights =
881 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
882 paramsInfo.m_InputToOutputWeights =
883 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100884
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100885 paramsInfo.m_RecurrentToInputWeights =
886 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
887 paramsInfo.m_RecurrentToForgetWeights =
888 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
889 paramsInfo.m_RecurrentToCellWeights =
890 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
891 paramsInfo.m_RecurrentToOutputWeights =
892 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100893
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100894 paramsInfo.m_InputGateBias =
895 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
896 paramsInfo.m_ForgetGateBias =
897 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
898 paramsInfo.m_CellBias =
899 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
900 paramsInfo.m_OutputGateBias =
901 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100902
903 result = layerSupportObject->IsQuantizedLstmSupported(input,
904 previousCellStateIn,
905 previousOutputIn,
906 cellStateOut,
907 output,
908 paramsInfo,
909 reason);
910 break;
911 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100912 case LayerType::Division:
913 {
914 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
915 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
916 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100917 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100918 OverrideDataType(input0, dataType),
919 OverrideDataType(input1, dataType),
920 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100921 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100922 break;
923 }
telsoa014fcda012018-03-09 14:13:49 +0000924 case LayerType::Reshape:
925 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100926 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000927 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000928 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000929 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
Kevin Maya023c402019-12-12 17:28:05 +0000930 OverrideDataType(output, dataType),
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000931 cLayer->GetParameters(),
932 reason);
telsoa014fcda012018-03-09 14:13:49 +0000933 break;
934 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100935 case LayerType::Resize:
936 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100937 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100938 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100939 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
940 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
941 OverrideDataType(output, dataType),
942 cLayer->GetParameters(),
943 reason);
944 break;
945 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100946 case LayerType::Slice:
947 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100948 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100949
950 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
951 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
952
953 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
954 OverrideDataType(output, dataType),
955 cLayer->GetParameters(),
956 reason);
957 break;
958 }
telsoa014fcda012018-03-09 14:13:49 +0000959 case LayerType::Softmax:
960 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100961 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000962 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100963 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100964 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
965 OverrideDataType(output, dataType),
966 cLayer->GetParameters(),
967 reason);
telsoa014fcda012018-03-09 14:13:49 +0000968 break;
969 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000970 case LayerType::SpaceToBatchNd:
971 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100972 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000973 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
974 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
975 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
976 OverrideDataType(output, dataType),
977 cLayer->GetParameters(),
978 reason);
979 break;
980 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100981 case LayerType::SpaceToDepth:
982 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100983 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100984
985 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
986 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
987
988 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
989 OverrideDataType(output, dataType),
990 cLayer->GetParameters(),
991 reason);
992 break;
993 }
telsoa014fcda012018-03-09 14:13:49 +0000994 case LayerType::Splitter:
995 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100996 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000997 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100998
999 // Get vector of all outputs.
1000 auto getTensorInfo = [&dataType](const OutputSlot& slot)
1001 {
1002 return OverrideDataType(slot.GetTensorInfo(), dataType);
1003 };
1004 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
1005 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
1006 std::vector<TensorInfo> outputs(beginI, endI);
1007
1008 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1009
David Beck33f0ae02018-10-18 15:13:56 +01001010 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001011 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +01001012 cLayer->GetParameters(),
1013 reason);
telsoa014fcda012018-03-09 14:13:49 +00001014 break;
1015 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001016 case LayerType::Stack:
1017 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001018 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001019
1020 // Get vector of all inputs.
1021 auto getTensorInfo = [&dataType](const InputSlot& slot)
1022 {
1023 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1024 };
1025 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
1026 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
1027 std::vector<TensorInfo> inputs(beginI, endI);
1028
1029 auto getTensorInfoPtr = [](const TensorInfo& info)
1030 {
1031 return &info;
1032 };
1033 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1034 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1035 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1036
1037 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1038
1039 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1040
1041 break;
1042 }
Derek Lamberti013c3902019-10-21 10:46:16 +01001043 case LayerType::StandIn:
1044 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001045 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
Derek Lamberti013c3902019-10-21 10:46:16 +01001046
1047 // Get vector of all inputs.
1048 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1049 {
1050 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1051 };
1052 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1053 {
1054 return OverrideDataType(slot.GetTensorInfo(), dataType);
1055 };
1056 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1057 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
1058 std::vector<TensorInfo> inputs(beginI, endI);
1059
1060 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1061 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
1062 std::vector<TensorInfo> outputs(beginO, endO);
1063
1064
1065 auto getTensorInfoPtr = [](const TensorInfo& info)
1066 {
1067 return &info;
1068 };
1069 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1070 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1071 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1072
1073 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
1074 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
1075 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1076
1077
1078 result = layerSupportObject->IsStandInSupported(inputPtrs,
1079 outputPtrs,
1080 cLayer->GetParameters(),
1081 reason);
1082 break;
1083 }
Conor Kennedy430b5d82018-11-14 15:28:28 +00001084 case LayerType::StridedSlice:
1085 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001086 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
Conor Kennedy430b5d82018-11-14 15:28:28 +00001087 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1088 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1089 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
1090 OverrideDataType(output, dataType),
1091 cLayer->GetParameters(),
1092 reason);
1093 break;
1094 }
David Beckc2044fe2018-09-05 15:00:38 +01001095 case LayerType::Subtraction:
1096 {
1097 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1098 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1099 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001100 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +01001101 OverrideDataType(input0, dataType),
1102 OverrideDataType(input1, dataType),
1103 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +01001104 reason);
David Beckc2044fe2018-09-05 15:00:38 +01001105 break;
1106 }
Sadik Armaganeff363d2019-04-05 15:25:46 +01001107 case LayerType::Switch:
1108 {
1109 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1110 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1111 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1112 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1113 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
1114 OverrideDataType(input1, dataType),
1115 OverrideDataType(output0, dataType),
1116 OverrideDataType(output1, dataType),
1117 reason);
1118 break;
1119 }
narpra0132b90462018-09-13 11:07:48 +01001120 case LayerType::Mean:
1121 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001122 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
narpra0132b90462018-09-13 11:07:48 +01001123 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1124 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001125 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001126 OverrideDataType(input, dataType),
1127 OverrideDataType(output, dataType),
1128 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001129 reason);
narpra0132b90462018-09-13 11:07:48 +01001130 break;
1131 }
kevmay0190539692018-11-29 08:40:19 +00001132 case LayerType::Minimum:
1133 {
1134 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1135 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1137 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1138 OverrideDataType(input1, dataType),
1139 OverrideDataType(output, dataType),
1140 reason);
1141 break;
1142 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001143 case LayerType::Prelu:
1144 {
1145 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1146 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1148 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1149 OverrideDataType(alpha, dataType),
1150 OverrideDataType(output, dataType),
1151 reason);
1152 break;
1153 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001154 case LayerType::Transpose:
1155 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001156 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001157 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1159 result = layerSupportObject->IsTransposeSupported(OverrideDataType(input, dataType),
1160 OverrideDataType(output, dataType),
1161 cLayer->GetParameters(),
1162 reason);
1163 break;
1164 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001165 case LayerType::TransposeConvolution2d:
1166 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001167 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001168
1169 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1170 dataType);
1171 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1172
1173 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1174
1175 Optional<TensorInfo> biases;
1176 if (descriptor.m_BiasEnabled)
1177 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001178 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001179 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1180 GetBiasTypeFromWeightsType(dataType));
1181 }
1182
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001183 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001184 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1185
1186 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1187 output,
1188 descriptor,
1189 weights,
1190 biases,
1191 reason);
1192
1193 break;
1194 }
telsoa014fcda012018-03-09 14:13:49 +00001195 default:
1196 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001197 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001198 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001199 result = false;
1200 break;
1201 }
1202 }
telsoa014fcda012018-03-09 14:13:49 +00001203 return result;
1204}
1205
David Beckdcb751f2018-10-03 11:42:42 +01001206bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001207 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001208 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001209{
Jan Eilersbb446e52020-04-02 13:56:54 +01001210 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +01001211 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +00001212}
1213
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001214// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001215std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1216 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001217{
1218 return std::unique_ptr<IWorkload>();
1219}
1220
Derek Lamberti901ea112019-12-10 22:07:09 +00001221std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1222 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001223{
1224 return std::unique_ptr<IWorkload>();
1225}
1226
Derek Lamberti901ea112019-12-10 22:07:09 +00001227std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1228 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001229{
1230 return std::unique_ptr<IWorkload>();
1231}
1232
Derek Lamberti901ea112019-12-10 22:07:09 +00001233std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1234 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001235{
1236 return std::unique_ptr<IWorkload>();
1237}
1238
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001239std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001240 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001241{
1242 return std::unique_ptr<IWorkload>();
1243}
1244
Derek Lamberti901ea112019-12-10 22:07:09 +00001245std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1246 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001247{
1248 return std::unique_ptr<IWorkload>();
1249}
1250
Derek Lamberti901ea112019-12-10 22:07:09 +00001251std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1252 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001253{
1254 return std::unique_ptr<IWorkload>();
1255}
1256
Derek Lamberti901ea112019-12-10 22:07:09 +00001257std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1258 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001259{
1260 return std::unique_ptr<IWorkload>();
1261}
1262
Derek Lamberti901ea112019-12-10 22:07:09 +00001263std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1264 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001265{
1266 return std::unique_ptr<IWorkload>();
1267}
1268
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001269std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1270 const WorkloadInfo& /*info*/) const
1271{
1272 return std::unique_ptr<IWorkload>();
1273}
1274
Derek Lamberti901ea112019-12-10 22:07:09 +00001275std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1276 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001277{
1278 return std::unique_ptr<IWorkload>();
1279}
1280
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001281std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1282 const WorkloadInfo& /*info*/) const
1283{
1284 return std::unique_ptr<IWorkload>();
1285}
1286
Derek Lamberti901ea112019-12-10 22:07:09 +00001287std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1288 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001289{
1290 return std::unique_ptr<IWorkload>();
1291}
1292
Derek Lamberti901ea112019-12-10 22:07:09 +00001293std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1294 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001295{
1296 return std::unique_ptr<IWorkload>();
1297}
1298
Derek Lamberti901ea112019-12-10 22:07:09 +00001299std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1300 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001301{
1302 return std::unique_ptr<IWorkload>();
1303}
1304
Derek Lamberti901ea112019-12-10 22:07:09 +00001305std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1306 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001307{
1308 return std::unique_ptr<IWorkload>();
1309}
1310
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001311std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001312 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001313{
1314 return std::unique_ptr<IWorkload>();
1315}
1316
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001317std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001318 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001319{
1320 return std::unique_ptr<IWorkload>();
1321}
1322
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001323std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001324 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001325{
1326 return std::unique_ptr<IWorkload>();
1327}
1328
Derek Lamberti901ea112019-12-10 22:07:09 +00001329std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1330 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001331{
1332 return std::unique_ptr<IWorkload>();
1333}
1334
josh minor4a3c6102020-01-06 16:40:46 -06001335std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1336 const WorkloadInfo& /*info*/) const
1337{
1338 return std::unique_ptr<IWorkload>();
1339}
1340
Derek Lamberti901ea112019-12-10 22:07:09 +00001341std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1342 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001343{
1344 return std::unique_ptr<IWorkload>();
1345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1348 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001349{
1350 return std::unique_ptr<IWorkload>();
1351}
1352
Ryan OSheaec6c6802020-06-05 17:17:06 +01001353std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
1354 const WorkloadInfo& /*info*/) const
1355{
1356 return std::unique_ptr<IWorkload>();
1357}
1358
Derek Lamberti901ea112019-12-10 22:07:09 +00001359std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1360 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001361{
1362 return std::unique_ptr<IWorkload>();
1363}
1364
Derek Lamberti901ea112019-12-10 22:07:09 +00001365std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1366 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001367{
1368 return std::unique_ptr<IWorkload>();
1369}
1370
Derek Lamberti901ea112019-12-10 22:07:09 +00001371std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1372 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001373{
1374 return std::unique_ptr<IWorkload>();
1375}
1376
Derek Lamberti901ea112019-12-10 22:07:09 +00001377std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1378 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001379{
1380 return std::unique_ptr<IWorkload>();
1381}
1382
Kevin Mayce5045a2019-10-02 14:07:47 +01001383std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001384 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1385 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001386{
1387 return std::unique_ptr<IWorkload>();
1388}
1389
Derek Lamberti901ea112019-12-10 22:07:09 +00001390std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1391 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001392{
1393 return std::unique_ptr<IWorkload>();
1394}
1395
Derek Lamberti901ea112019-12-10 22:07:09 +00001396std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1397 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001398{
1399 return std::unique_ptr<IWorkload>();
1400}
1401
Derek Lamberti901ea112019-12-10 22:07:09 +00001402std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1403 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001404{
1405 return std::unique_ptr<IWorkload>();
1406}
1407
Derek Lamberti901ea112019-12-10 22:07:09 +00001408std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1409 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001410{
1411 return std::unique_ptr<IWorkload>();
1412}
1413
Derek Lamberti901ea112019-12-10 22:07:09 +00001414std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1415 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001416{
1417 return std::unique_ptr<IWorkload>();
1418}
1419
Derek Lamberti901ea112019-12-10 22:07:09 +00001420std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1421 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001422{
1423 return std::unique_ptr<IWorkload>();
1424}
1425
Derek Lamberti901ea112019-12-10 22:07:09 +00001426std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1427 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001428{
1429 return std::unique_ptr<IWorkload>();
1430}
1431
Derek Lamberti901ea112019-12-10 22:07:09 +00001432std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1433 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001434{
1435 return std::unique_ptr<IWorkload>();
1436}
1437
Derek Lamberti901ea112019-12-10 22:07:09 +00001438std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1439 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001440{
1441 return std::unique_ptr<IWorkload>();
1442}
1443
Derek Lamberti901ea112019-12-10 22:07:09 +00001444std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1445 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001446{
1447 return std::unique_ptr<IWorkload>();
1448}
1449
Derek Lamberti901ea112019-12-10 22:07:09 +00001450std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1451 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001452{
1453 return std::unique_ptr<IWorkload>();
1454}
1455
Derek Lamberti901ea112019-12-10 22:07:09 +00001456std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1457 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001458{
1459 return std::unique_ptr<IWorkload>();
1460}
1461
Derek Lamberti901ea112019-12-10 22:07:09 +00001462std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1463 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001464{
1465 return std::unique_ptr<IWorkload>();
1466}
1467
Derek Lamberti901ea112019-12-10 22:07:09 +00001468std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1469 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001470{
1471 return std::unique_ptr<IWorkload>();
1472}
1473
Derek Lamberti901ea112019-12-10 22:07:09 +00001474std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001475 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001476{
1477 return std::unique_ptr<IWorkload>();
1478}
1479
Derek Lamberti901ea112019-12-10 22:07:09 +00001480std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1481 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001482{
1483 return std::unique_ptr<IWorkload>();
1484}
1485
Derek Lamberti901ea112019-12-10 22:07:09 +00001486std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1487 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001488{
1489 return std::unique_ptr<IWorkload>();
1490}
1491
Derek Lamberti901ea112019-12-10 22:07:09 +00001492std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1493 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001494{
1495 return std::unique_ptr<IWorkload>();
1496}
1497
Derek Lamberti901ea112019-12-10 22:07:09 +00001498std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1499 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001500{
1501 return std::unique_ptr<IWorkload>();
1502}
1503
James Conroy586a9aa2020-03-20 08:49:33 +00001504std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1505 const WorkloadInfo& /*info*/) const
1506{
1507 return std::unique_ptr<IWorkload>();
1508}
1509
Derek Lamberti901ea112019-12-10 22:07:09 +00001510std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1511 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001512{
1513 return std::unique_ptr<IWorkload>();
1514}
1515
Derek Lamberti901ea112019-12-10 22:07:09 +00001516std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1517 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001518{
1519 return std::unique_ptr<IWorkload>();
1520}
1521
Derek Lamberti901ea112019-12-10 22:07:09 +00001522std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1523 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001524{
1525 return std::unique_ptr<IWorkload>();
1526}
1527
Derek Lamberti901ea112019-12-10 22:07:09 +00001528std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1529 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001530{
1531 return std::unique_ptr<IWorkload>();
1532}
1533
Derek Lamberti901ea112019-12-10 22:07:09 +00001534std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1535 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001536{
1537 return std::unique_ptr<IWorkload>();
1538}
1539
Derek Lamberti901ea112019-12-10 22:07:09 +00001540std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1541 const WorkloadInfo& /*info*/) const
1542{
1543 return std::unique_ptr<IWorkload>();
1544}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001545
Derek Lamberti901ea112019-12-10 22:07:09 +00001546std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1547 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001548{
1549 return std::unique_ptr<IWorkload>();
1550}
1551
Derek Lamberti901ea112019-12-10 22:07:09 +00001552std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1553 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001554{
1555 return std::unique_ptr<IWorkload>();
1556}
1557
Derek Lamberti901ea112019-12-10 22:07:09 +00001558std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1559 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001560{
1561 return std::unique_ptr<IWorkload>();
1562}
1563
Derek Lamberti901ea112019-12-10 22:07:09 +00001564std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1565 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001566{
1567 return std::unique_ptr<IWorkload>();
1568}
1569
Derek Lamberti901ea112019-12-10 22:07:09 +00001570std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1571 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001572{
1573 return std::unique_ptr<IWorkload>();
1574}
1575
Derek Lamberti901ea112019-12-10 22:07:09 +00001576std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1577 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001578{
1579 return std::unique_ptr<IWorkload>();
1580}
1581
Derek Lamberti901ea112019-12-10 22:07:09 +00001582std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1583 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001584{
1585 return std::unique_ptr<IWorkload>();
1586}
1587
Derek Lamberti901ea112019-12-10 22:07:09 +00001588std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1589 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001590{
1591 return std::unique_ptr<IWorkload>();
1592}
1593
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001594std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1595 const WorkloadInfo& /*info*/) const
1596{
1597 return std::unique_ptr<IWorkload>();
1598}
1599
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001600std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001601 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1602 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001603{
1604 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001605}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001606
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001607} // namepsace armnn