blob: d932eef49fbd52bde26a052ade3bfec4ea5ed663 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <Layer.hpp>
7#include <LayersFwd.hpp>
David Beckdcb751f2018-10-03 11:42:42 +01008
David Beckb4540be2018-09-24 13:18:27 +01009#include <armnn/Types.hpp>
10#include <armnn/LayerSupport.hpp>
David Beck111b5d92018-11-12 14:59:37 +000011#include <armnn/ILayerSupport.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000012#include <armnn/BackendRegistry.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/WorkloadFactory.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000015#include <armnn/backends/IBackendInternal.hpp>
16#include <backendsCommon/CpuTensorHandle.hpp>
17#include <backendsCommon/WorkloadFactory.hpp>
18
Francis Murtagh46c09d02019-05-28 08:15:28 +010019#include <backendsCommon/test/WorkloadTestUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/iterator/transform_iterator.hpp>
23
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000024#include <cstring>
David Beck111b5d92018-11-12 14:59:37 +000025#include <sstream>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000026
telsoa014fcda012018-03-09 14:13:49 +000027namespace armnn
28{
29
telsoa01c577f2c2018-08-31 09:22:23 +010030namespace
31{
telsoa01c577f2c2018-08-31 09:22:23 +010032
David Beck29c75de2018-10-23 13:35:58 +010033const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34{
35 if (!type)
36 {
37 return info;
telsoa01c577f2c2018-08-31 09:22:23 +010038 }
39
David Beck29c75de2018-10-23 13:35:58 +010040 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
telsoa01c577f2c2018-08-31 09:22:23 +010041}
42
David Beck29c75de2018-10-23 13:35:58 +010043} // anonymous namespace
44
David Beck33f0ae02018-10-18 15:13:56 +010045bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
David Beckdcb751f2018-10-03 11:42:42 +010046 const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +010047 Optional<DataType> dataType,
David Beckdcb751f2018-10-03 11:42:42 +010048 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +000049{
David Beck33f0ae02018-10-18 15:13:56 +010050 Optional<std::string&> reason = outReasonIfUnsupported;
telsoa014fcda012018-03-09 14:13:49 +000051 bool result;
David Beckdcb751f2018-10-03 11:42:42 +010052 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
53
David Beck111b5d92018-11-12 14:59:37 +000054 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
56 {
57 std::stringstream ss;
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
60
61 outReasonIfUnsupported = ss.str();
62 return false;
63 }
64
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
David Beck33f0ae02018-10-18 15:13:56 +010068
telsoa014fcda012018-03-09 14:13:49 +000069 switch(layer.GetType())
70 {
71 case LayerType::Activation:
72 {
73 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010075 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010076 result = layerSupportObject->IsActivationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010077 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +010080 reason);
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
83 case LayerType::Addition:
84 {
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +010088 result = layerSupportObject->IsAdditionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +010089 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +010092 reason);
telsoa014fcda012018-03-09 14:13:49 +000093 break;
94 }
Nikhil Rajee391d52019-09-05 17:50:44 +010095 case LayerType::ArgMinMax:
96 {
97 auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
98 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
99
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject->IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000104 OverrideDataType(output, DataType::Signed32),
Nikhil Rajee391d52019-09-05 17:50:44 +0100105 descriptor,
106 reason);
107 break;
108 }
telsoa014fcda012018-03-09 14:13:49 +0000109 case LayerType::BatchNormalization:
110 {
111 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100118 result = layerSupportObject->IsBatchNormalizationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100126 reason);
telsoa014fcda012018-03-09 14:13:49 +0000127 break;
128 }
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000129 case LayerType::BatchToSpaceNd:
130 {
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
133 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
134
135 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
138 reason);
139 break;
140 }
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100141 case LayerType::Comparison:
142 {
143 auto cLayer = boost::polymorphic_downcast<const ComparisonLayer*>(&layer);
144
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148
149 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
153 reason);
154 break;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 case LayerType::Constant:
157 {
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100159 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 break;
161 }
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000162 case LayerType::ConvertBf16ToFp32:
163 {
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
166 result = layerSupportObject->IsConvertBf16ToFp32Supported(input, output, reason);
167 break;
168 }
telsoa01c577f2c2018-08-31 09:22:23 +0100169 case LayerType::ConvertFp16ToFp32:
170 {
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100173 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100174 break;
175 }
176 case LayerType::ConvertFp32ToFp16:
177 {
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100180 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
telsoa014fcda012018-03-09 14:13:49 +0000181 break;
182 }
183 case LayerType::Convolution2d:
184 {
185 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
arovir01a6824102018-08-28 17:40:45 +0100186
187 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
188 dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
surmeh013537c2c2018-05-18 16:31:43 +0100190 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
191
arovir01a6824102018-08-28 17:40:45 +0100192 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
surmeh013537c2c2018-05-18 16:31:43 +0100193
arovir01a6824102018-08-28 17:40:45 +0100194 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100195 Optional<TensorInfo> biases;
surmeh013537c2c2018-05-18 16:31:43 +0100196 if (descriptor.m_BiasEnabled)
197 {
David Beck5eec11d2018-10-04 15:43:17 +0100198 biases =
199 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
surmeh013537c2c2018-05-18 16:31:43 +0100200 }
201
David Beck33f0ae02018-10-18 15:13:56 +0100202 result = layerSupportObject->IsConvolution2dSupported(
surmeh013537c2c2018-05-18 16:31:43 +0100203 input,
204 output,
205 descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +0100206 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100207 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100208 reason);
telsoa014fcda012018-03-09 14:13:49 +0000209 break;
210 }
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +0000211 case LayerType::Debug:
212 {
213 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
214 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
215
216 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
217 OverrideDataType(output, dataType),
218 reason);
219 break;
220 }
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +0100221 case LayerType::DepthToSpace:
222 {
223 auto cLayer = boost::polymorphic_downcast<const DepthToSpaceLayer*>(&layer);
224
225 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
226 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
227
228 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
229 OverrideDataType(output, dataType),
230 cLayer->GetParameters(),
231 reason);
232 break;
233 }
telsoa014fcda012018-03-09 14:13:49 +0000234 case LayerType::DepthwiseConvolution2d:
235 {
236 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100237 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
238 dataType);
239 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
240 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
241
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
arovir01a6824102018-08-28 17:40:45 +0100243
244 // Construct optional biases object based on the value of m_BiasEnabled
David Beck5eec11d2018-10-04 15:43:17 +0100245 Optional<TensorInfo> biases;
telsoa01c577f2c2018-08-31 09:22:23 +0100246 if (descriptor.m_BiasEnabled)
247 {
David Beck5eec11d2018-10-04 15:43:17 +0100248 biases =
249 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
telsoa01c577f2c2018-08-31 09:22:23 +0100250 }
telsoa01c577f2c2018-08-31 09:22:23 +0100251
David Beck33f0ae02018-10-18 15:13:56 +0100252 result = layerSupportObject->IsDepthwiseConvolutionSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100253 input,
254 output,
255 descriptor,
256 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
arovir01a6824102018-08-28 17:40:45 +0100257 biases,
David Beck33f0ae02018-10-18 15:13:56 +0100258 reason);
telsoa014fcda012018-03-09 14:13:49 +0000259 break;
260 }
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000261 case LayerType::Dequantize:
262 {
263 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
264 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
265
Aron Virginas-Tar87972be2019-11-13 15:16:28 +0000266 result = layerSupportObject->IsDequantizeSupported(input,
267 OverrideDataType(output, dataType),
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +0000268 reason);
269 break;
270 }
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000271 case LayerType::DetectionPostProcess:
272 {
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000273 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000274 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
275 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
276 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
277
278 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
279 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
280 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
281 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
282
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000283 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000284 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
285 scores,
286 anchors,
287 detectionBoxes,
288 detectionClasses,
289 detectionScores,
290 numDetections,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000291 descriptor,
292 reason);
293 break;
294 }
josh minor4a3c6102020-01-06 16:40:46 -0600295 case LayerType::ElementwiseUnary:
296 {
297 auto cLayer = boost::polymorphic_downcast<const ElementwiseUnaryLayer*>(&layer);
298
299 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
300 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
301
302 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
303 OverrideDataType(output, dataType),
304 cLayer->GetParameters(),
305 reason);
306 break;
307 }
telsoa014fcda012018-03-09 14:13:49 +0000308 case LayerType::FakeQuantization:
309 {
310 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
311 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100312 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
313 cLayer->GetParameters(),
314 reason);
telsoa014fcda012018-03-09 14:13:49 +0000315 break;
316 }
317 case LayerType::Floor:
318 {
319 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
320 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100321 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
322 OverrideDataType(output, dataType),
323 reason);
telsoa014fcda012018-03-09 14:13:49 +0000324 break;
325 }
326 case LayerType::FullyConnected:
327 {
328 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
329 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100330 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
331 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
332
333 TensorInfo biasInfo;
334 const TensorInfo * biasInfoPtr = nullptr;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000335 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
telsoa01c577f2c2018-08-31 09:22:23 +0100336 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
337 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
338 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
339
340 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
341 if (descriptor.m_BiasEnabled)
342 {
343 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
344 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
345 biasInfoPtr = &biasInfo;
346 }
347 else
348 {
349 // If biases are not enabled pass a dummy tensorinfo for the validation
350 switch(input.GetDataType())
351 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000352 case DataType::BFloat16:
353 {
354 biasInfoPtr = &dummyBFloat16Bias;
355 break;
356 }
telsoa01c577f2c2018-08-31 09:22:23 +0100357 case DataType::Float16:
358 {
359 biasInfoPtr = &dummyFloat16Bias;
360 break;
361 }
362 case DataType::Float32:
363 {
364 biasInfoPtr = &dummyFloat32Bias;
365 break;
366 }
Derek Lambertif90c56d2020-01-10 17:14:08 +0000367 case DataType::QAsymmU8:
Keith Davisa8565012020-02-14 12:22:40 +0000368 case DataType::QAsymmS8:
Keith Davis9d0ff742020-02-03 14:47:54 +0000369 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +0000370 case DataType::QSymmS16:
telsoa01c577f2c2018-08-31 09:22:23 +0100371 {
372 biasInfoPtr = &dummyQA8Bias;
373 break;
374 }
375 default:
376 {
377 BOOST_ASSERT_MSG(false, "Unexpected bias type");
378 }
379 }
380 }
381
David Beck33f0ae02018-10-18 15:13:56 +0100382 result = layerSupportObject->IsFullyConnectedSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100383 OverrideDataType(input, dataType),
384 OverrideDataType(output, dataType),
385 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
386 *biasInfoPtr,
387 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100388 reason);
telsoa014fcda012018-03-09 14:13:49 +0000389 break;
390 }
narpra01b89b05f2019-01-16 09:53:09 +0000391 case LayerType::Gather:
392 {
393 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
394 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
395 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
396 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100397 input1,
narpra01b89b05f2019-01-16 09:53:09 +0000398 OverrideDataType(output, dataType),
399 reason);
400 break;
401 }
telsoa014fcda012018-03-09 14:13:49 +0000402 case LayerType::Input:
403 {
404 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100405 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000406 break;
407 }
Kevin Mayce5045a2019-10-02 14:07:47 +0100408 case LayerType::InstanceNormalization:
409 {
410 auto cLayer = boost::polymorphic_downcast<const InstanceNormalizationLayer*>(&layer);
411 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
412
413 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
414 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
415
416 result = layerSupportObject->IsInstanceNormalizationSupported(
417 OverrideDataType(input, dataType),
418 OverrideDataType(output, dataType),
419 descriptor,
420 reason);
421 break;
422 }
telsoa014fcda012018-03-09 14:13:49 +0000423 case LayerType::L2Normalization:
424 {
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100425 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
426 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
427
telsoa014fcda012018-03-09 14:13:49 +0000428 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100429 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100430
David Beck33f0ae02018-10-18 15:13:56 +0100431 result = layerSupportObject->IsL2NormalizationSupported(
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100432 OverrideDataType(input, dataType),
433 OverrideDataType(output, dataType),
434 descriptor,
David Beck33f0ae02018-10-18 15:13:56 +0100435 reason);
telsoa01c577f2c2018-08-31 09:22:23 +0100436 break;
437 }
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100438 case LayerType::LogSoftmax:
439 {
440 auto cLayer = boost::polymorphic_downcast<const LogSoftmaxLayer*>(&layer);
441
442 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
443 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
444
445 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
446 OverrideDataType(output, dataType),
447 cLayer->GetParameters(),
448 reason);
449 break;
450 }
telsoa01c577f2c2018-08-31 09:22:23 +0100451 case LayerType::Lstm:
452 {
453 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
454 const LstmDescriptor& descriptor = cLayer->GetParameters();
455
456 // All inputs.
457 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
458 dataType);
459 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
460 dataType);
461 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
462 dataType);
463 // All outputs
464 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
465 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
466 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
467 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
468
469 // Basic parameters
470 const TensorInfo& inputToForgetWeights
471 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
472 const TensorInfo& inputToCellWeights
473 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
474 const TensorInfo& inputToOutputWeights
475 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
476 const TensorInfo& recurrentToForgetWeights
477 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
478 const TensorInfo& recurrentToCellWeights
479 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
480 const TensorInfo& recurrentToOutputWeights
481 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
482 const TensorInfo& forgetGateBias
483 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
484 const TensorInfo& cellBias
485 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
486 const TensorInfo& outputGateBias
487 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
488
Jan Eilersd01a83c2019-07-03 18:20:40 +0100489 LstmInputParamsInfo paramsInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100490
Jan Eilersd01a83c2019-07-03 18:20:40 +0100491 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
492 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
493 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
494 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
495 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
496 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
497 paramsInfo.m_ForgetGateBias = &forgetGateBias;
498 paramsInfo.m_CellBias = &cellBias;
499 paramsInfo.m_OutputGateBias = &outputGateBias;
500
501
502 // Optional parameters
telsoa01c577f2c2018-08-31 09:22:23 +0100503 TensorInfo optInputToInputWeights;
504 TensorInfo optRecurrentToInputWeights;
505 TensorInfo optCellToInputWeights;
506 TensorInfo optInputGateBias;
507 TensorInfo optProjectionWeights;
508 TensorInfo optProjectionBias;
509 TensorInfo optCellToForgetWeights;
510 TensorInfo optCellToOutputWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100511 TensorInfo optInputLayerNormWeights;
512 TensorInfo optForgetLayerNormWeights;
513 TensorInfo optCellLayerNormWeights;
514 TensorInfo optOutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100515
516 if(!descriptor.m_CifgEnabled)
517 {
518 optInputToInputWeights =
519 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100520 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100521
522 optRecurrentToInputWeights =
523 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100524 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100525 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
526 {
527 optCellToInputWeights =
528 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100529 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100530 }
531 optInputGateBias =
532 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100533 paramsInfo.m_InputGateBias = &optInputGateBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100534 }
535
536 if(descriptor.m_ProjectionEnabled)
537 {
538 optProjectionWeights =
539 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100540 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100541 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
542 {
543 optProjectionBias =
544 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100545 paramsInfo.m_ProjectionBias = &optProjectionBias;
telsoa01c577f2c2018-08-31 09:22:23 +0100546 }
547 }
548
549 if(descriptor.m_PeepholeEnabled)
550 {
551 optCellToForgetWeights =
552 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100553 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100554 optCellToOutputWeights =
555 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100556 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
telsoa01c577f2c2018-08-31 09:22:23 +0100557 }
558
Jan Eilers38e05bd2019-06-26 13:10:09 +0100559 if(descriptor.m_LayerNormEnabled)
560 {
Ferran Balaguere30c16e2019-07-24 17:03:45 +0100561 if (!descriptor.m_CifgEnabled)
562 {
563 optInputLayerNormWeights = OverrideDataType(
564 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
565 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
566 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100567
568 optForgetLayerNormWeights = OverrideDataType(
569 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100570 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100571
572 optCellLayerNormWeights = OverrideDataType(
573 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100574 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100575
576 optOutputLayerNormWeights = OverrideDataType(
577 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100578 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
Jan Eilers38e05bd2019-06-26 13:10:09 +0100579 }
580
David Beck33f0ae02018-10-18 15:13:56 +0100581 result = layerSupportObject->IsLstmSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100582 input,
583 outputStateIn,
584 cellStateIn,
585 scratchBuffer,
586 outputStateOut,
587 cellStateOut,
588 output,
589 descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100590 paramsInfo,
591 reason);
telsoa014fcda012018-03-09 14:13:49 +0000592 break;
593 }
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000594 case LayerType::Maximum:
595 {
596 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
597 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
598 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
599
600 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
601 OverrideDataType(input1, dataType),
602 OverrideDataType(output, dataType),
603 reason);
604 break;
605 }
narpra01b89b05f2019-01-16 09:53:09 +0000606 case LayerType::MemCopy:
607 {
608 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
609 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000610
narpra01b89b05f2019-01-16 09:53:09 +0000611 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
612 OverrideDataType(output, dataType),
613 reason);
614 break;
615 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100616 case LayerType::MemImport:
617 {
618 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
619 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
620
621 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
622 OverrideDataType(output, dataType),
623 reason);
624 break;
625 }
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100626 case LayerType::Merge:
627 {
628 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
629 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
630 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
631
632 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
633 OverrideDataType(input1, dataType),
634 OverrideDataType(output, dataType),
635 reason);
636 break;
637 }
Jim Flynne242f2d2019-05-22 14:24:13 +0100638 case LayerType::Concat:
telsoa014fcda012018-03-09 14:13:49 +0000639 {
Jim Flynne242f2d2019-05-22 14:24:13 +0100640 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000641
telsoa01c577f2c2018-08-31 09:22:23 +0100642 // Get vector of all inputs.
643 auto getTensorInfo = [&dataType](const InputSlot& slot)
telsoa014fcda012018-03-09 14:13:49 +0000644 {
telsoa01c577f2c2018-08-31 09:22:23 +0100645 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
telsoa014fcda012018-03-09 14:13:49 +0000646 };
telsoa01c577f2c2018-08-31 09:22:23 +0100647 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
648 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
649 std::vector<TensorInfo> inputs(beginI, endI);
telsoa014fcda012018-03-09 14:13:49 +0000650
telsoa01c577f2c2018-08-31 09:22:23 +0100651 auto getTensorInfoPtr = [](const TensorInfo& info)
652 {
653 return &info;
654 };
655 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
656 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
657 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
telsoa014fcda012018-03-09 14:13:49 +0000658
Nikhil Raj8599a412018-11-19 14:51:07 +0000659 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
660
Jim Flynne242f2d2019-05-22 14:24:13 +0100661 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
662
663
telsoa014fcda012018-03-09 14:13:49 +0000664 break;
665 }
666 case LayerType::Multiplication:
667 {
668 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
669 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100670 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100671 result = layerSupportObject->IsMultiplicationSupported(
telsoa01c577f2c2018-08-31 09:22:23 +0100672 OverrideDataType(input0, dataType),
673 OverrideDataType(input1, dataType),
674 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100675 reason);
telsoa014fcda012018-03-09 14:13:49 +0000676 break;
677 }
678 case LayerType::Normalization:
679 {
680 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
681 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
682 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100683 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
684 OverrideDataType(output, dataType),
685 cLayer->GetParameters(),
686 reason);
telsoa014fcda012018-03-09 14:13:49 +0000687 break;
688 }
689 case LayerType::Output:
690 {
691 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100692 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
telsoa014fcda012018-03-09 14:13:49 +0000693 break;
694 }
695 case LayerType::Permute:
696 {
697 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
698 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
699 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100700 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
701 OverrideDataType(output, dataType),
702 cLayer->GetParameters(),
703 reason);
telsoa014fcda012018-03-09 14:13:49 +0000704 break;
705 }
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100706 case LayerType::Pad:
707 {
708 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
709 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
710 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100711 result = layerSupportObject->IsPadSupported(
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100712 OverrideDataType(input, dataType),
713 OverrideDataType(output, dataType),
714 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +0100715 reason);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100716 break;
717 }
telsoa014fcda012018-03-09 14:13:49 +0000718 case LayerType::Pooling2d:
719 {
720 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
721 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100723 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
724 OverrideDataType(output, dataType),
725 cLayer->GetParameters(),
726 reason);
telsoa014fcda012018-03-09 14:13:49 +0000727 break;
728 }
Matteo Martincigh49124022019-01-11 13:25:59 +0000729 case LayerType::PreCompiled:
730 {
731 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
732 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
733 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
734 cLayer->GetParameters(),
735 reason);
736 break;
737 }
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000738 case LayerType::Quantize:
739 {
740 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
741 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
742 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
743 break;
744 }
James Conroyee18dc82019-07-17 11:27:46 +0100745 case LayerType::QuantizedLstm:
746 {
747 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
748
749 // Inputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100750 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
751 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
752 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100753
754 // Outputs
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100755 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
756 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100757
758 // QuantizedLstm parameters
James Conroyee18dc82019-07-17 11:27:46 +0100759 QuantizedLstmInputParamsInfo paramsInfo;
760
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100761 paramsInfo.m_InputToInputWeights =
762 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
763 paramsInfo.m_InputToForgetWeights =
764 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
765 paramsInfo.m_InputToCellWeights =
766 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
767 paramsInfo.m_InputToOutputWeights =
768 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100769
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100770 paramsInfo.m_RecurrentToInputWeights =
771 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
772 paramsInfo.m_RecurrentToForgetWeights =
773 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
774 paramsInfo.m_RecurrentToCellWeights =
775 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
776 paramsInfo.m_RecurrentToOutputWeights =
777 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
James Conroyee18dc82019-07-17 11:27:46 +0100778
Ferran Balaguer737d9ff2019-08-01 09:58:08 +0100779 paramsInfo.m_InputGateBias =
780 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
781 paramsInfo.m_ForgetGateBias =
782 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
783 paramsInfo.m_CellBias =
784 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
785 paramsInfo.m_OutputGateBias =
786 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
James Conroyee18dc82019-07-17 11:27:46 +0100787
788 result = layerSupportObject->IsQuantizedLstmSupported(input,
789 previousCellStateIn,
790 previousOutputIn,
791 cellStateOut,
792 output,
793 paramsInfo,
794 reason);
795 break;
796 }
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100797 case LayerType::Division:
798 {
799 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
800 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
801 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100802 result = layerSupportObject->IsDivisionSupported(
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100803 OverrideDataType(input0, dataType),
804 OverrideDataType(input1, dataType),
805 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100806 reason);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100807 break;
808 }
telsoa014fcda012018-03-09 14:13:49 +0000809 case LayerType::Reshape:
810 {
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000811 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
telsoa014fcda012018-03-09 14:13:49 +0000812 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Kevin Maya023c402019-12-12 17:28:05 +0000813 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000814 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
Kevin Maya023c402019-12-12 17:28:05 +0000815 OverrideDataType(output, dataType),
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000816 cLayer->GetParameters(),
817 reason);
telsoa014fcda012018-03-09 14:13:49 +0000818 break;
819 }
Teresa Charlina9075df2019-06-27 15:41:57 +0100820 case LayerType::Resize:
821 {
822 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100823 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Teresa Charlina9075df2019-06-27 15:41:57 +0100824 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
825 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
826 OverrideDataType(output, dataType),
827 cLayer->GetParameters(),
828 reason);
829 break;
830 }
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100831 case LayerType::Slice:
832 {
833 auto cLayer = boost::polymorphic_downcast<const SliceLayer*>(&layer);
834
835 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
836 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
837
838 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
839 OverrideDataType(output, dataType),
840 cLayer->GetParameters(),
841 reason);
842 break;
843 }
telsoa014fcda012018-03-09 14:13:49 +0000844 case LayerType::Softmax:
845 {
846 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
847 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +0100848 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100849 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
850 OverrideDataType(output, dataType),
851 cLayer->GetParameters(),
852 reason);
telsoa014fcda012018-03-09 14:13:49 +0000853 break;
854 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000855 case LayerType::SpaceToBatchNd:
856 {
857 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
858 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
859 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
860 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
861 OverrideDataType(output, dataType),
862 cLayer->GetParameters(),
863 reason);
864 break;
865 }
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100866 case LayerType::SpaceToDepth:
867 {
868 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
869
870 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
871 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
872
873 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
874 OverrideDataType(output, dataType),
875 cLayer->GetParameters(),
876 reason);
877 break;
878 }
telsoa014fcda012018-03-09 14:13:49 +0000879 case LayerType::Splitter:
880 {
881 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
882 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100883
884 // Get vector of all outputs.
885 auto getTensorInfo = [&dataType](const OutputSlot& slot)
886 {
887 return OverrideDataType(slot.GetTensorInfo(), dataType);
888 };
889 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
890 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
891 std::vector<TensorInfo> outputs(beginI, endI);
892
893 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
894
David Beck33f0ae02018-10-18 15:13:56 +0100895 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100896 outputPtrs,
David Beck33f0ae02018-10-18 15:13:56 +0100897 cLayer->GetParameters(),
898 reason);
telsoa014fcda012018-03-09 14:13:49 +0000899 break;
900 }
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901 case LayerType::Stack:
902 {
903 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
904
905 // Get vector of all inputs.
906 auto getTensorInfo = [&dataType](const InputSlot& slot)
907 {
908 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
909 };
910 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
911 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
912 std::vector<TensorInfo> inputs(beginI, endI);
913
914 auto getTensorInfoPtr = [](const TensorInfo& info)
915 {
916 return &info;
917 };
918 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
919 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
920 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
921
922 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
923
924 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
925
926 break;
927 }
Derek Lamberti013c3902019-10-21 10:46:16 +0100928 case LayerType::StandIn:
929 {
930 auto cLayer = boost::polymorphic_downcast<const StandInLayer*>(&layer);
931
932 // Get vector of all inputs.
933 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
934 {
935 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
936 };
937 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
938 {
939 return OverrideDataType(slot.GetTensorInfo(), dataType);
940 };
941 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
942 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
943 std::vector<TensorInfo> inputs(beginI, endI);
944
945 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
946 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
947 std::vector<TensorInfo> outputs(beginO, endO);
948
949
950 auto getTensorInfoPtr = [](const TensorInfo& info)
951 {
952 return &info;
953 };
954 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
955 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
956 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
957
958 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
959 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
960 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
961
962
963 result = layerSupportObject->IsStandInSupported(inputPtrs,
964 outputPtrs,
965 cLayer->GetParameters(),
966 reason);
967 break;
968 }
Conor Kennedy430b5d82018-11-14 15:28:28 +0000969 case LayerType::StridedSlice:
970 {
971 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
972 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
973 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
974 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
975 OverrideDataType(output, dataType),
976 cLayer->GetParameters(),
977 reason);
978 break;
979 }
David Beckc2044fe2018-09-05 15:00:38 +0100980 case LayerType::Subtraction:
981 {
982 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
983 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
984 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +0100985 result = layerSupportObject->IsSubtractionSupported(
David Beckc2044fe2018-09-05 15:00:38 +0100986 OverrideDataType(input0, dataType),
987 OverrideDataType(input1, dataType),
988 OverrideDataType(output, dataType),
David Beck33f0ae02018-10-18 15:13:56 +0100989 reason);
David Beckc2044fe2018-09-05 15:00:38 +0100990 break;
991 }
Sadik Armaganeff363d2019-04-05 15:25:46 +0100992 case LayerType::Switch:
993 {
994 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
995 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
996 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
997 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
998 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
999 OverrideDataType(input1, dataType),
1000 OverrideDataType(output0, dataType),
1001 OverrideDataType(output1, dataType),
1002 reason);
1003 break;
1004 }
narpra0132b90462018-09-13 11:07:48 +01001005 case LayerType::Mean:
1006 {
1007 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
1008 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1009 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
David Beck33f0ae02018-10-18 15:13:56 +01001010 result = layerSupportObject->IsMeanSupported(
narpra0132b90462018-09-13 11:07:48 +01001011 OverrideDataType(input, dataType),
1012 OverrideDataType(output, dataType),
1013 cLayer->GetParameters(),
David Beck33f0ae02018-10-18 15:13:56 +01001014 reason);
narpra0132b90462018-09-13 11:07:48 +01001015 break;
1016 }
kevmay0190539692018-11-29 08:40:19 +00001017 case LayerType::Minimum:
1018 {
1019 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1020 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1021 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1022 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1023 OverrideDataType(input1, dataType),
1024 OverrideDataType(output, dataType),
1025 reason);
1026 break;
1027 }
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001028 case LayerType::Prelu:
1029 {
1030 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1031 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1032 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1033 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1034 OverrideDataType(alpha, dataType),
1035 OverrideDataType(output, dataType),
1036 reason);
1037 break;
1038 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001039 case LayerType::Transpose:
1040 {
1041 auto cLayer = boost::polymorphic_downcast<const TransposeLayer*>(&layer);
1042 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1043 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1044 result = layerSupportObject->IsTransposeSupported(OverrideDataType(input, dataType),
1045 OverrideDataType(output, dataType),
1046 cLayer->GetParameters(),
1047 reason);
1048 break;
1049 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001050 case LayerType::TransposeConvolution2d:
1051 {
1052 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
1053
1054 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1055 dataType);
1056 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1057
1058 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1059
1060 Optional<TensorInfo> biases;
1061 if (descriptor.m_BiasEnabled)
1062 {
1063 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
1064 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1065 GetBiasTypeFromWeightsType(dataType));
1066 }
1067
1068 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
1069 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1070
1071 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1072 output,
1073 descriptor,
1074 weights,
1075 biases,
1076 reason);
1077
1078 break;
1079 }
telsoa014fcda012018-03-09 14:13:49 +00001080 default:
1081 {
1082 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
David Beck33f0ae02018-10-18 15:13:56 +01001083 reason.value() = "Unrecognised layer type";
telsoa014fcda012018-03-09 14:13:49 +00001084 result = false;
1085 break;
1086 }
1087 }
telsoa014fcda012018-03-09 14:13:49 +00001088 return result;
1089}
1090
David Beckdcb751f2018-10-03 11:42:42 +01001091bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
David Beck29c75de2018-10-23 13:35:58 +01001092 Optional<DataType> dataType,
telsoa01c577f2c2018-08-31 09:22:23 +01001093 std::string& outReasonIfUnsupported)
telsoa014fcda012018-03-09 14:13:49 +00001094{
David Beckdcb751f2018-10-03 11:42:42 +01001095 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
David Beck33f0ae02018-10-18 15:13:56 +01001096 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
telsoa014fcda012018-03-09 14:13:49 +00001097}
1098
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001099// Default Implementations
Derek Lamberti901ea112019-12-10 22:07:09 +00001100std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1101 const WorkloadInfo& /*info*/) const
Kevin May868eb142019-09-04 17:29:31 +01001102{
1103 return std::unique_ptr<IWorkload>();
1104}
1105
Derek Lamberti901ea112019-12-10 22:07:09 +00001106std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1107 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001108{
1109 return std::unique_ptr<IWorkload>();
1110}
1111
Derek Lamberti901ea112019-12-10 22:07:09 +00001112std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1113 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001114{
1115 return std::unique_ptr<IWorkload>();
1116}
1117
Derek Lamberti901ea112019-12-10 22:07:09 +00001118std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1119 const WorkloadInfo& /*info*/) const
Nikhil Rajee391d52019-09-05 17:50:44 +01001120{
1121 return std::unique_ptr<IWorkload>();
1122}
1123
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001124std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001125 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001126{
1127 return std::unique_ptr<IWorkload>();
1128}
1129
Derek Lamberti901ea112019-12-10 22:07:09 +00001130std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1131 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001132{
1133 return std::unique_ptr<IWorkload>();
1134}
1135
Derek Lamberti901ea112019-12-10 22:07:09 +00001136std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1137 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001138{
1139 return std::unique_ptr<IWorkload>();
1140}
1141
Derek Lamberti901ea112019-12-10 22:07:09 +00001142std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1143 const WorkloadInfo& /*info*/) const
Jim Flynn4ed6c832019-05-20 11:02:46 +01001144{
1145 return std::unique_ptr<IWorkload>();
1146}
1147
Derek Lamberti901ea112019-12-10 22:07:09 +00001148std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1149 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001150{
1151 return std::unique_ptr<IWorkload>();
1152}
1153
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00001154std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1155 const WorkloadInfo& /*info*/) const
1156{
1157 return std::unique_ptr<IWorkload>();
1158}
1159
Derek Lamberti901ea112019-12-10 22:07:09 +00001160std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1161 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001162{
1163 return std::unique_ptr<IWorkload>();
1164}
1165
Derek Lamberti901ea112019-12-10 22:07:09 +00001166std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1167 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001168{
1169 return std::unique_ptr<IWorkload>();
1170}
1171
Derek Lamberti901ea112019-12-10 22:07:09 +00001172std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1173 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001174{
1175 return std::unique_ptr<IWorkload>();
1176}
1177
Derek Lamberti901ea112019-12-10 22:07:09 +00001178std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1179 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001180{
1181 return std::unique_ptr<IWorkload>();
1182}
1183
Derek Lamberti901ea112019-12-10 22:07:09 +00001184std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1185 const WorkloadInfo& /*info*/) const
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01001186{
1187 return std::unique_ptr<IWorkload>();
1188}
1189
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001190std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001191 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001192{
1193 return std::unique_ptr<IWorkload>();
1194}
1195
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001196std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
Derek Lamberti901ea112019-12-10 22:07:09 +00001197 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00001198{
1199 return std::unique_ptr<IWorkload>();
1200}
1201
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001202std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
Derek Lamberti901ea112019-12-10 22:07:09 +00001203 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001204{
1205 return std::unique_ptr<IWorkload>();
1206}
1207
Derek Lamberti901ea112019-12-10 22:07:09 +00001208std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1209 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001210{
1211 return std::unique_ptr<IWorkload>();
1212}
1213
josh minor4a3c6102020-01-06 16:40:46 -06001214std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1215 const WorkloadInfo& /*info*/) const
1216{
1217 return std::unique_ptr<IWorkload>();
1218}
1219
Derek Lamberti901ea112019-12-10 22:07:09 +00001220std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1221 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001222{
1223 return std::unique_ptr<IWorkload>();
1224}
1225
Derek Lamberti901ea112019-12-10 22:07:09 +00001226std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1227 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001228{
1229 return std::unique_ptr<IWorkload>();
1230}
1231
Derek Lamberti901ea112019-12-10 22:07:09 +00001232std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1233 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001234{
1235 return std::unique_ptr<IWorkload>();
1236}
1237
Derek Lamberti901ea112019-12-10 22:07:09 +00001238std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1239 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001240{
1241 return std::unique_ptr<IWorkload>();
1242}
1243
Derek Lamberti901ea112019-12-10 22:07:09 +00001244std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1245 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001246{
1247 return std::unique_ptr<IWorkload>();
1248}
1249
Derek Lamberti901ea112019-12-10 22:07:09 +00001250std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1251 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001252{
1253 return std::unique_ptr<IWorkload>();
1254}
1255
Kevin Mayce5045a2019-10-02 14:07:47 +01001256std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
Derek Lamberti901ea112019-12-10 22:07:09 +00001257 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1258 const WorkloadInfo& /*info*/) const
Kevin Mayce5045a2019-10-02 14:07:47 +01001259{
1260 return std::unique_ptr<IWorkload>();
1261}
1262
Derek Lamberti901ea112019-12-10 22:07:09 +00001263std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1264 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001265{
1266 return std::unique_ptr<IWorkload>();
1267}
1268
Derek Lamberti901ea112019-12-10 22:07:09 +00001269std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1270 const WorkloadInfo& /*info*/) const
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001271{
1272 return std::unique_ptr<IWorkload>();
1273}
1274
Derek Lamberti901ea112019-12-10 22:07:09 +00001275std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1276 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001277{
1278 return std::unique_ptr<IWorkload>();
1279}
1280
Derek Lamberti901ea112019-12-10 22:07:09 +00001281std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1282 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001283{
1284 return std::unique_ptr<IWorkload>();
1285}
1286
Derek Lamberti901ea112019-12-10 22:07:09 +00001287std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1288 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001289{
1290 return std::unique_ptr<IWorkload>();
1291}
1292
Derek Lamberti901ea112019-12-10 22:07:09 +00001293std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1294 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001295{
1296 return std::unique_ptr<IWorkload>();
1297}
1298
Derek Lamberti901ea112019-12-10 22:07:09 +00001299std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1300 const WorkloadInfo& /*info*/) const
Derek Lambertif674aa02019-08-01 15:56:25 +01001301{
1302 return std::unique_ptr<IWorkload>();
1303}
1304
Derek Lamberti901ea112019-12-10 22:07:09 +00001305std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1306 const WorkloadInfo& /*info*/) const
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01001307{
1308 return std::unique_ptr<IWorkload>();
1309}
1310
Derek Lamberti901ea112019-12-10 22:07:09 +00001311std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1312 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001313{
1314 return std::unique_ptr<IWorkload>();
1315}
1316
Derek Lamberti901ea112019-12-10 22:07:09 +00001317std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1318 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001319{
1320 return std::unique_ptr<IWorkload>();
1321}
1322
Derek Lamberti901ea112019-12-10 22:07:09 +00001323std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1324 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001325{
1326 return std::unique_ptr<IWorkload>();
1327}
1328
Derek Lamberti901ea112019-12-10 22:07:09 +00001329std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1330 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001331{
1332 return std::unique_ptr<IWorkload>();
1333}
1334
Derek Lamberti901ea112019-12-10 22:07:09 +00001335std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1336 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001337{
1338 return std::unique_ptr<IWorkload>();
1339}
1340
Derek Lamberti901ea112019-12-10 22:07:09 +00001341std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1342 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001343{
1344 return std::unique_ptr<IWorkload>();
1345}
1346
Derek Lamberti901ea112019-12-10 22:07:09 +00001347std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001348 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001349{
1350 return std::unique_ptr<IWorkload>();
1351}
1352
Derek Lamberti901ea112019-12-10 22:07:09 +00001353std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1354 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001355{
1356 return std::unique_ptr<IWorkload>();
1357}
1358
Derek Lamberti901ea112019-12-10 22:07:09 +00001359std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1360 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001361{
1362 return std::unique_ptr<IWorkload>();
1363}
1364
Derek Lamberti901ea112019-12-10 22:07:09 +00001365std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1366 const WorkloadInfo &/*info*/) const
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01001367{
1368 return std::unique_ptr<IWorkload>();
1369}
1370
Derek Lamberti901ea112019-12-10 22:07:09 +00001371std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1372 const WorkloadInfo& /*Info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001373{
1374 return std::unique_ptr<IWorkload>();
1375}
1376
Derek Lamberti901ea112019-12-10 22:07:09 +00001377std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1378 const WorkloadInfo& /*info*/) const
James Conroyee18dc82019-07-17 11:27:46 +01001379{
1380 return std::unique_ptr<IWorkload>();
1381}
1382
Derek Lamberti901ea112019-12-10 22:07:09 +00001383std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1384 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001385{
1386 return std::unique_ptr<IWorkload>();
1387}
1388
Derek Lamberti901ea112019-12-10 22:07:09 +00001389std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1390 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001391{
1392 return std::unique_ptr<IWorkload>();
1393}
1394
Derek Lamberti901ea112019-12-10 22:07:09 +00001395std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1396 const WorkloadInfo& /*info*/) const
Teresa Charlina9075df2019-06-27 15:41:57 +01001397{
1398 return std::unique_ptr<IWorkload>();
1399}
1400
Derek Lamberti901ea112019-12-10 22:07:09 +00001401std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1402 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001403{
1404 return std::unique_ptr<IWorkload>();
1405}
1406
Derek Lamberti901ea112019-12-10 22:07:09 +00001407std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1408 const WorkloadInfo& /*info*/) const
1409{
1410 return std::unique_ptr<IWorkload>();
1411}
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001412
Derek Lamberti901ea112019-12-10 22:07:09 +00001413std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1414 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001415{
1416 return std::unique_ptr<IWorkload>();
1417}
1418
Derek Lamberti901ea112019-12-10 22:07:09 +00001419std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1420 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001421{
1422 return std::unique_ptr<IWorkload>();
1423}
1424
Derek Lamberti901ea112019-12-10 22:07:09 +00001425std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1426 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001427{
1428 return std::unique_ptr<IWorkload>();
1429}
1430
Derek Lamberti901ea112019-12-10 22:07:09 +00001431std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1432 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001433{
1434 return std::unique_ptr<IWorkload>();
1435}
1436
Derek Lamberti901ea112019-12-10 22:07:09 +00001437std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1438 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar972af152019-06-11 14:14:03 +01001439{
1440 return std::unique_ptr<IWorkload>();
1441}
1442
Derek Lamberti901ea112019-12-10 22:07:09 +00001443std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1444 const WorkloadInfo& /*info*/) const
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001445{
1446 return std::unique_ptr<IWorkload>();
1447}
1448
Derek Lamberti901ea112019-12-10 22:07:09 +00001449std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1450 const WorkloadInfo& /*info*/) const
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001451{
1452 return std::unique_ptr<IWorkload>();
1453}
1454
Derek Lamberti901ea112019-12-10 22:07:09 +00001455std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1456 const WorkloadInfo& /*info*/) const
Sadik Armaganeff363d2019-04-05 15:25:46 +01001457{
1458 return std::unique_ptr<IWorkload>();
1459}
1460
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001461std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1462 const WorkloadInfo& /*info*/) const
1463{
1464 return std::unique_ptr<IWorkload>();
1465}
1466
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001467std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
Derek Lamberti901ea112019-12-10 22:07:09 +00001468 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1469 const WorkloadInfo& /*info*/) const
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001470{
1471 return std::unique_ptr<IWorkload>();
surmeh013537c2c2018-05-18 16:31:43 +01001472}
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01001473
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01001474} // namepsace armnn